Repository: blackfeather-wang/GFNet-Pytorch
Branch: master
Commit: 8a6775c8267c
Files: 103
Total size: 610.4 KB
Directory structure:
gitextract_6lcfcbge/
├── .gitignore
├── LICENSE
├── README.md
├── configs.py
├── inference.py
├── models/
│ ├── __init__.py
│ ├── activations/
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── activations_autofn.py
│ │ ├── activations_jit.py
│ │ └── config.py
│ ├── config.py
│ ├── conv2d_layers.py
│ ├── densenet.py
│ ├── efficientnet_builder.py
│ ├── gen_efficientnet.py
│ ├── helpers.py
│ ├── mobilenetv3.py
│ ├── model_factory.py
│ ├── resnet.py
│ └── version.py
├── network.py
├── pycls/
│ ├── __init__.py
│ ├── cfgs/
│ │ ├── RegNetY-1.6GF_dds_8gpu.yaml
│ │ ├── RegNetY-600MF_dds_8gpu.yaml
│ │ └── RegNetY-800MF_dds_8gpu.yaml
│ ├── core/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── losses.py
│ │ ├── model_builder.py
│ │ ├── old_config.py
│ │ └── optimizer.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── cifar10.py
│ │ ├── imagenet.py
│ │ ├── loader.py
│ │ ├── paths.py
│ │ └── transforms.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── anynet.py
│ │ ├── effnet.py
│ │ ├── regnet.py
│ │ └── resnet.py
│ └── utils/
│ ├── __init__.py
│ ├── benchmark.py
│ ├── checkpoint.py
│ ├── distributed.py
│ ├── error_handler.py
│ ├── io.py
│ ├── logging.py
│ ├── lr_policy.py
│ ├── meters.py
│ ├── metrics.py
│ ├── multiprocessing.py
│ ├── net.py
│ ├── plotting.py
│ └── timer.py
├── simplejson/
│ ├── __init__.py
│ ├── _speedups.c
│ ├── compat.py
│ ├── decoder.py
│ ├── encoder.py
│ ├── errors.py
│ ├── ordered_dict.py
│ ├── raw_json.py
│ ├── scanner.py
│ ├── tests/
│ │ ├── __init__.py
│ │ ├── test_bigint_as_string.py
│ │ ├── test_bitsize_int_as_string.py
│ │ ├── test_check_circular.py
│ │ ├── test_decimal.py
│ │ ├── test_decode.py
│ │ ├── test_default.py
│ │ ├── test_dump.py
│ │ ├── test_encode_basestring_ascii.py
│ │ ├── test_encode_for_html.py
│ │ ├── test_errors.py
│ │ ├── test_fail.py
│ │ ├── test_float.py
│ │ ├── test_for_json.py
│ │ ├── test_indent.py
│ │ ├── test_item_sort_key.py
│ │ ├── test_iterable.py
│ │ ├── test_namedtuple.py
│ │ ├── test_pass1.py
│ │ ├── test_pass2.py
│ │ ├── test_pass3.py
│ │ ├── test_raw_json.py
│ │ ├── test_recursion.py
│ │ ├── test_scanstring.py
│ │ ├── test_separators.py
│ │ ├── test_speedups.py
│ │ ├── test_str_subclass.py
│ │ ├── test_subclass.py
│ │ ├── test_tool.py
│ │ ├── test_tuple.py
│ │ └── test_unicode.py
│ └── tool.py
├── train.py
├── utils.py
└── yacs/
├── __init__.py
├── config.py
└── tests.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
models/.DS_Store
figures/.DS_Store
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Glance-and-Focus Networks (PyTorch)
This repo contains the official code and pre-trained models for the glance and focus networks (GFNet).
- (NeurIPS 2020) [Glance and Focus: a Dynamic Approach to Reducing Spatial Redundancy in Image Classification](https://arxiv.org/abs/2010.05300)
- (T-PAMI) [Glance and Focus Networks for Dynamic Visual Recognition](https://arxiv.org/abs/2201.03014)
**Update on 2020/12/28: Release Training Code.**
**Update on 2020/10/08: Release Pre-trained Models and the Inference Code on ImageNet.**
## Introduction
Inspired by the fact that not all regions in an image are task-relevant, we propose a novel framework that performs efficient image classification by processing a sequence of relatively small inputs, which are strategically cropped from the original image.
Experiments on ImageNet show that our method consistently improves the computational efficiency of a wide variety of deep models.
For example, it further reduces the average latency of the highly efficient MobileNet-V3 on an iPhone XS Max by 20% without sacrificing accuracy.
## Citation
```
@inproceedings{NeurIPS2020_7866,
title={Glance and Focus: a Dynamic Approach to Reducing Spatial Redundancy in Image Classification},
author={Wang, Yulin and Lv, Kangchen and Huang, Rui and Song, Shiji and Yang, Le and Huang, Gao},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2020},
}
@article{huang2023glance,
title={Glance and Focus Networks for Dynamic Visual Recognition},
author={Huang, Gao and Wang, Yulin and Lv, Kangchen and Jiang, Haojun and Huang, Wenhui and Qi, Pengfei and Song, Shiji},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2023},
volume={45},
number={4},
pages={4605-4621},
doi={10.1109/TPAMI.2022.3196959}
}
```
## Results
- Top-1 accuracy on ImageNet v.s. Multiply-Adds
- Top-1 accuracy on ImageNet v.s. Inference Latency (ms) on an iPhone XS Max
- Visualization
## Pre-trained Models
|Backbone CNNs|Patch Size|T|Links|
|-----|------|-----|-----|
|ResNet-50| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/4c55dd9472b4416cbdc9/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Iun8o4o7cQL-7vSwKyNfefOgwb9-o9kD/view?usp=sharing)|
|ResNet-50| 128x128| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/1cbed71346e54a129771/?dl=1) / [Google Drive](https://drive.google.com/file/d/1cEj0dXO7BfzQNd5fcYZOQekoAe3_DPia/view?usp=sharing)|
|DenseNet-121| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c75c77d2f2054872ac20/?dl=1) / [Google Drive](https://drive.google.com/file/d/1UflIM29Npas0rTQSxPqwAT6zHbFkQq6R/view?usp=sharing)|
|DenseNet-169| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/83fef24e667b4dccace1/?dl=1) / [Google Drive](https://drive.google.com/file/d/1pBo22i6VsWJWtw2xJw1_bTSMO3HgJNDL/view?usp=sharing)|
|DenseNet-201| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/85a57b82e592470892e0/?dl=1) / [Google Drive](https://drive.google.com/file/d/1sETDr7dP5Q525fRMTIl2jlt9Eg2qh2dx/view?usp=sharing)|
|RegNet-Y-600MF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/2638f038d3b1465da59e/?dl=1) / [Google Drive](https://drive.google.com/file/d/1FCR14wUiNrIXb81cU1pDPcy4bPSRXrqe/view?usp=sharing)|
|RegNet-Y-800MF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/686e411e72894b789dde/?dl=1) / [Google Drive](https://drive.google.com/file/d/1On39MwbJY5Zagz7gNtKBFwMWhhHskfZq/view?usp=sharing)|
|RegNet-Y-1.6GF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/90116ad21ee74843b0ef/?dl=1) / [Google Drive](https://drive.google.com/file/d/1rMe0LU8m4BF3udII71JT0VLPBE2G-eCJ/view?usp=sharing)|
|MobileNet-V3-Large (1.00)| 96x96| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/4a4e8486b83b4dbeb06c/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Dw16jPlw2hR8EaWbd_1Ujd6Df9n1gHgj/view?usp=sharing)|
|MobileNet-V3-Large (1.00)| 128x128| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/ab0f6fc3997d4771a4c9/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Ud_olyer-YgAb667YUKs38C2O1Yb6Nom/view?usp=sharing)|
|MobileNet-V3-Large (1.25)| 128x128| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b2052c3af7734f688bc7/?dl=1) / [Google Drive](https://drive.google.com/file/d/14zj1Ci0i4nYceu-f2ZckFMjRmYDGtJpl/view?usp=sharing)|
|EfficientNet-B2| 128x128| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/1a490deecd34470580da/?dl=1) / [Google Drive](https://drive.google.com/file/d/1LBBPrrYZzKKqCnoZH1kPfQQZ5ixkEmjz/view?usp=sharing)|
|EfficientNet-B3| 128x128| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/d5182a2257bb481ea622/?dl=1) / [Google Drive](https://drive.google.com/file/d/1fdxwimcuQAXBOsbOdGw8Ee43PgeHHZTA/view?usp=sharing)|
|EfficientNet-B3| 144x144| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/f96abfb6de13430aa663/?dl=1) / [Google Drive](https://drive.google.com/file/d/1OVTGI6d2nsN5Hz5T_qLnYeUIBL5oMeeU/view?usp=sharing)|
- What are contained in the checkpoints:
```
**.pth.tar
├── model_name: name of the backbone CNNs (e.g., resnet50, densenet121)
├── patch_size: size of image patches (i.e., H' or W' in the paper)
├── model_prime_state_dict, model_state_dict, fc, policy: state dictionaries of the four components of GFNets
├── model_flops, policy_flops, fc_flops: Multiply-Adds of inferring the encoder, patch proposal network and classifier for once
├── flops: a list containing the Multiply-Adds corresponding to each length of the input sequence during inference
├── anytime_classification: results of anytime prediction (in Top-1 accuracy)
├── dynamic_threshold: the confidence thresholds used in budgeted batch classification
├── budgeted_batch_classification: results of budgeted batch classification (a two-item list, [0] and [1] correspond to the two coordinates of a curve)
```
## Requirements
- python 3.7.7
- pytorch 1.3.1
- torchvision 0.4.2
- pyyaml 5.3.1 (for RegNets)
## Evaluate Pre-trained Models
Read the evaluation results saved in pre-trained models
```
CUDA_VISIBLE_DEVICES=0 python inference.py --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 0
```
Read the confidence thresholds saved in pre-trained models and infer the model on the validation set
```
CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 1
```
Determine confidence thresholds on the training set and infer the model on the validation set
```
CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 2
```
The dataset is expected to be prepared as follows:
```
ImageNet
├── train
│ ├── folder 1 (class 1)
│ ├── folder 2 (class 1)
│ ├── ...
├── val
│ ├── folder 1 (class 1)
│ ├── folder 2 (class 1)
│ ├── ...
```
## Training
- Here we take training ResNet-50 (96x96, T=5) for example. All the used initialization models and stage-1/2 checkpoints can be found in [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/ac7c47b3f9b04e098862/) / [Google Drive](https://drive.google.com/drive/folders/1yO2GviOnukSUgcTkptNLBBttSJQZk9yn?usp=sharing). Currently, this link includes ResNet and MobileNet-V3. We will update it as soon as possible. If you need other helps, feel free to contact us.
- The Results in the paper is based on 2 Tesla V100 GPUs. For most of experiments, up to 4 Titan Xp GPUs may be enough.
Training stage 1, the initializations of global encoder (model_prime) and local encoder (model) are required:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_url PATH_TO_DATASET --train_stage 1 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --model_prime_path PATH_TO_CHECKPOINTS --model_path PATH_TO_CHECKPOINTS
```
Training stage 2, a stage-1 checkpoint is required:
```
CUDA_VISIBLE_DEVICES=0 python train.py --data_url PATH_TO_DATASET --train_stage 2 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --checkpoint_path PATH_TO_CHECKPOINTS
```
Training stage 3, a stage-2 checkpoint is required:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_url PATH_TO_DATASET --train_stage 3 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --checkpoint_path PATH_TO_CHECKPOINTS
```
## Contact
If you have any question, please feel free to contact the authors. Yulin Wang: wang-yl19@mails.tsinghua.edu.cn.
## Acknowledgment
Our code of MobileNet-V3 and EfficientNet is from [here](https://github.com/rwightman/pytorch-image-models). Our code of RegNet is from [here](https://github.com/facebookresearch/pycls).
## To Do
- Update the code for visualizing.
- Update the code for MIXED PRECISION TRAINING。
================================================
FILE: configs.py
================================================
from PIL import Image
model_configurations = {
'resnet50': {
'feature_num': 2048,
'feature_map_channels': 2048,
'policy_conv': False,
'policy_hidden_dim': 1024,
'fc_rnn': True,
'fc_hidden_dim': 1024,
'image_size': 224,
'crop_pct': 0.875,
'dataset_interpolation': Image.BILINEAR,
'prime_interpolation': 'bicubic'
},
'densenet121': {
'feature_num': 1024,
'feature_map_channels': 1024,
'policy_conv': False,
'policy_hidden_dim': 1024,
'fc_rnn': True,
'fc_hidden_dim': 1024,
'image_size': 224,
'crop_pct': 0.875,
'dataset_interpolation': Image.BILINEAR,
'prime_interpolation': 'bilinear'
},
'densenet169': {
'feature_num': 1664,
'feature_map_channels': 1664,
'policy_conv': False,
'policy_hidden_dim': 1024,
'fc_rnn': True,
'fc_hidden_dim': 1024,
'image_size': 224,
'crop_pct': 0.875,
'dataset_interpolation': Image.BILINEAR,
'prime_interpolation': 'bilinear'
},
'densenet201': {
'feature_num': 1920,
'feature_map_channels': 1920,
'policy_conv': False,
'policy_hidden_dim': 1024,
'fc_rnn': True,
'fc_hidden_dim': 1024,
'image_size': 224,
'crop_pct': 0.875,
'dataset_interpolation': Image.BILINEAR,
'prime_interpolation': 'bilinear'
},
'mobilenetv3_large_100': {
'feature_num': 1280,
'feature_map_channels': 960,
'policy_conv': True,
'policy_hidden_dim': 256,
'fc_rnn': False,
'fc_hidden_dim': None,
'image_size': 224,
'crop_pct': 0.875,
'dataset_interpolation': Image.BILINEAR,
'prime_interpolation': 'bicubic'
},
'mobilenetv3_large_125': {
'feature_num': 1280,
'feature_map_channels': 1200,
'policy_conv': True,
'policy_hidden_dim': 256,
'fc_rnn': False,
'fc_hidden_dim': None,
'image_size': 224,
'crop_pct': 0.875,
'dataset_interpolation': Image.BILINEAR,
'prime_interpolation': 'bicubic'
},
'efficientnet_b2': {
'feature_num': 1408,
'feature_map_channels': 1408,
'policy_conv': True,
'policy_hidden_dim': 256,
'fc_rnn': False,
'fc_hidden_dim': None,
'image_size': 260,
'crop_pct': 0.875,
'dataset_interpolation': Image.BICUBIC,
'prime_interpolation': 'bicubic'
},
'efficientnet_b3': {
'feature_num': 1536,
'feature_map_channels': 1536,
'policy_conv': True,
'policy_hidden_dim': 256,
'fc_rnn': False,
'fc_hidden_dim': None,
'image_size': 300,
'crop_pct': 0.904,
'dataset_interpolation': Image.BICUBIC,
'prime_interpolation': 'bicubic'
},
'regnety_600m': {
'feature_num': 608,
'feature_map_channels': 608,
'policy_conv': True,
'policy_hidden_dim': 256,
'fc_rnn': True,
'fc_hidden_dim': 1024,
'image_size': 224,
'crop_pct': 0.875,
'dataset_interpolation': Image.BILINEAR,
'prime_interpolation': 'bilinear',
'cfg_file': 'pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml'
},
'regnety_800m': {
'feature_num': 768,
'feature_map_channels': 768,
'policy_conv': True,
'policy_hidden_dim': 256,
'fc_rnn': True,
'fc_hidden_dim': 1024,
'image_size': 224,
'crop_pct': 0.875,
'dataset_interpolation': Image.BILINEAR,
'prime_interpolation': 'bilinear',
'cfg_file': 'pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml'
},
'regnety_1.6g': {
'feature_num': 888,
'feature_map_channels': 888,
'policy_conv': True,
'policy_hidden_dim': 256,
'fc_rnn': True,
'fc_hidden_dim': 1024,
'image_size': 224,
'crop_pct': 0.875,
'dataset_interpolation': Image.BILINEAR,
'prime_interpolation': 'bilinear',
'cfg_file': 'pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml'
}
}
train_configurations = {
'resnet': {
'backbone_lr': 0.01,
'fc_stage_1_lr': 0.1,
'fc_stage_3_lr': 0.01,
'weight_decay': 1e-4,
'momentum': 0.9,
'Nesterov': True,
'batch_size': 256,
'dsn_ratio': 1,
'epoch_num': 60,
'train_model_prime': True
},
'densenet': {
'backbone_lr': 0.01,
'fc_stage_1_lr': 0.1,
'fc_stage_3_lr': 0.01,
'weight_decay': 1e-4,
'momentum': 0.9,
'Nesterov': True,
'batch_size': 256,
'dsn_ratio': 1,
'epoch_num': 60,
'train_model_prime': True
},
'efficientnet': {
'backbone_lr': 0.005,
'fc_stage_1_lr': 0.1,
'fc_stage_3_lr': 0.01,
'weight_decay': 1e-4,
'momentum': 0.9,
'Nesterov': True,
'batch_size': 256,
'dsn_ratio': 5,
'epoch_num': 30,
'train_model_prime': False
},
'mobilenetv3': {
'backbone_lr': 0.005,
'fc_stage_1_lr': 0.1,
'fc_stage_3_lr': 0.01,
'weight_decay': 1e-4,
'momentum': 0.9,
'Nesterov': True,
'batch_size': 256,
'dsn_ratio': 5,
'epoch_num': 90,
'train_model_prime': False
},
'regnet': {
'backbone_lr': 0.02,
'fc_stage_1_lr': 0.1,
'fc_stage_3_lr': 0.01,
'weight_decay': 5e-5,
'momentum': 0.9,
'Nesterov': True,
'batch_size': 256,
'dsn_ratio': 1,
'epoch_num': 60,
'train_model_prime': True
}
}
================================================
FILE: inference.py
================================================
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
from utils import *
from network import *
from configs import *
import math
import argparse
import models.resnet as resnet
import models.densenet as densenet
from models import create_model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description='Inference code for GFNet')
parser.add_argument('--data_url', default='./data', type=str,
help='path to the dataset (ImageNet)')
parser.add_argument('--checkpoint_path', default='', type=str,
help='path to the pre-train model (default: none)')
parser.add_argument('--eval_mode', default=2, type=int,
help='mode 0 : read the evaluation results saved in pre-trained models\
mode 1 : read the confidence thresholds saved in pre-trained models and infer the model on the validation set\
mode 2 : determine confidence thresholds on the training set and infer the model on the validation set')
args = parser.parse_args()
def main():
# load pretrained model
checkpoint = torch.load(args.checkpoint_path)
try:
model_arch = checkpoint['model_name']
patch_size = checkpoint['patch_size']
prime_size = checkpoint['patch_size']
flops = checkpoint['flops']
model_flops = checkpoint['model_flops']
policy_flops = checkpoint['policy_flops']
fc_flops = checkpoint['fc_flops']
anytime_classification = checkpoint['anytime_classification']
budgeted_batch_classification = checkpoint['budgeted_batch_classification']
dynamic_threshold = checkpoint['dynamic_threshold']
maximum_length = len(checkpoint['flops'])
except:
print('Error: \n'
'Please provide essential information'
'for customized models (as we have done '
'in pre-trained models)!\n'
'At least the following information should be Given: \n'
'--model_name: name of the backbone CNNs (e.g., resnet50, densenet121)\n'
'--patch_size: size of image patches (i.e., H\' or W\' in the paper)\n'
'--flops: a list containing the Multiply-Adds corresponding to each '
'length of the input sequence during inference')
model_configuration = model_configurations[model_arch]
if args.eval_mode > 0:
# create model
if 'resnet' in model_arch:
model = resnet.resnet50(pretrained=False)
model_prime = resnet.resnet50(pretrained=False)
elif 'densenet' in model_arch:
model = eval('densenet.' + model_arch)(pretrained=False)
model_prime = eval('densenet.' + model_arch)(pretrained=False)
elif 'efficientnet' in model_arch:
model = create_model(model_arch, pretrained=False, num_classes=1000,
drop_rate=0.3, drop_connect_rate=0.2)
model_prime = create_model(model_arch, pretrained=False, num_classes=1000,
drop_rate=0.3, drop_connect_rate=0.2)
elif 'mobilenetv3' in model_arch:
model = create_model(model_arch, pretrained=False, num_classes=1000,
drop_rate=0.2, drop_connect_rate=0.2)
model_prime = create_model(model_arch, pretrained=False, num_classes=1000,
drop_rate=0.2, drop_connect_rate=0.2)
elif 'regnet' in model_arch:
import pycls.core.model_builder as model_builder
from pycls.core.config import cfg
cfg.merge_from_file(model_configuration['cfg_file'])
cfg.freeze()
model = model_builder.build_model()
model_prime = model_builder.build_model()
traindir = args.data_url + 'train/'
valdir = args.data_url + 'val/'
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_set = datasets.ImageFolder(traindir, transforms.Compose([
transforms.RandomResizedCrop(model_configuration['image_size'], interpolation=model_configuration['dataset_interpolation']),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize ]))
train_set_index = torch.randperm(len(train_set))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, num_workers=32, pin_memory=False,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_set_index[-200000:]))
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(int(model_configuration['image_size']/model_configuration['crop_pct']), interpolation = model_configuration['dataset_interpolation']),
transforms.CenterCrop(model_configuration['image_size']),
transforms.ToTensor(),
normalize])),
batch_size=256, shuffle=False, num_workers=16, pin_memory=False)
state_dim = model_configuration['feature_map_channels'] * math.ceil(patch_size/32) * math.ceil(patch_size/32)
memory = Memory()
policy = ActorCritic(model_configuration['feature_map_channels'], state_dim, model_configuration['policy_hidden_dim'], model_configuration['policy_conv'])
fc = Full_layer(model_configuration['feature_num'], model_configuration['fc_hidden_dim'], model_configuration['fc_rnn'])
model = nn.DataParallel(model.cuda())
model_prime = nn.DataParallel(model_prime.cuda())
policy = policy.cuda()
fc = fc.cuda()
model.load_state_dict(checkpoint['model_state_dict'])
model_prime.load_state_dict(checkpoint['model_prime_state_dict'])
fc.load_state_dict(checkpoint['fc'])
policy.load_state_dict(checkpoint['policy'])
budgeted_batch_flops_list = []
budgeted_batch_acc_list = []
print('generate logits on test samples...')
test_logits, test_targets, anytime_classification = generate_logits(model_prime, model, fc, memory, policy, val_loader, maximum_length, prime_size, patch_size, model_arch)
if args.eval_mode == 2:
print('generate logits on training samples...')
dynamic_threshold = torch.zeros([39, maximum_length])
train_logits, train_targets, _ = generate_logits(model_prime, model, fc, memory, policy, train_loader, maximum_length, prime_size, patch_size, model_arch)
for p in range(1, 40):
print('inference: {}/40'.format(p))
_p = torch.FloatTensor(1).fill_(p * 1.0 / 20)
probs = torch.exp(torch.log(_p) * torch.range(1, maximum_length))
probs /= probs.sum()
if args.eval_mode == 2:
dynamic_threshold[p-1] = dynamic_find_threshold(train_logits, train_targets, probs)
acc_step, flops_step = dynamic_evaluate(test_logits, test_targets, flops, dynamic_threshold[p-1])
budgeted_batch_acc_list.append(acc_step)
budgeted_batch_flops_list.append(flops_step)
budgeted_batch_classification = [budgeted_batch_flops_list, budgeted_batch_acc_list]
print('model_arch :', model_arch)
print('patch_size :', patch_size)
print('flops :', flops)
print('model_flops :', model_flops)
print('policy_flops :', policy_flops)
print('fc_flops :', fc_flops)
print('anytime_classification :', anytime_classification)
print('budgeted_batch_classification :', budgeted_batch_classification)
def generate_logits(model_prime, model, fc, memory, policy, dataloader, maximum_length, prime_size, patch_size, model_arch):
logits_list = []
targets_list = []
top1 = [AverageMeter() for _ in range(maximum_length)]
model.eval()
model_prime.eval()
fc.eval()
for i, (x, target) in enumerate(dataloader):
logits_temp = torch.zeros(maximum_length, x.size(0), 1000)
target_var = target.cuda()
input_var = x.cuda()
input_prime = get_prime(input_var, prime_size, model_configurations[model_arch]['prime_interpolation'])
with torch.no_grad():
output, state = model_prime(input_prime)
if 'resnet' in model_arch or 'densenet' in model_arch:
output = fc(output, restart=True)
elif 'regnet' in model_arch:
_ = fc(output, restart=True)
output = model_prime.module.fc(output)
else:
_ = fc(output, restart=True)
output = model_prime.module.classifier(output)
logits_temp[0] = F.softmax(output, 1)
acc = accuracy(output, target_var, topk=(1,))
top1[0].update(acc.sum(0).mul_(100.0 / x.size(0)).data.item(), x.size(0))
for patch_step in range(1, maximum_length):
with torch.no_grad():
if patch_step == 1:
action = policy.act(state, memory, restart_batch=True)
else:
action = policy.act(state, memory)
patches = get_patch(input_var, action, patch_size)
output, state = model(patches)
output = fc(output, restart=False)
logits_temp[patch_step] = F.softmax(output, 1)
acc = accuracy(output, target_var, topk=(1,))
top1[patch_step].update(acc.sum(0).mul_(100.0 / x.size(0)).data.item(), x.size(0))
logits_list.append(logits_temp)
targets_list.append(target_var)
memory.clear_memory()
anytime_classification = []
for index in range(maximum_length):
anytime_classification.append(top1[index].ave)
return torch.cat(logits_list, 1), torch.cat(targets_list, 0), anytime_classification
def dynamic_find_threshold(logits, targets, p):
n_stage, n_sample, c = logits.size()
max_preds, argmax_preds = logits.max(dim=2, keepdim=False)
_, sorted_idx = max_preds.sort(dim=1, descending=True)
filtered = torch.zeros(n_sample)
T = torch.Tensor(n_stage).fill_(1e8)
for k in range(n_stage - 1):
acc, count = 0.0, 0
out_n = math.floor(n_sample * p[k])
for i in range(n_sample):
ori_idx = sorted_idx[k][i]
if filtered[ori_idx] == 0:
count += 1
if count == out_n:
T[k] = max_preds[k][ori_idx]
break
filtered.add_(max_preds[k].ge(T[k]).type_as(filtered))
T[n_stage - 1] = -1e8
return T
def dynamic_evaluate(logits, targets, flops, T):
n_stage, n_sample, c = logits.size()
max_preds, argmax_preds = logits.max(dim=2, keepdim=False)
_, sorted_idx = max_preds.sort(dim=1, descending=True)
acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage)
acc, expected_flops = 0, 0
for i in range(n_sample):
gold_label = targets[i]
for k in range(n_stage):
if max_preds[k][i].item() >= T[k]: # force the sample to exit at k
if int(gold_label.item()) == int(argmax_preds[k][i].item()):
acc += 1
acc_rec[k] += 1
exp[k] += 1
break
acc_all = 0
for k in range(n_stage):
_t = 1.0 * exp[k] / n_sample
expected_flops += _t * flops[k]
acc_all += acc_rec[k]
return acc * 100.0 / n_sample, expected_flops.item()
if __name__ == '__main__':
main()
================================================
FILE: models/__init__.py
================================================
from .gen_efficientnet import *
from .mobilenetv3 import *
from .model_factory import create_model
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable
from .activations import *
from .resnet import *
from .densenet import *
================================================
FILE: models/activations/__init__.py
================================================
from .config import *
from .activations_autofn import *
from .activations_jit import *
from .activations import *
_ACT_FN_DEFAULT = dict(
swish=swish,
mish=mish,
relu=F.relu,
relu6=F.relu6,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=hard_sigmoid,
hard_swish=hard_swish,
)
_ACT_FN_AUTO = dict(
swish=swish_auto,
mish=mish_auto,
)
_ACT_FN_JIT = dict(
swish=swish_jit,
mish=mish_jit,
#hard_swish=hard_swish_jit,
#hard_sigmoid_jit=hard_sigmoid_jit,
)
_ACT_LAYER_DEFAULT = dict(
swish=Swish,
mish=Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=HardSigmoid,
hard_swish=HardSwish,
)
_ACT_LAYER_AUTO = dict(
swish=SwishAuto,
mish=MishAuto,
)
_ACT_LAYER_JIT = dict(
swish=SwishJit,
mish=MishJit,
#hard_swish=HardSwishJit,
#hard_sigmoid=HardSigmoidJit
)
_OVERRIDE_FN = dict()
_OVERRIDE_LAYER = dict()
def add_override_act_fn(name, fn):
global _OVERRIDE_FN
_OVERRIDE_FN[name] = fn
def update_override_act_fn(overrides):
assert isinstance(overrides, dict)
global _OVERRIDE_FN
_OVERRIDE_FN.update(overrides)
def clear_override_act_fn():
global _OVERRIDE_FN
_OVERRIDE_FN = dict()
def add_override_act_layer(name, fn):
_OVERRIDE_LAYER[name] = fn
def update_override_act_layer(overrides):
assert isinstance(overrides, dict)
global _OVERRIDE_LAYER
_OVERRIDE_LAYER.update(overrides)
def clear_override_act_layer():
global _OVERRIDE_LAYER
_OVERRIDE_LAYER = dict()
def get_act_fn(name='relu'):
""" Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if name in _OVERRIDE_FN:
return _OVERRIDE_FN[name]
if not config.is_exportable() and not config.is_scriptable():
# If not exporting or scripting the model, first look for a JIT optimized version
# of our activation, then a custom autograd.Function variant before defaulting to
# a Python or Torch builtin impl
if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name]
if name in _ACT_FN_AUTO:
return _ACT_FN_AUTO[name]
return _ACT_FN_DEFAULT[name]
def get_act_layer(name='relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if name in _OVERRIDE_LAYER:
return _OVERRIDE_LAYER[name]
if not config.is_exportable() and not config.is_scriptable():
if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name]
if name in _ACT_LAYER_AUTO:
return _ACT_LAYER_AUTO[name]
return _ACT_LAYER_DEFAULT[name]
================================================
FILE: models/activations/activations.py
================================================
from torch import nn as nn
from torch.nn import functional as F
def swish(x, inplace: bool = False):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
"""
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
class Swish(nn.Module):
def __init__(self, inplace: bool = False):
super(Swish, self).__init__()
self.inplace = inplace
def forward(self, x):
return swish(x, self.inplace)
def mish(x, inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
return x.mul(F.softplus(x).tanh())
class Mish(nn.Module):
def __init__(self, inplace: bool = False):
super(Mish, self).__init__()
self.inplace = inplace
def forward(self, x):
return mish(x, self.inplace)
def sigmoid(x, inplace: bool = False):
return x.sigmoid_() if inplace else x.sigmoid()
# PyTorch has this, but not with a consistent inplace argmument interface
class Sigmoid(nn.Module):
def __init__(self, inplace: bool = False):
super(Sigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return x.sigmoid_() if self.inplace else x.sigmoid()
def tanh(x, inplace: bool = False):
return x.tanh_() if inplace else x.tanh()
# PyTorch has this, but not with a consistent inplace argmument interface
class Tanh(nn.Module):
def __init__(self, inplace: bool = False):
super(Tanh, self).__init__()
self.inplace = inplace
def forward(self, x):
return x.tanh_() if self.inplace else x.tanh()
def hard_swish(x, inplace: bool = False):
inner = F.relu6(x + 3.).div_(6.)
return x.mul_(inner) if inplace else x.mul(inner)
class HardSwish(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwish, self).__init__()
self.inplace = inplace
def forward(self, x):
return hard_swish(x, self.inplace)
def hard_sigmoid(x, inplace: bool = False):
if inplace:
return x.add_(3.).clamp_(0., 6.).div_(6.)
else:
return F.relu6(x + 3.) / 6.
class HardSigmoid(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return hard_sigmoid(x, self.inplace)
================================================
FILE: models/activations/activations_autofn.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F
__all__ = ['swish_auto', 'SwishAuto', 'mish_auto', 'MishAuto']
class SwishAutoFn(torch.autograd.Function):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
Memory efficient variant from:
https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76
"""
@staticmethod
def forward(ctx, x):
result = x.mul(torch.sigmoid(x))
ctx.save_for_backward(x)
return result
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
x_sigmoid = torch.sigmoid(x)
return grad_output.mul(x_sigmoid * (1 + x * (1 - x_sigmoid)))
def swish_auto(x, inplace=False):
# inplace ignored
return SwishAutoFn.apply(x)
class SwishAuto(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishAuto, self).__init__()
self.inplace = inplace
def forward(self, x):
return SwishAutoFn.apply(x)
class MishAutoFn(torch.autograd.Function):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
Experimental memory-efficient variant
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
y = x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
return y
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
def mish_auto(x, inplace=False):
# inplace ignored
return MishAutoFn.apply(x)
class MishAuto(nn.Module):
def __init__(self, inplace: bool = False):
super(MishAuto, self).__init__()
self.inplace = inplace
def forward(self, x):
return MishAutoFn.apply(x)
================================================
FILE: models/activations/activations_jit.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F
__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit']
#'hard_swish_jit', 'HardSwishJit', 'hard_sigmoid_jit', 'HardSigmoidJit']
@torch.jit.script
def swish_jit_fwd(x):
return x.mul(torch.sigmoid(x))
@torch.jit.script
def swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish
Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return swish_jit_bwd(x, grad_output)
def swish_jit(x, inplace=False):
# inplace ignored
return SwishJitAutoFn.apply(x)
class SwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishJit, self).__init__()
self.inplace = inplace
def forward(self, x):
return SwishJitAutoFn.apply(x)
@torch.jit.script
def mish_jit_fwd(x):
return x.mul(torch.tanh(F.softplus(x)))
@torch.jit.script
def mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return mish_jit_bwd(x, grad_output)
def mish_jit(x, inplace=False):
# inplace ignored
return MishJitAutoFn.apply(x)
class MishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(MishJit, self).__init__()
self.inplace = inplace
def forward(self, x):
return MishJitAutoFn.apply(x)
# @torch.jit.script
# def hard_swish_jit(x, inplac: bool = False):
# return x.mul(F.relu6(x + 3.).mul_(1./6.))
#
#
# class HardSwishJit(nn.Module):
# def __init__(self, inplace: bool = False):
# super(HardSwishJit, self).__init__()
#
# def forward(self, x):
# return hard_swish_jit(x)
#
#
# @torch.jit.script
# def hard_sigmoid_jit(x, inplace: bool = False):
# return F.relu6(x + 3.).mul(1./6.)
#
#
# class HardSigmoidJit(nn.Module):
# def __init__(self, inplace: bool = False):
# super(HardSigmoidJit, self).__init__()
#
# def forward(self, x):
# return hard_sigmoid_jit(x)
================================================
FILE: models/activations/config.py
================================================
""" Global Config and Constants
"""
__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable']
# Set to True if exporting a model with Same padding via ONNX
_EXPORTABLE = False
# Set to True if wanting to use torch.jit.script on a model
_SCRIPTABLE = False
def is_exportable():
return _EXPORTABLE
def set_exportable(value):
global _EXPORTABLE
_EXPORTABLE = value
def is_scriptable():
return _SCRIPTABLE
def set_scriptable(value):
global _SCRIPTABLE
_SCRIPTABLE = value
================================================
FILE: models/config.py
================================================
""" Global Config and Constants
"""
__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable']
# Set to True if exporting a model with Same padding via ONNX
_EXPORTABLE = False
# Set to True if wanting to use torch.jit.script on a model
_SCRIPTABLE = False
def is_exportable():
return _EXPORTABLE
def set_exportable(value):
global _EXPORTABLE
_EXPORTABLE = value
def is_scriptable():
return _SCRIPTABLE
def set_scriptable(value):
global _SCRIPTABLE
_SCRIPTABLE = value
================================================
FILE: models/conv2d_layers.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._six import container_abcs
from itertools import repeat
from functools import partial
from typing import Union, List, Tuple, Optional, Callable
import numpy as np
import math
from .activations.config import *
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
def _get_padding(kernel_size, stride=1, dilation=1, **_):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
def _calc_same_pad(i: int, k: int, s: int, d: int):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def _same_pad_arg(input_size, kernel_size, stride, dilation):
ih, iw = input_size
kh, kw = kernel_size
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
def _split_channels(num_chan, num_groups):
split = [num_chan // num_groups for _ in range(num_groups)]
split[0] += num_chan - sum(split)
return split
def conv2d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
ih, iw = x.size()[-2:]
kh, kw = weight.size()[-2:]
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
# pylint: disable=unused-argument
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class Conv2dSameExport(nn.Conv2d):
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
NOTE: This does not currently work with torch.jit.script
"""
# pylint: disable=unused-argument
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSameExport, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
self.pad = None
self.pad_input_size = (0, 0)
def forward(self, x):
input_size = x.size()[-2:]
if self.pad is None:
pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
self.pad = nn.ZeroPad2d(pad_arg)
self.pad_input_size = input_size
else:
assert self.pad_input_size == input_size
x = self.pad(x)
return F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def get_padding_value(padding, kernel_size, **kwargs):
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if _is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = _get_padding(kernel_size, **kwargs)
else:
# dynamic padding
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = _get_padding(kernel_size, **kwargs)
return padding, dynamic
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
if is_exportable():
assert not is_scriptable()
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
else:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
class MixedConv2d(nn.ModuleDict):
""" Mixed Grouped Convolution
Based on MDConv and GroupedConv in MixNet impl:
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
super(MixedConv2d, self).__init__()
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
num_groups = len(kernel_size)
in_splits = _split_channels(in_channels, num_groups)
out_splits = _split_channels(out_channels, num_groups)
self.in_channels = sum(in_splits)
self.out_channels = sum(out_splits)
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
conv_groups = out_ch if depthwise else 1
self.add_module(
str(idx),
create_conv2d_pad(
in_ch, out_ch, k, stride=stride,
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
)
self.splits = in_splits
def forward(self, x):
x_split = torch.split(x, self.splits, 1)
x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())]
x = torch.cat(x_out, 1)
return x
def get_condconv_initializer(initializer, num_experts, expert_shape):
def condconv_initializer(weight):
"""CondConv initializer function."""
num_params = np.prod(expert_shape)
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
weight.shape[1] != num_params):
raise (ValueError(
'CondConv variables must have shape [num_experts, num_params]'))
for i in range(num_experts):
initializer(weight[i].view(expert_shape))
return condconv_initializer
class CondConv2d(nn.Module):
""" Conditional Convolution
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
https://github.com/pytorch/pytorch/issues/17983
"""
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
super(CondConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
padding_val, is_padding_dynamic = get_padding_value(
padding, kernel_size, stride=stride, dilation=dilation)
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
self.padding = _pair(padding_val)
self.dilation = _pair(dilation)
self.groups = groups
self.num_experts = num_experts
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight_num_param = 1
for wd in self.weight_shape:
weight_num_param *= wd
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
if bias:
self.bias_shape = (self.out_channels,)
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init_weight = get_condconv_initializer(
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
init_weight(self.weight)
if self.bias is not None:
fan_in = np.prod(self.weight_shape[1:])
bound = 1 / math.sqrt(fan_in)
init_bias = get_condconv_initializer(
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
init_bias(self.bias)
def forward(self, x, routing_weights):
B, C, H, W = x.shape
weight = torch.matmul(routing_weights, self.weight)
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight = weight.view(new_weight_shape)
bias = None
if self.bias is not None:
bias = torch.matmul(routing_weights, self.bias)
bias = bias.view(B * self.out_channels)
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
x = x.view(1, B * C, H, W)
if self.dynamic_padding:
out = conv2d_same(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
else:
out = F.conv2d(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
# Literal port (from TF definition)
# x = torch.split(x, 1, 0)
# weight = torch.split(weight, 1, 0)
# if self.bias is not None:
# bias = torch.matmul(routing_weights, self.bias)
# bias = torch.split(bias, 1, 0)
# else:
# bias = [None] * B
# out = []
# for xi, wi, bi in zip(x, weight, bias):
# wi = wi.view(*self.weight_shape)
# if bi is not None:
# bi = bi.view(*self.bias_shape)
# out.append(self.conv_fn(
# xi, wi, bi, stride=self.stride, padding=self.padding,
# dilation=self.dilation, groups=self.groups))
# out = torch.cat(out, 0)
return out
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
assert 'groups' not in kwargs # only use 'depthwise' bool arg
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
groups = out_chs if depthwise else 1
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
else:
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
return m
================================================
FILE: models/densenet.py
================================================
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
def densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
**kwargs)
if pretrained:
# model.load_state_dict(torch.load(model_urls['densenet121']))
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = torch.load(model_urls['densenet121'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
def densenet169(pretrained=False, **kwargs):
r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = torch.load(model_urls['densenet169'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
def densenet201(pretrained=False, **kwargs):
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = torch.load(model_urls['densenet201'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
def densenet161(pretrained=False, **kwargs):
r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = torch.load(model_urls['densenet161'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = drop_rate
def forward(self, x):
new_features = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return torch.cat([x, new_features], 1)
class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
self.add_module('denselayer%d' % (i + 1), layer)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class DenseNet(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" `_
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
super(DenseNet, self).__init__()
# First convolution
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
self.feature_num = num_features
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Linear layer
self.classifier = nn.Linear(num_features, num_classes)
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal(m.weight.data)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
features = self.features(x)
x = F.relu(features, inplace=True)
out = self.avgpool(x).view(features.size(0), -1)
return out, x.detach()
================================================
FILE: models/efficientnet_builder.py
================================================
import re
from copy import deepcopy
from .conv2d_layers import *
from .activations import *
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
#
# PyTorch defaults are momentum = .1, eps = 1e-5
#
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
def get_bn_args_tf():
return _BN_ARGS_TF.copy()
def resolve_bn_args(kwargs):
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
_SE_ARGS_DEFAULT = dict(
gate_fn=sigmoid,
act_layer=None, # None == use containing block's activation layer
reduce_mid=False,
divisor=1)
def resolve_se_args(kwargs, in_chs, act_layer=None):
se_kwargs = kwargs.copy() if kwargs is not None else {}
# fill in args that aren't specified with the defaults
for k, v in _SE_ARGS_DEFAULT.items():
se_kwargs.setdefault(k, v)
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
if not se_kwargs.pop('reduce_mid'):
se_kwargs['reduced_base_chs'] = in_chs
# act_layer override, if it remains None, the containing block's act_layer will be used
if se_kwargs['act_layer'] is None:
assert act_layer is not None
se_kwargs['act_layer'] = act_layer
return se_kwargs
def resolve_act_layer(kwargs, default='relu'):
act_layer = kwargs.pop('act_layer', default)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
return act_layer
def make_divisible(v: int, divisor: int = 8, min_value: int = None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v: # ensure round down does not go down by more than 10%.
new_v += divisor
return new_v
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
channels *= multiplier
return make_divisible(channels, divisor, channel_min)
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
"""Apply drop connect."""
if not training:
return inputs
keep_prob = 1 - drop_connect_rate
random_tensor = keep_prob + torch.rand(
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
random_tensor.floor_() # binarize
output = inputs.div(keep_prob) * random_tensor
return output
class SqueezeExcite(nn.Module):
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1):
super(SqueezeExcite, self).__init__()
self.gate_fn = gate_fn
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
def forward(self, x):
# tensor.view + mean bad for ONNX export (produces mess of gather ops that break TensorRT)
x_se = self.avg_pool(x)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se)
return x
class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(ConvBnAct, self).__init__()
assert stride in [1, 2]
norm_kwargs = norm_kwargs or {}
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type)
self.bn1 = norm_layer(out_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn1(x)
x = self.act1(x)
return x
class DepthwiseSeparableConv(nn.Module):
""" DepthwiseSeparable block
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
assert stride in [1, 2]
norm_kwargs = norm_kwargs or {}
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.drop_connect_rate = drop_connect_rate
self.conv_dw = select_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
self.bn1 = norm_layer(in_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if se_ratio is not None and se_ratio > 0.:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = nn.Identity()
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) if pw_act else nn.Identity()
def forward(self, x):
residual = x
x = self.conv_dw(x)
x = self.bn1(x)
x = self.act1(x)
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
x = self.act2(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class InvertedResidual(nn.Module):
""" Inverted residual block w/ optional SE"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
conv_kwargs=None, drop_connect_rate=0.):
super(InvertedResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {}
mid_chs: int = make_divisible(in_chs * exp_ratio)
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
# Point-wise expansion
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Depth-wise convolution
self.conv_dw = select_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs)
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation
if se_ratio is not None and se_ratio > 0.:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = nn.Identity() # for jit.script compat
# Point-wise linear projection
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs, **norm_kwargs)
def forward(self, x):
residual = x
# Point-wise expansion
x = self.conv_pw(x)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw(x)
x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class CondConvResidual(InvertedResidual):
""" Inverted residual block w/ CondConv routing"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
num_experts=0, drop_connect_rate=0.):
self.num_experts = num_experts
conv_kwargs = dict(num_experts=self.num_experts)
super(CondConvResidual, self).__init__(
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type,
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
drop_connect_rate=drop_connect_rate)
self.routing_fn = nn.Linear(in_chs, self.num_experts)
def forward(self, x):
residual = x
# CondConv routing
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
# Point-wise expansion
x = self.conv_pw(x, routing_weights)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw(x, routing_weights)
x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x, routing_weights)
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class EdgeResidual(nn.Module):
""" EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride"""
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
super(EdgeResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio)
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
# Expansion convolution
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if se_ratio is not None and se_ratio > 0.:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = nn.Identity()
# Point-wise linear projection
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs)
def forward(self, x):
residual = x
# Expansion convolution
x = self.conv_exp(x)
x = self.bn1(x)
x = self.act1(x)
# Squeeze-and-excitation
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn2(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class EfficientNetBuilder:
""" Build Trunk Blocks for Efficient/Mobile Networks
This ended up being somewhat of a cross between
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
and
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
"""
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_layer=None, se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.pad_type = pad_type
self.act_layer = act_layer
self.se_kwargs = se_kwargs
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
self.drop_connect_rate = drop_connect_rate
# updated during build
self.in_chs = None
self.block_idx = 0
self.block_count = 0
def _round_channels(self, chs):
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
def _make_block(self, ba):
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
# FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['norm_layer'] = self.norm_layer
ba['norm_kwargs'] = self.norm_kwargs
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
if bt == 'ir':
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
ba['se_kwargs'] = self.se_kwargs
if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba)
else:
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
ba['se_kwargs'] = self.se_kwargs
block = DepthwiseSeparableConv(**ba)
elif bt == 'er':
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
ba['se_kwargs'] = self.se_kwargs
block = EdgeResidual(**ba)
elif bt == 'cn':
block = ConvBnAct(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block
def _make_stack(self, stack_args):
blocks = []
# each stack (stage) contains a list of block arguments
for i, ba in enumerate(stack_args):
if i >= 1:
# only the first block in any stack can have a stride > 1
ba['stride'] = 1
block = self._make_block(ba)
blocks.append(block)
self.block_idx += 1 # incr global idx (across all stacks)
return nn.Sequential(*blocks)
def __call__(self, in_chs, block_args):
""" Build the blocks
Args:
in_chs: Number of input-channels passed to first block
block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
self.in_chs = in_chs
self.block_count = sum([len(x) for x in block_args])
self.block_idx = 0
blocks = []
# outer list of block_args defines the stacks ('stages' by some conventions)
for stack_idx, stack in enumerate(block_args):
assert isinstance(stack, list)
stack = self._make_stack(stack)
blocks.append(stack)
return blocks
def _parse_ksize(ss):
if ss.isdigit():
return int(ss)
else:
return [int(k) for k in ss.split('.')]
def _decode_block_str(block_str):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
A list of block args (dicts)
Raises:
ValueError: if the string def not properly specified (TODO)
"""
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
for op in ops:
# string options being checked on individual basis, combine if they grow
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v == 're':
value = get_act_layer('relu')
elif v == 'r6':
value = get_act_layer('relu6')
elif v == 'hs':
value = get_act_layer('hard_swish')
elif v == 'sw':
value = get_act_layer('swish')
else:
continue
options[key] = value
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# if act_layer is None, the model default (passed to model init) will be used
act_layer = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
)
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'er':
block_args = dict(
block_type=block_type,
exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
fake_in_chs=fake_in_chs,
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
)
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
)
else:
assert False, 'Unknown block type (%s)' % block_type
return block_args, num_repeat
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
""" Per-stage depth scaling
Scales the block repeats in each stage. This depth scaling impl maintains
compatibility with the EfficientNet scaling method, while allowing sensible
scaling for other models that may have multiple block arg definitions in each stage.
"""
# We scale the total repeat count for each stage, there may be multiple
# block arg defs per stage so we need to sum.
num_repeat = sum(repeats)
if depth_trunc == 'round':
# Truncating to int by rounding allows stages with few repeats to remain
# proportionally smaller for longer. This is a good choice when stage definitions
# include single repeat stages that we'd prefer to keep that way as long as possible
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
else:
# The default for EfficientNet truncates repeats to int via 'ceil'.
# Any multiplier > 1.0 will result in an increased depth for every stage.
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
# Proportionally distribute repeat count scaling to each block definition in the stage.
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
# The first block makes less sense to repeat in most of the arch definitions.
repeats_scaled = []
for r in repeats[::-1]:
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
repeats_scaled.append(rs)
num_repeat -= r
num_repeat_scaled -= rs
repeats_scaled = repeats_scaled[::-1]
# Apply the calculated scaling to each block arg in the stage
sa_scaled = []
for ba, rep in zip(stack_args, repeats_scaled):
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
return sa_scaled
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
arch_args = []
for stack_idx, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
stack_args = []
repeats = []
for block_str in block_strings:
assert isinstance(block_str, str)
ba, rep = _decode_block_str(block_str)
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
ba['num_experts'] *= experts_multiplier
stack_args.append(ba)
repeats.append(rep)
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
else:
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
return arch_args
def initialize_weight_goog(m, n='', fix_group_fanout=True):
# weight init as per Tensorflow Official impl
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
if isinstance(m, CondConv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
init_weight_fn = get_condconv_initializer(
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
init_weight_fn(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
fan_out = m.weight.size(0) # fan-out
fan_in = 0
if 'routing_fn' in n:
fan_in = m.weight.size(1)
init_range = 1.0 / math.sqrt(fan_in + fan_out)
m.weight.data.uniform_(-init_range, init_range)
m.bias.data.zero_()
def initialize_weight_default(m, n=''):
if isinstance(m, CondConv2d):
init_fn = get_condconv_initializer(partial(
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
init_fn(m.weight)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
================================================
FILE: models/gen_efficientnet.py
================================================
""" Generic Efficient Networks
A generic MobileNet class with building blocks to support a variety of models:
* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports)
- EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
- CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
- Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665
- Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252
* EfficientNet-Lite
* MixNet (Small, Medium, and Large)
- MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595
* MNasNet B1, A1 (SE), Small
- MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626
* FBNet-C
- FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443
* Single-Path NAS Pixel1
- Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877
* And likely more...
Hacked together by Ross Wightman
"""
import torch.nn as nn
import torch.nn.functional as F
from .helpers import load_pretrained
from .efficientnet_builder import *
__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140',
'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small',
'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3',
'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8',
'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el',
'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e',
'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4',
'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3',
'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8',
'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap',
'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap',
'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns',
'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns',
'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475',
'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el',
'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e',
'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3',
'tf_efficientnet_lite4',
'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l']
model_urls = {
'mnasnet_050': None,
'mnasnet_075': None,
'mnasnet_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
'mnasnet_140': None,
'mnasnet_small': None,
'semnasnet_050': None,
'semnasnet_075': None,
'semnasnet_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth',
'semnasnet_140': None,
'fbnetc_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
'spnasnet_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
'efficientnet_b0':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth',
'efficientnet_b1':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
'efficientnet_b2':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth',
'efficientnet_b3':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra-a5e2fbc7.pth',
'efficientnet_b4': None,
'efficientnet_b5': None,
'efficientnet_b6': None,
'efficientnet_b7': None,
'efficientnet_b8': None,
'efficientnet_l2': None,
'efficientnet_es':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth',
'efficientnet_em': None,
'efficientnet_el': None,
'efficientnet_cc_b0_4e': None,
'efficientnet_cc_b0_8e': None,
'efficientnet_cc_b1_8e': None,
'efficientnet_lite0': None,
'efficientnet_lite1': None,
'efficientnet_lite2': None,
'efficientnet_lite3': None,
'efficientnet_lite4': None,
'tf_efficientnet_b0':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
'tf_efficientnet_b1':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
'tf_efficientnet_b2':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
'tf_efficientnet_b3':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
'tf_efficientnet_b4':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
'tf_efficientnet_b5':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
'tf_efficientnet_b6':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
'tf_efficientnet_b7':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
'tf_efficientnet_b8':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
'tf_efficientnet_b0_ap':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
'tf_efficientnet_b1_ap':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth',
'tf_efficientnet_b2_ap':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth',
'tf_efficientnet_b3_ap':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth',
'tf_efficientnet_b4_ap':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth',
'tf_efficientnet_b5_ap':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth',
'tf_efficientnet_b6_ap':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth',
'tf_efficientnet_b7_ap':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth',
'tf_efficientnet_b8_ap':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',
'tf_efficientnet_b0_ns':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
'tf_efficientnet_b1_ns':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
'tf_efficientnet_b2_ns':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
'tf_efficientnet_b3_ns':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
'tf_efficientnet_b4_ns':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
'tf_efficientnet_b5_ns':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
'tf_efficientnet_b6_ns':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
'tf_efficientnet_b7_ns':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
'tf_efficientnet_l2_ns_475':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
'tf_efficientnet_l2_ns':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
'tf_efficientnet_es':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
'tf_efficientnet_em':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth',
'tf_efficientnet_el':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
'tf_efficientnet_cc_b0_4e':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth',
'tf_efficientnet_cc_b0_8e':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth',
'tf_efficientnet_cc_b1_8e':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth',
'tf_efficientnet_lite0':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth',
'tf_efficientnet_lite1':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth',
'tf_efficientnet_lite2':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth',
'tf_efficientnet_lite3':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth',
'tf_efficientnet_lite4':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth',
'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth',
'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth',
'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth',
'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth',
'tf_mixnet_s':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth',
'tf_mixnet_m':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth',
'tf_mixnet_l':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth',
}
class GenEfficientNet(nn.Module):
""" Generic EfficientNets
An implementation of mobile optimized networks that covers:
* EfficientNet (B0-B8, L2, CondConv, EdgeTPU)
* MixNet (Small, Medium, and Large, XL)
* MNASNet A1, B1, and small
* FBNet C
* Single-Path NAS Pixel1
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
weight_init='goog'):
super(GenEfficientNet, self).__init__()
self.drop_rate = drop_rate
if not fix_stem:
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
in_chs = stem_size
builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min,
pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate)
self.blocks = nn.Sequential(*builder(in_chs, block_args))
in_chs = builder.in_chs
self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type)
self.bn2 = norm_layer(num_features, **norm_kwargs)
self.act2 = act_layer(inplace=True)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(num_features, num_classes)
for n, m in self.named_modules():
if weight_init == 'goog':
initialize_weight_goog(m, n)
else:
initialize_weight_default(m, n)
self.feature_num = num_features
def features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x)
x = self.conv_head(x)
x = self.bn2(x)
x = self.act2(x)
return x
def as_sequential(self):
layers = [self.conv_stem, self.bn1, self.act1]
layers.extend(self.blocks)
layers.extend([
self.conv_head, self.bn2, self.act2,
self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
def forward(self, x):
feature = self.features(x)
x = self.global_pool(feature)
x = x.flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return x, feature.detach()
def _create_model(model_kwargs, variant, pretrained=False):
as_sequential = model_kwargs.pop('as_sequential', False)
model = GenEfficientNet(**model_kwargs)
if pretrained:
load_pretrained(model, model_urls[variant])
if as_sequential:
model = model.as_sequential()
return model
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a mnasnet-a1 model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
Paper: https://arxiv.org/pdf/1807.11626.pdf.
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_noskip'],
# stage 1, 112x112 in
['ir_r2_k3_s2_e6_c24'],
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25'],
# stage 3, 28x28 in
['ir_r4_k3_s2_e6_c80'],
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'],
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_se0.25'],
# stage 6, 7x7 in
['ir_r1_k3_s1_e6_c320'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
stem_size=32,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'relu'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a mnasnet-b1 model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
Paper: https://arxiv.org/pdf/1807.11626.pdf.
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_c16_noskip'],
# stage 1, 112x112 in
['ir_r3_k3_s2_e3_c24'],
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40'],
# stage 3, 28x28 in
['ir_r3_k5_s2_e6_c80'],
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c96'],
# stage 5, 14x14in
['ir_r4_k5_s2_e6_c192'],
# stage 6, 7x7 in
['ir_r1_k3_s1_e6_c320_noskip']
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
stem_size=32,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'relu'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a mnasnet-b1 model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
Paper: https://arxiv.org/pdf/1807.11626.pdf.
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
['ds_r1_k3_s1_c8'],
['ir_r1_k3_s2_e3_c16'],
['ir_r2_k3_s2_e6_c16'],
['ir_r4_k5_s2_e6_c32_se0.25'],
['ir_r3_k3_s1_e6_c32_se0.25'],
['ir_r3_k5_s2_e6_c88_se0.25'],
['ir_r1_k3_s1_e6_c144']
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
stem_size=8,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'relu'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
""" FBNet-C
Paper: https://arxiv.org/abs/1812.03443
Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper,
it was used to confirm some building block details
"""
arch_def = [
['ir_r1_k3_s1_e1_c16'],
['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'],
['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'],
['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'],
['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'],
['ir_r4_k5_s2_e6_c184'],
['ir_r1_k3_s1_e6_c352'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
stem_size=16,
num_features=1984, # paper suggests this, but is not 100% clear
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'relu'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates the Single-Path NAS model from search targeted for Pixel1 phone.
Paper: https://arxiv.org/abs/1904.02877
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_c16_noskip'],
# stage 1, 112x112 in
['ir_r3_k3_s2_e3_c24'],
# stage 2, 56x56 in
['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'],
# stage 3, 28x28 in
['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'],
# stage 4, 14x14in
['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'],
# stage 5, 14x14in
['ir_r4_k5_s2_e6_c192'],
# stage 6, 7x7 in
['ir_r1_k3_s1_e6_c320_noskip']
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
stem_size=32,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'relu'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
Args:
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
"""
arch_def = [
['ds_r1_k3_s1_e1_c16_se0.25'],
['ir_r2_k3_s2_e6_c24_se0.25'],
['ir_r2_k5_s2_e6_c40_se0.25'],
['ir_r3_k3_s2_e6_c80_se0.25'],
['ir_r3_k5_s1_e6_c112_se0.25'],
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=round_channels(1280, channel_multiplier, 8, None),
stem_size=32,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'swish'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
arch_def = [
# NOTE `fc` is present to override a mismatch between stem channels and in chs not
# present in other models
['er_r1_k3_s1_e4_c24_fc24_noskip'],
['er_r2_k3_s2_e8_c32'],
['er_r4_k3_s2_e8_c48'],
['ir_r5_k5_s2_e8_c96'],
['ir_r4_k5_s1_e8_c144'],
['ir_r2_k5_s2_e8_c192'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=round_channels(1280, channel_multiplier, 8, None),
stem_size=32,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'relu'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_efficientnet_condconv(
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
"""Creates an efficientnet-condconv model."""
arch_def = [
['ds_r1_k3_s1_e1_c16_se0.25'],
['ir_r2_k3_s2_e6_c24_se0.25'],
['ir_r2_k5_s2_e6_c40_se0.25'],
['ir_r3_k3_s2_e6_c80_se0.25'],
['ir_r3_k5_s1_e6_c112_se0.25_cc4'],
['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
num_features=round_channels(1280, channel_multiplier, 8, None),
stem_size=32,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'swish'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
"""Creates an EfficientNet-Lite model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
Args:
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
"""
arch_def = [
['ds_r1_k3_s1_e1_c16'],
['ir_r2_k3_s2_e6_c24'],
['ir_r2_k5_s2_e6_c40'],
['ir_r3_k3_s2_e6_c80'],
['ir_r3_k5_s1_e6_c112'],
['ir_r4_k5_s2_e6_c192'],
['ir_r1_k3_s1_e6_c320'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
num_features=1280,
stem_size=32,
fix_stem=True,
channel_multiplier=channel_multiplier,
act_layer=nn.ReLU6,
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MixNet Small model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
Paper: https://arxiv.org/abs/1907.09595
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16'], # relu
# stage 1, 112x112 in
['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu
# stage 2, 56x56 in
['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
# stage 3, 28x28 in
['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish
# stage 4, 14x14in
['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
# stage 5, 14x14in
['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
# 7x7
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=1536,
stem_size=16,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'relu'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MixNet Medium-Large model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
Paper: https://arxiv.org/abs/1907.09595
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c24'], # relu
# stage 1, 112x112 in
['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu
# stage 2, 56x56 in
['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
# stage 3, 28x28 in
['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish
# stage 4, 14x14in
['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
# stage 5, 14x14in
['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
# 7x7
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
num_features=1536,
stem_size=24,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'relu'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def mnasnet_050(pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 0.5. """
model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs)
return model
def mnasnet_075(pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 0.75. """
model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs)
return model
def mnasnet_100(pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 1.0. """
model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mnasnet_b1(pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 1.0. """
return mnasnet_100(pretrained, **kwargs)
def mnasnet_140(pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 1.4 """
model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs)
return model
def semnasnet_050(pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 0.5 """
model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs)
return model
def semnasnet_075(pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 0.75. """
model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs)
return model
def semnasnet_100(pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mnasnet_a1(pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
return semnasnet_100(pretrained, **kwargs)
def semnasnet_140(pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs)
return model
def mnasnet_small(pretrained=False, **kwargs):
""" MNASNet Small, depth multiplier of 1.0. """
model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs)
return model
def fbnetc_100(pretrained=False, **kwargs):
""" FBNet-C """
if pretrained:
# pretrained model trained with non-default BN epsilon
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs)
return model
def spnasnet_100(pretrained=False, **kwargs):
""" Single-Path NAS Pixel1"""
model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs)
return model
def efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0 """
# NOTE for train set drop_rate=0.2, drop_connect_rate=0.2
model = _gen_efficientnet(
'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def efficientnet_b1(pretrained=False, **kwargs):
""" EfficientNet-B1 """
# NOTE for train set drop_rate=0.2, drop_connect_rate=0.2
model = _gen_efficientnet(
'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def efficientnet_b2(pretrained=False, **kwargs):
""" EfficientNet-B2 """
# NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
model = _gen_efficientnet(
'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
def efficientnet_b3(pretrained=False, **kwargs):
""" EfficientNet-B3 """
# NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
model = _gen_efficientnet(
'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def efficientnet_b4(pretrained=False, **kwargs):
""" EfficientNet-B4 """
# NOTE for train set drop_rate=0.4, drop_connect_rate=0.2
model = _gen_efficientnet(
'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
def efficientnet_b5(pretrained=False, **kwargs):
""" EfficientNet-B5 """
# NOTE for train set drop_rate=0.4, drop_connect_rate=0.2
model = _gen_efficientnet(
'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
return model
def efficientnet_b6(pretrained=False, **kwargs):
""" EfficientNet-B6 """
# NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
model = _gen_efficientnet(
'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
return model
def efficientnet_b7(pretrained=False, **kwargs):
""" EfficientNet-B7 """
# NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
model = _gen_efficientnet(
'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
return model
def efficientnet_b8(pretrained=False, **kwargs):
""" EfficientNet-B8 """
# NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
model = _gen_efficientnet(
'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
return model
def efficientnet_l2(pretrained=False, **kwargs):
""" EfficientNet-L2. """
# NOTE for train, drop_rate should be 0.5
model = _gen_efficientnet(
'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
return model
def efficientnet_es(pretrained=False, **kwargs):
""" EfficientNet-Edge Small. """
model = _gen_efficientnet_edge(
'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def efficientnet_em(pretrained=False, **kwargs):
""" EfficientNet-Edge-Medium. """
model = _gen_efficientnet_edge(
'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def efficientnet_el(pretrained=False, **kwargs):
""" EfficientNet-Edge-Large. """
model = _gen_efficientnet_edge(
'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 8 Experts """
# NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
model = _gen_efficientnet_condconv(
'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 8 Experts """
# NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
model = _gen_efficientnet_condconv(
'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
pretrained=pretrained, **kwargs)
return model
def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B1 w/ 8 Experts """
# NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
model = _gen_efficientnet_condconv(
'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
pretrained=pretrained, **kwargs)
return model
def efficientnet_lite0(pretrained=False, **kwargs):
""" EfficientNet-Lite0 """
model = _gen_efficientnet_lite(
'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def efficientnet_lite1(pretrained=False, **kwargs):
""" EfficientNet-Lite1 """
model = _gen_efficientnet_lite(
'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def efficientnet_lite2(pretrained=False, **kwargs):
""" EfficientNet-Lite2 """
model = _gen_efficientnet_lite(
'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
def efficientnet_lite3(pretrained=False, **kwargs):
""" EfficientNet-Lite3 """
model = _gen_efficientnet_lite(
'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def efficientnet_lite4(pretrained=False, **kwargs):
""" EfficientNet-Lite4 """
model = _gen_efficientnet_lite(
'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0 AutoAug. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b1(pretrained=False, **kwargs):
""" EfficientNet-B1 AutoAug. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b2(pretrained=False, **kwargs):
""" EfficientNet-B2 AutoAug. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b3(pretrained=False, **kwargs):
""" EfficientNet-B3 AutoAug. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b4(pretrained=False, **kwargs):
""" EfficientNet-B4 AutoAug. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b5(pretrained=False, **kwargs):
""" EfficientNet-B5 RandAug. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b6(pretrained=False, **kwargs):
""" EfficientNet-B6 AutoAug. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b7(pretrained=False, **kwargs):
""" EfficientNet-B7 RandAug. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b8(pretrained=False, **kwargs):
""" EfficientNet-B8 RandAug. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b0_ap(pretrained=False, **kwargs):
""" EfficientNet-B0 AdvProp. Tensorflow compatible variant
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b1_ap(pretrained=False, **kwargs):
""" EfficientNet-B1 AdvProp. Tensorflow compatible variant
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b2_ap(pretrained=False, **kwargs):
""" EfficientNet-B2 AdvProp. Tensorflow compatible variant
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b3_ap(pretrained=False, **kwargs):
""" EfficientNet-B3 AdvProp. Tensorflow compatible variant
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b4_ap(pretrained=False, **kwargs):
""" EfficientNet-B4 AdvProp. Tensorflow compatible variant
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b5_ap(pretrained=False, **kwargs):
""" EfficientNet-B5 AdvProp. Tensorflow compatible variant
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b6_ap(pretrained=False, **kwargs):
""" EfficientNet-B6 AdvProp. Tensorflow compatible variant
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
"""
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b7_ap(pretrained=False, **kwargs):
""" EfficientNet-B7 AdvProp. Tensorflow compatible variant
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
"""
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b8_ap(pretrained=False, **kwargs):
""" EfficientNet-B8 AdvProp. Tensorflow compatible variant
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
"""
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b0_ns(pretrained=False, **kwargs):
""" EfficientNet-B0 NoisyStudent. Tensorflow compatible variant
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b1_ns(pretrained=False, **kwargs):
""" EfficientNet-B1 NoisyStudent. Tensorflow compatible variant
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b2_ns(pretrained=False, **kwargs):
""" EfficientNet-B2 NoisyStudent. Tensorflow compatible variant
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b3_ns(pretrained=False, **kwargs):
""" EfficientNet-B3 NoisyStudent. Tensorflow compatible variant
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b4_ns(pretrained=False, **kwargs):
""" EfficientNet-B4 NoisyStudent. Tensorflow compatible variant
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b5_ns(pretrained=False, **kwargs):
""" EfficientNet-B5 NoisyStudent. Tensorflow compatible variant
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b6_ns(pretrained=False, **kwargs):
""" EfficientNet-B6 NoisyStudent. Tensorflow compatible variant
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
"""
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_b7_ns(pretrained=False, **kwargs):
""" EfficientNet-B7 NoisyStudent. Tensorflow compatible variant
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
"""
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs):
""" EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
"""
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_l2_ns(pretrained=False, **kwargs):
""" EfficientNet-L2 NoisyStudent. Tensorflow compatible variant
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
"""
# NOTE for train, drop_rate should be 0.5
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_es(pretrained=False, **kwargs):
""" EfficientNet-Edge Small. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_edge(
'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_em(pretrained=False, **kwargs):
""" EfficientNet-Edge-Medium. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_edge(
'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_el(pretrained=False, **kwargs):
""" EfficientNet-Edge-Large. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_edge(
'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 4 Experts """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_condconv(
'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 8 Experts """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_condconv(
'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B1 w/ 8 Experts """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_condconv(
'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_lite0(pretrained=False, **kwargs):
""" EfficientNet-Lite0. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_lite(
'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_lite1(pretrained=False, **kwargs):
""" EfficientNet-Lite1. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_lite(
'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_lite2(pretrained=False, **kwargs):
""" EfficientNet-Lite2. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_lite(
'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_lite3(pretrained=False, **kwargs):
""" EfficientNet-Lite3. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_lite(
'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
def tf_efficientnet_lite4(pretrained=False, **kwargs):
""" EfficientNet-Lite4. Tensorflow compatible variant """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_lite(
'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
def mixnet_s(pretrained=False, **kwargs):
"""Creates a MixNet Small model.
"""
# NOTE for train set drop_rate=0.2
model = _gen_mixnet_s(
'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def mixnet_m(pretrained=False, **kwargs):
"""Creates a MixNet Medium model.
"""
# NOTE for train set drop_rate=0.25
model = _gen_mixnet_m(
'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def mixnet_l(pretrained=False, **kwargs):
"""Creates a MixNet Large model.
"""
# NOTE for train set drop_rate=0.25
model = _gen_mixnet_m(
'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
return model
def mixnet_xl(pretrained=False, **kwargs):
"""Creates a MixNet Extra-Large model.
Not a paper spec, experimental def by RW w/ depth scaling.
"""
# NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
model = _gen_mixnet_m(
'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
def mixnet_xxl(pretrained=False, **kwargs):
"""Creates a MixNet Double Extra Large model.
Not a paper spec, experimental def by RW w/ depth scaling.
"""
# NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
model = _gen_mixnet_m(
'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs)
return model
def tf_mixnet_s(pretrained=False, **kwargs):
"""Creates a MixNet Small model. Tensorflow compatible variant
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mixnet_s(
'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_mixnet_m(pretrained=False, **kwargs):
"""Creates a MixNet Medium model. Tensorflow compatible variant
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mixnet_m(
'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
def tf_mixnet_l(pretrained=False, **kwargs):
"""Creates a MixNet Large model. Tensorflow compatible variant
"""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mixnet_m(
'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
return model
================================================
FILE: models/helpers.py
================================================
import torch
import os
from collections import OrderedDict
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
def load_checkpoint(model, checkpoint_path):
if checkpoint_path and os.path.isfile(checkpoint_path):
print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_pretrained(model, url, filter_fn=None, strict=True):
if not url:
print("=> Warning: Pretrained model URL is empty, using random initialization.")
return
state_dict = torch.load(url, map_location='cpu')
input_conv = 'conv_stem'
classifier = 'classifier'
in_chans = getattr(model, input_conv).weight.shape[1]
num_classes = getattr(model, classifier).weight.shape[0]
input_conv_weight = input_conv + '.weight'
pretrained_in_chans = state_dict[input_conv_weight].shape[1]
if in_chans != pretrained_in_chans:
if in_chans == 1:
print('=> Converting pretrained input conv {} from {} to 1 channel'.format(
input_conv_weight, pretrained_in_chans))
conv1_weight = state_dict[input_conv_weight]
state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True)
else:
print('=> Discarding pretrained input conv {} since input channel count != {}'.format(
input_conv_weight, pretrained_in_chans))
del state_dict[input_conv_weight]
strict = False
classifier_weight = classifier + '.weight'
pretrained_num_classes = state_dict[classifier_weight].shape[0]
if num_classes != pretrained_num_classes:
print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
del state_dict[classifier_weight]
del state_dict[classifier + '.bias']
strict = False
if filter_fn is not None:
state_dict = filter_fn(state_dict)
model.load_state_dict(state_dict, strict=strict)
================================================
FILE: models/mobilenetv3.py
================================================
""" MobileNet-V3
A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by Ross Wightman
"""
import torch.nn as nn
import torch.nn.functional as F
from .helpers import load_pretrained
from .efficientnet_builder import *
__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100',
'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100',
'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100',
'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100', 'mobilenetv3_large_125']
model_urls = {
'mobilenetv3_rw':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
'mobilenetv3_large_075': None,
'mobilenetv3_large_100': None,
'mobilenetv3_large_125': None,
'mobilenetv3_large_minimal_100': None,
'mobilenetv3_small_075': None,
'mobilenetv3_small_100': None,
'mobilenetv3_small_minimal_100': None,
'tf_mobilenetv3_large_075':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
'tf_mobilenetv3_large_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
'tf_mobilenetv3_large_minimal_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
'tf_mobilenetv3_small_075':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
'tf_mobilenetv3_small_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
'tf_mobilenetv3_small_minimal_100':
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
}
class MobileNetV3(nn.Module):
""" MobileNet-V3
A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the
head convolution without a final batch-norm layer before the classifier.
Paper: https://arxiv.org/abs/1905.02244
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'):
super(MobileNetV3, self).__init__()
self.drop_rate = drop_rate
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
in_chs = stem_size
builder = EfficientNetBuilder(
channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate)
self.blocks = nn.Sequential(*builder(in_chs, block_args))
in_chs = builder.in_chs
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
self.feature_num = num_features
print(num_features)
self.classifier = nn.Linear(num_features, num_classes)
for m in self.modules():
if weight_init == 'goog':
initialize_weight_goog(m)
else:
initialize_weight_default(m)
def as_sequential(self):
layers = [self.conv_stem, self.bn1, self.act1]
layers.extend(self.blocks)
layers.extend([
self.global_pool, self.conv_head, self.act2,
nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
def features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
features = self.blocks(x)
x = self.global_pool(features)
x = self.conv_head(x)
x = self.act2(x)
return x, features.detach()
def forward(self, x):
x, features = self.features(x)
x = x.flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return (x), features.detach()
def _create_model(model_kwargs, variant, pretrained=False):
as_sequential = model_kwargs.pop('as_sequential', False)
model = MobileNetV3(**model_kwargs)
if pretrained and model_urls[variant]:
load_pretrained(model, model_urls[variant])
if as_sequential:
model = model.as_sequential()
return model
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 model (RW variant).
Paper: https://arxiv.org/abs/1905.02244
This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the
eventual Tensorflow reference impl but has a few differences:
1. This model has no bias on the head convolution
2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet
3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer
from their parent block
4. This model does not enforce divisible by 8 limitation on the SE reduction channel count
Overall the changes are fairly minor and result in a very small parameter count difference and no
top-1/5
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
head_bias=False, # one of my mistakes
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 large/small/minimal models.
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
Paper: https://arxiv.org/abs/1905.02244
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
if 'small' in variant:
num_features = 1024
if 'minimal' in variant:
act_layer = 'relu'
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16'],
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
# stage 2, 28x28 in
['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
# stage 3, 14x14 in
['ir_r2_k3_s1_e3_c48'],
# stage 4, 14x14in
['ir_r3_k3_s2_e6_c96'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'],
]
else:
act_layer = 'hard_swish'
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
# stage 2, 28x28 in
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
# stage 3, 14x14 in
['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
# stage 4, 14x14in
['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'], # hard-swish
]
else:
num_features = 1280
if 'minimal' in variant:
act_layer = 'relu'
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16'],
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
# stage 2, 56x56 in
['ir_r3_k3_s2_e3_c40'],
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112'],
# stage 5, 14x14in
['ir_r3_k3_s2_e6_c160'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'],
]
else:
act_layer = 'hard_swish'
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=16,
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, act_layer),
se_kwargs=dict(
act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, variant, pretrained)
return model
def mobilenetv3_rw(pretrained=False, **kwargs):
""" MobileNet-V3 RW
Attn: See note in gen function for this variant.
"""
# NOTE for train set drop_rate=0.2
if pretrained:
# pretrained model trained with non-default BN epsilon
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_large_075(pretrained=False, **kwargs):
""" MobileNet V3 Large 0.75"""
# NOTE for train set drop_rate=0.2
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_large_100(pretrained=False, **kwargs):
""" MobileNet V3 Large 1.0 """
# NOTE for train set drop_rate=0.2
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_large_125(pretrained=False, **kwargs):
""" MobileNet V3 Large 1.25 """
# NOTE for train set drop_rate=0.2
model = _gen_mobilenet_v3('mobilenetv3_large_125', 1.25, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
""" MobileNet V3 Large (Minimalistic) 1.0 """
# NOTE for train set drop_rate=0.2
model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_small_075(pretrained=False, **kwargs):
""" MobileNet V3 Small 0.75 """
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_small_100(pretrained=False, **kwargs):
""" MobileNet V3 Small 1.0 """
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model
def mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
""" MobileNet V3 Small (Minimalistic) 1.0 """
model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
""" MobileNet V3 Large 0.75. Tensorflow compat variant. """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
""" MobileNet V3 Large 1.0. Tensorflow compat variant. """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
""" MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
""" MobileNet V3 Small 0.75. Tensorflow compat variant. """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
""" MobileNet V3 Small 1.0. Tensorflow compat variant."""
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model
def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
""" MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model
================================================
FILE: models/model_factory.py
================================================
from .mobilenetv3 import *
from .gen_efficientnet import *
from .helpers import load_checkpoint
def create_model(
model_name='mnasnet_100',
pretrained=None,
num_classes=1000,
in_chans=3,
checkpoint_path='',
**kwargs):
margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)
if model_name in globals():
create_fn = globals()[model_name]
model = create_fn(**margs, **kwargs)
else:
raise RuntimeError('Unknown model (%s)' % model_name)
if checkpoint_path and not pretrained:
load_checkpoint(model, checkpoint_path)
return model
================================================
FILE: models/resnet.py
================================================
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, groups=groups, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.inplanes = 64
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
if norm_layer is None:
norm_layer = nn.BatchNorm2d
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
output = self.avgpool(x)
output = output.view(x.size(0), -1)
return output, x.detach()
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
def resnext50_32x4d(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d']))
return model
def resnext101_32x8d(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d']))
return model
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
================================================
FILE: models/version.py
================================================
__version__ = '0.9.8'
================================================
FILE: network.py
================================================
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import math
class Memory:
def __init__(self):
self.actions = []
self.states = []
self.logprobs = []
self.rewards = []
self.is_terminals = []
self.hidden = []
def clear_memory(self):
del self.actions[:]
del self.states[:]
del self.logprobs[:]
del self.rewards[:]
del self.is_terminals[:]
del self.hidden[:]
class ActorCritic(nn.Module):
def __init__(self, feature_dim, state_dim, hidden_state_dim=1024, policy_conv=True, action_std=0.1):
super(ActorCritic, self).__init__()
# encoder with convolution layer for MobileNetV3, EfficientNet and RegNet
if policy_conv:
self.state_encoder = nn.Sequential(
nn.Conv2d(feature_dim, 32, kernel_size=1, stride=1, padding=0, bias=False),
nn.ReLU(),
nn.Flatten(),
nn.Linear(int(state_dim * 32 / feature_dim), hidden_state_dim),
nn.ReLU()
)
# encoder with linear layer for ResNet and DenseNet
else:
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, 2048),
nn.ReLU(),
nn.Linear(2048, hidden_state_dim),
nn.ReLU()
)
self.gru = nn.GRU(hidden_state_dim, hidden_state_dim, batch_first=False)
self.actor = nn.Sequential(
nn.Linear(hidden_state_dim, 2),
nn.Sigmoid())
self.critic = nn.Sequential(
nn.Linear(hidden_state_dim, 1))
self.action_var = torch.full((2,), action_std).cuda()
self.hidden_state_dim = hidden_state_dim
self.policy_conv = policy_conv
self.feature_dim = feature_dim
self.feature_ratio = int(math.sqrt(state_dim/feature_dim))
def forward(self):
raise NotImplementedError
def act(self, state_ini, memory, restart_batch=False, training=False):
if restart_batch:
del memory.hidden[:]
memory.hidden.append(torch.zeros(1, state_ini.size(0), self.hidden_state_dim).cuda())
if not self.policy_conv:
state = state_ini.flatten(1)
else:
state = state_ini
state = self.state_encoder(state)
state, hidden_output = self.gru(state.view(1, state.size(0), state.size(1)), memory.hidden[-1])
memory.hidden.append(hidden_output)
state = state[0]
action_mean = self.actor(state)
cov_mat = torch.diag(self.action_var).cuda()
dist = torch.distributions.multivariate_normal.MultivariateNormal(action_mean, scale_tril=cov_mat)
action = dist.sample().cuda()
if training:
action = F.relu(action)
action = 1 - F.relu(1 - action)
action_logprob = dist.log_prob(action).cuda()
memory.states.append(state_ini)
memory.actions.append(action)
memory.logprobs.append(action_logprob)
else:
action = action_mean
return action.detach()
def evaluate(self, state, action):
seq_l = state.size(0)
batch_size = state.size(1)
if not self.policy_conv:
state = state.flatten(2)
state = state.view(seq_l * batch_size, state.size(2))
else:
state = state.view(seq_l * batch_size, state.size(2), state.size(3), state.size(4))
state = self.state_encoder(state)
state = state.view(seq_l, batch_size, -1)
state, hidden = self.gru(state, torch.zeros(1, batch_size, state.size(2)).cuda())
state = state.view(seq_l * batch_size, -1)
action_mean = self.actor(state)
cov_mat = torch.diag(self.action_var).cuda()
dist = torch.distributions.multivariate_normal.MultivariateNormal(action_mean, scale_tril=cov_mat)
action_logprobs = dist.log_prob(torch.squeeze(action.view(seq_l * batch_size, -1))).cuda()
dist_entropy = dist.entropy().cuda()
state_value = self.critic(state)
return action_logprobs.view(seq_l, batch_size), \
state_value.view(seq_l, batch_size), \
dist_entropy.view(seq_l, batch_size)
class PPO:
def __init__(self, feature_dim, state_dim, hidden_state_dim, policy_conv,
action_std=0.1, lr=0.0003, betas=(0.9, 0.999), gamma=0.7, K_epochs=1, eps_clip=0.2):
self.lr = lr
self.betas = betas
self.gamma = gamma
self.eps_clip = eps_clip
self.K_epochs = K_epochs
self.policy = ActorCritic(feature_dim, state_dim, hidden_state_dim, policy_conv, action_std).cuda()
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
self.policy_old = ActorCritic(feature_dim, state_dim, hidden_state_dim, policy_conv, action_std).cuda()
self.policy_old.load_state_dict(self.policy.state_dict())
self.MseLoss = nn.MSELoss()
def select_action(self, state, memory, restart_batch=False, training=True):
return self.policy_old.act(state, memory, restart_batch, training)
def update(self, memory):
rewards = []
discounted_reward = 0
for reward in reversed(memory.rewards):
discounted_reward = reward + (self.gamma * discounted_reward)
rewards.insert(0, discounted_reward)
rewards = torch.cat(rewards, 0).cuda()
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
old_states = torch.stack(memory.states, 0).cuda().detach()
old_actions = torch.stack(memory.actions, 0).cuda().detach()
old_logprobs = torch.stack(memory.logprobs, 0).cuda().detach()
for _ in range(self.K_epochs):
logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
ratios = torch.exp(logprobs - old_logprobs.detach())
advantages = rewards - state_values.detach()
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
self.optimizer.zero_grad()
loss.mean().backward()
self.optimizer.step()
self.policy_old.load_state_dict(self.policy.state_dict())
class Full_layer(torch.nn.Module):
def __init__(self, feature_num, hidden_state_dim=1024, fc_rnn=True, class_num=1000):
super(Full_layer, self).__init__()
self.class_num = class_num
self.feature_num = feature_num
self.hidden_state_dim = hidden_state_dim
self.hidden = None
self.fc_rnn = fc_rnn
# classifier with RNN for ResNet, DenseNet and RegNet
if fc_rnn:
self.rnn = nn.GRU(feature_num, self.hidden_state_dim)
self.fc = nn.Linear(self.hidden_state_dim, class_num)
# cascaded classifier for MobileNetV3 and EfficientNet
else:
self.fc_2 = nn.Linear(self.feature_num * 2, class_num)
self.fc_3 = nn.Linear(self.feature_num * 3, class_num)
self.fc_4 = nn.Linear(self.feature_num * 4, class_num)
self.fc_5 = nn.Linear(self.feature_num * 5, class_num)
def forward(self, x, restart=False):
if self.fc_rnn:
if restart:
output, h_n = self.rnn(x.view(1, x.size(0), x.size(1)), torch.zeros(1, x.size(0), self.hidden_state_dim).cuda())
self.hidden = h_n
else:
output, h_n = self.rnn(x.view(1, x.size(0), x.size(1)), self.hidden)
self.hidden = h_n
return self.fc(output[0])
else:
if restart:
self.hidden = x
else:
self.hidden = torch.cat([self.hidden, x], 1)
if self.hidden.size(1) == self.feature_num:
return None
elif self.hidden.size(1) == self.feature_num * 2:
return self.fc_2(self.hidden)
elif self.hidden.size(1) == self.feature_num * 3:
return self.fc_3(self.hidden)
elif self.hidden.size(1) == self.feature_num * 4:
return self.fc_4(self.hidden)
elif self.hidden.size(1) == self.feature_num * 5:
return self.fc_5(self.hidden)
else:
print(self.hidden.size())
exit()
================================================
FILE: pycls/__init__.py
================================================
================================================
FILE: pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml
================================================
MODEL:
TYPE: regnet
NUM_CLASSES: 1000
REGNET:
SE_ON: True
DEPTH: 27
W0: 48
WA: 20.71
WM: 2.65
GROUP_W: 24
OPTIM:
LR_POLICY: cos
BASE_LR: 0.8
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224
BATCH_SIZE: 1024
TEST:
DATASET: imagenet
IM_SIZE: 256
BATCH_SIZE: 800
NUM_GPUS: 1
OUT_DIR: .
================================================
FILE: pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml
================================================
MODEL:
TYPE: regnet
NUM_CLASSES: 1000
REGNET:
SE_ON: True
DEPTH: 15
W0: 48
WA: 32.54
WM: 2.32
GROUP_W: 16
OPTIM:
LR_POLICY: cos
BASE_LR: 0.8
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224
BATCH_SIZE: 1024
TEST:
DATASET: imagenet
IM_SIZE: 256
BATCH_SIZE: 800
NUM_GPUS: 1
OUT_DIR: .
================================================
FILE: pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml
================================================
MODEL:
TYPE: regnet
NUM_CLASSES: 1000
REGNET:
SE_ON: True
DEPTH: 14
W0: 56
WA: 38.84
WM: 2.4
GROUP_W: 16
OPTIM:
LR_POLICY: cos
BASE_LR: 0.8
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224
BATCH_SIZE: 1024
TEST:
DATASET: imagenet
IM_SIZE: 256
BATCH_SIZE: 800
NUM_GPUS: 1
OUT_DIR: .
================================================
FILE: pycls/core/__init__.py
================================================
================================================
FILE: pycls/core/config.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Configuration file (powered by YACS)."""
import os
from pycls.utils.io import cache_url
from yacs.config import CfgNode as CN
# Global config object
_C = CN()
# Example usage:
# from core.config import cfg
cfg = _C
# ---------------------------------------------------------------------------- #
# Model options
# ---------------------------------------------------------------------------- #
_C.MODEL = CN()
# Model type
_C.MODEL.TYPE = ""
# Number of weight layers
_C.MODEL.DEPTH = 0
# Number of classes
_C.MODEL.NUM_CLASSES = 10
# Loss function (see pycls/models/loss.py for options)
_C.MODEL.LOSS_FUN = "cross_entropy"
# ---------------------------------------------------------------------------- #
# ResNet options
# ---------------------------------------------------------------------------- #
_C.RESNET = CN()
# Transformation function (see pycls/models/resnet.py for options)
_C.RESNET.TRANS_FUN = "basic_transform"
# Number of groups to use (1 -> ResNet; > 1 -> ResNeXt)
_C.RESNET.NUM_GROUPS = 1
# Width of each group (64 -> ResNet; 4 -> ResNeXt)
_C.RESNET.WIDTH_PER_GROUP = 64
# Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch)
_C.RESNET.STRIDE_1X1 = True
# ---------------------------------------------------------------------------- #
# AnyNet options
# ---------------------------------------------------------------------------- #
_C.ANYNET = CN()
# Stem type
_C.ANYNET.STEM_TYPE = "plain_block"
# Stem width
_C.ANYNET.STEM_W = 32
# Block type
_C.ANYNET.BLOCK_TYPE = "plain_block"
# Depth for each stage (number of blocks in the stage)
_C.ANYNET.DEPTHS = []
# Width for each stage (width of each block in the stage)
_C.ANYNET.WIDTHS = []
# Strides for each stage (applies to the first block of each stage)
_C.ANYNET.STRIDES = []
# Bottleneck multipliers for each stage (applies to bottleneck block)
_C.ANYNET.BOT_MULS = []
# Group widths for each stage (applies to bottleneck block)
_C.ANYNET.GROUP_WS = []
# Whether SE is enabled for res_bottleneck_block
_C.ANYNET.SE_ON = False
# SE ratio
_C.ANYNET.SE_R = 0.25
# ---------------------------------------------------------------------------- #
# RegNet options
# ---------------------------------------------------------------------------- #
_C.REGNET = CN()
# Stem type
_C.REGNET.STEM_TYPE = "simple_stem_in"
# Stem width
_C.REGNET.STEM_W = 32
# Block type
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
# Stride of each stage
_C.REGNET.STRIDE = 2
# Squeeze-and-Excitation (RegNetY)
_C.REGNET.SE_ON = False
_C.REGNET.SE_R = 0.25
# Depth
_C.REGNET.DEPTH = 10
# Initial width
_C.REGNET.W0 = 32
# Slope
_C.REGNET.WA = 5.0
# Quantization
_C.REGNET.WM = 2.5
# Group width
_C.REGNET.GROUP_W = 16
# Bottleneck multiplier (bm = 1 / b from the paper)
_C.REGNET.BOT_MUL = 1.0
# ---------------------------------------------------------------------------- #
# EfficientNet options
# ---------------------------------------------------------------------------- #
_C.EN = CN()
# Stem width
_C.EN.STEM_W = 32
# Depth for each stage (number of blocks in the stage)
_C.EN.DEPTHS = []
# Width for each stage (width of each block in the stage)
_C.EN.WIDTHS = []
# Expansion ratios for MBConv blocks in each stage
_C.EN.EXP_RATIOS = []
# Squeeze-and-Excitation (SE) ratio
_C.EN.SE_R = 0.25
# Strides for each stage (applies to the first block of each stage)
_C.EN.STRIDES = []
# Kernel sizes for each stage
_C.EN.KERNELS = []
# Head width
_C.EN.HEAD_W = 1280
# Drop connect ratio
_C.EN.DC_RATIO = 0.0
# Dropout ratio
_C.EN.DROPOUT_RATIO = 0.0
# ---------------------------------------------------------------------------- #
# Batch norm options
# ---------------------------------------------------------------------------- #
_C.BN = CN()
# BN epsilon
_C.BN.EPS = 1e-5
# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
_C.BN.MOM = 0.1
# Precise BN stats
_C.BN.USE_PRECISE_STATS = False
_C.BN.NUM_SAMPLES_PRECISE = 1024
# Initialize the gamma of the final BN of each block to zero
_C.BN.ZERO_INIT_FINAL_GAMMA = False
# Use a different weight decay for BN layers
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
# ---------------------------------------------------------------------------- #
# Optimizer options
# ---------------------------------------------------------------------------- #
_C.OPTIM = CN()
# Base learning rate
_C.OPTIM.BASE_LR = 0.1
# Learning rate policy select from {'cos', 'exp', 'steps'}
_C.OPTIM.LR_POLICY = "cos"
# Exponential decay factor
_C.OPTIM.GAMMA = 0.1
# Steps for 'steps' policy (in epochs)
_C.OPTIM.STEPS = []
# Learning rate multiplier for 'steps' policy
_C.OPTIM.LR_MULT = 0.1
# Maximal number of epochs
_C.OPTIM.MAX_EPOCH = 200
# Momentum
_C.OPTIM.MOMENTUM = 0.9
# Momentum dampening
_C.OPTIM.DAMPENING = 0.0
# Nesterov momentum
_C.OPTIM.NESTEROV = True
# L2 regularization
_C.OPTIM.WEIGHT_DECAY = 5e-4
# Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR
_C.OPTIM.WARMUP_FACTOR = 0.1
# Gradually warm up the OPTIM.BASE_LR over this number of epochs
_C.OPTIM.WARMUP_EPOCHS = 0
# ---------------------------------------------------------------------------- #
# Training options
# ---------------------------------------------------------------------------- #
_C.TRAIN = CN()
# Dataset and split
_C.TRAIN.DATASET = ""
_C.TRAIN.SPLIT = "train"
# Total mini-batch size
_C.TRAIN.BATCH_SIZE = 128
# Image size
_C.TRAIN.IM_SIZE = 224
# Evaluate model on test data every eval period epochs
_C.TRAIN.EVAL_PERIOD = 1
# Save model checkpoint every checkpoint period epochs
_C.TRAIN.CHECKPOINT_PERIOD = 1
# Resume training from the latest checkpoint in the output directory
_C.TRAIN.AUTO_RESUME = True
# Weights to start training from
_C.TRAIN.WEIGHTS = ""
# ---------------------------------------------------------------------------- #
# Testing options
# ---------------------------------------------------------------------------- #
_C.TEST = CN()
# Dataset and split
_C.TEST.DATASET = ""
_C.TEST.SPLIT = "val"
# Total mini-batch size
_C.TEST.BATCH_SIZE = 200
# Image size
_C.TEST.IM_SIZE = 256
# Weights to use for testing
_C.TEST.WEIGHTS = ""
# ---------------------------------------------------------------------------- #
# Common train/test data loader options
# ---------------------------------------------------------------------------- #
_C.DATA_LOADER = CN()
# Number of data loader workers per training process
_C.DATA_LOADER.NUM_WORKERS = 4
# Load data to pinned host memory
_C.DATA_LOADER.PIN_MEMORY = True
# ---------------------------------------------------------------------------- #
# Memory options
# ---------------------------------------------------------------------------- #
_C.MEM = CN()
# Perform ReLU inplace
_C.MEM.RELU_INPLACE = True
# ---------------------------------------------------------------------------- #
# CUDNN options
# ---------------------------------------------------------------------------- #
_C.CUDNN = CN()
# Perform benchmarking to select the fastest CUDNN algorithms to use
# Note that this may increase the memory usage and will likely not result
# in overall speedups when variable size inputs are used (e.g. COCO training)
_C.CUDNN.BENCHMARK = True
# ---------------------------------------------------------------------------- #
# Precise timing options
# ---------------------------------------------------------------------------- #
_C.PREC_TIME = CN()
# Perform precise timing at the start of training
_C.PREC_TIME.ENABLED = False
# Total mini-batch size
_C.PREC_TIME.BATCH_SIZE = 128
# Number of iterations to warm up the caches
_C.PREC_TIME.WARMUP_ITER = 3
# Number of iterations to compute avg time
_C.PREC_TIME.NUM_ITER = 30
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
# Number of GPUs to use (applies to both training and testing)
_C.NUM_GPUS = 1
# Output directory
_C.OUT_DIR = "/tmp"
# Config destination (in OUT_DIR)
_C.CFG_DEST = "config.yaml"
# Note that non-determinism may still be present due to non-deterministic
# operator implementations in GPU operator libraries
_C.RNG_SEED = 1
# Log destination ('stdout' or 'file')
_C.LOG_DEST = "stdout"
# Log period in iters
_C.LOG_PERIOD = 10
# Distributed backend
_C.DIST_BACKEND = "nccl"
# Hostname and port for initializing multi-process groups
_C.HOST = "localhost"
_C.PORT = 10001
# Models weights referred to by URL are downloaded to this local cache
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
def assert_and_infer_cfg(cache_urls=True):
"""Checks config values invariants."""
assert (
not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0
), "The first lr step must start at 0"
assert _C.TRAIN.SPLIT in [
"train",
"val",
"test",
], "Train split '{}' not supported".format(_C.TRAIN.SPLIT)
assert (
_C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0
), "Train mini-batch size should be a multiple of NUM_GPUS."
assert _C.TEST.SPLIT in [
"train",
"val",
"test",
], "Test split '{}' not supported".format(_C.TEST.SPLIT)
assert (
_C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0
), "Test mini-batch size should be a multiple of NUM_GPUS."
assert (
not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1
), "Precise BN stats computation not verified for > 1 GPU"
assert _C.LOG_DEST in [
"stdout",
"file",
], "Log destination '{}' not supported".format(_C.LOG_DEST)
assert (
not _C.PREC_TIME.ENABLED or _C.NUM_GPUS == 1
), "Precise iter time computation not verified for > 1 GPU"
if cache_urls:
cache_cfg_urls()
def cache_cfg_urls():
"""Download URLs in the config, cache them locally, and rewrite cfg to make
use of the locally cached file.
"""
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
def dump_cfg():
"""Dumps the config to the output directory."""
cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
with open(cfg_file, "w") as f:
_C.dump(stream=f)
def load_cfg(out_dir, cfg_dest="config.yaml"):
"""Loads config from specified output directory."""
cfg_file = os.path.join(out_dir, cfg_dest)
_C.merge_from_file(cfg_file)
================================================
FILE: pycls/core/losses.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Loss functions."""
import torch.nn as nn
from pycls.core.config import cfg
# Supported loss functions
_loss_funs = {"cross_entropy": nn.CrossEntropyLoss}
def get_loss_fun():
"""Retrieves the loss function."""
assert (
cfg.MODEL.LOSS_FUN in _loss_funs.keys()
), "Loss function '{}' not supported".format(cfg.TRAIN.LOSS)
return _loss_funs[cfg.MODEL.LOSS_FUN]().cuda()
def register_loss_fun(name, ctor):
"""Registers a loss function dynamically."""
_loss_funs[name] = ctor
================================================
FILE: pycls/core/model_builder.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Model construction functions."""
import pycls.utils.logging as lu
import torch
from pycls.core.config import cfg
from pycls.models.anynet import AnyNet
from pycls.models.effnet import EffNet
from pycls.models.regnet import RegNet
from pycls.models.resnet import ResNet
logger = lu.get_logger(__name__)
# Supported models
_models = {"anynet": AnyNet, "effnet": EffNet, "resnet": ResNet, "regnet": RegNet}
def build_model():
"""Builds the model."""
assert cfg.MODEL.TYPE in _models.keys(), "Model type '{}' not supported".format(
cfg.MODEL.TYPE
)
# print(torch.cuda.device_count())
# print(torch.cuda.current_device())
assert (
cfg.NUM_GPUS <= torch.cuda.device_count()
), "Cannot use more GPU devices than available"
# Construct the model
model = _models[cfg.MODEL.TYPE]()
# Determine the GPU used by the current process
cur_device = torch.cuda.current_device()
# Transfer the model to the current GPU device
model = model.cuda(device=cur_device)
# Use multi-process data parallel model in the multi-gpu setting
if cfg.NUM_GPUS > 1:
# Make model replica operate on the current device
model = torch.nn.parallel.DistributedDataParallel(
module=model, device_ids=[cur_device], output_device=cur_device
)
return model
def register_model(name, ctor):
"""Registers a model dynamically."""
_models[name] = ctor
================================================
FILE: pycls/core/old_config.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Configuration file (powered by YACS)."""
import os
from pycls.utils.io import cache_url
from yacs.config import CfgNode as CN
# Global config object
_C = CN()
# Example usage:
# from core.config import cfg
cfg = _C
# ---------------------------------------------------------------------------- #
# Model options
# ---------------------------------------------------------------------------- #
_C.MODEL = CN()
# Model type
_C.MODEL.TYPE = ""
# Number of weight layers
_C.MODEL.DEPTH = 0
# Number of classes
_C.MODEL.NUM_CLASSES = 10
# Loss function (see pycls/models/loss.py for options)
_C.MODEL.LOSS_FUN = "cross_entropy"
# ---------------------------------------------------------------------------- #
# ResNet options
# ---------------------------------------------------------------------------- #
_C.RESNET = CN()
# Transformation function (see pycls/models/resnet.py for options)
_C.RESNET.TRANS_FUN = "basic_transform"
# Number of groups to use (1 -> ResNet; > 1 -> ResNeXt)
_C.RESNET.NUM_GROUPS = 1
# Width of each group (64 -> ResNet; 4 -> ResNeXt)
_C.RESNET.WIDTH_PER_GROUP = 64
# Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch)
_C.RESNET.STRIDE_1X1 = True
# ---------------------------------------------------------------------------- #
# AnyNet options
# ---------------------------------------------------------------------------- #
_C.ANYNET = CN()
# Stem type
_C.ANYNET.STEM_TYPE = "plain_block"
# Stem width
_C.ANYNET.STEM_W = 32
# Block type
_C.ANYNET.BLOCK_TYPE = "plain_block"
# Depth for each stage (number of blocks in the stage)
_C.ANYNET.DEPTHS = []
# Width for each stage (width of each block in the stage)
_C.ANYNET.WIDTHS = []
# Strides for each stage (applies to the first block of each stage)
_C.ANYNET.STRIDES = []
# Bottleneck multipliers for each stage (applies to bottleneck block)
_C.ANYNET.BOT_MULS = []
# Group widths for each stage (applies to bottleneck block)
_C.ANYNET.GROUP_WS = []
# Whether SE is enabled for res_bottleneck_block
_C.ANYNET.SE_ON = False
# SE ratio
_C.ANYNET.SE_R = 0.25
# ---------------------------------------------------------------------------- #
# RegNet options
# ---------------------------------------------------------------------------- #
_C.REGNET = CN()
# Stem type
_C.REGNET.STEM_TYPE = "simple_stem_in"
# Stem width
_C.REGNET.STEM_W = 32
# Block type
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
# Stride of each stage
_C.REGNET.STRIDE = 2
# Squeeze-and-Excitation (RegNetY)
_C.REGNET.SE_ON = False
_C.REGNET.SE_R = 0.25
# Depth
_C.REGNET.DEPTH = 10
# Initial width
_C.REGNET.W0 = 32
# Slope
_C.REGNET.WA = 5.0
# Quantization
_C.REGNET.WM = 2.5
# Group width
_C.REGNET.GROUP_W = 16
# Bottleneck multiplier (bm = 1 / b from the paper)
_C.REGNET.BOT_MUL = 1.0
# ---------------------------------------------------------------------------- #
# EfficientNet options
# ---------------------------------------------------------------------------- #
_C.EN = CN()
# Stem width
_C.EN.STEM_W = 32
# Depth for each stage (number of blocks in the stage)
_C.EN.DEPTHS = []
# Width for each stage (width of each block in the stage)
_C.EN.WIDTHS = []
# Expansion ratios for MBConv blocks in each stage
_C.EN.EXP_RATIOS = []
# Squeeze-and-Excitation (SE) ratio
_C.EN.SE_R = 0.25
# Strides for each stage (applies to the first block of each stage)
_C.EN.STRIDES = []
# Kernel sizes for each stage
_C.EN.KERNELS = []
# Head width
_C.EN.HEAD_W = 1280
# Drop connect ratio
_C.EN.DC_RATIO = 0.0
# Dropout ratio
_C.EN.DROPOUT_RATIO = 0.0
# ---------------------------------------------------------------------------- #
# Batch norm options
# ---------------------------------------------------------------------------- #
_C.BN = CN()
# BN epsilon
_C.BN.EPS = 1e-5
# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
_C.BN.MOM = 0.1
# Precise BN stats
_C.BN.USE_PRECISE_STATS = False
_C.BN.NUM_SAMPLES_PRECISE = 1024
# Initialize the gamma of the final BN of each block to zero
_C.BN.ZERO_INIT_FINAL_GAMMA = False
# Use a different weight decay for BN layers
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
# ---------------------------------------------------------------------------- #
# Optimizer options
# ---------------------------------------------------------------------------- #
_C.OPTIM = CN()
# Base learning rate
_C.OPTIM.BASE_LR = 0.1
# Learning rate policy select from {'cos', 'exp', 'steps'}
_C.OPTIM.LR_POLICY = "cos"
# Exponential decay factor
_C.OPTIM.GAMMA = 0.1
# Steps for 'steps' policy (in epochs)
_C.OPTIM.STEPS = []
# Learning rate multiplier for 'steps' policy
_C.OPTIM.LR_MULT = 0.1
# Maximal number of epochs
_C.OPTIM.MAX_EPOCH = 200
# Momentum
_C.OPTIM.MOMENTUM = 0.9
# Momentum dampening
_C.OPTIM.DAMPENING = 0.0
# Nesterov momentum
_C.OPTIM.NESTEROV = True
# L2 regularization
_C.OPTIM.WEIGHT_DECAY = 5e-4
# Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR
_C.OPTIM.WARMUP_FACTOR = 0.1
# Gradually warm up the OPTIM.BASE_LR over this number of epochs
_C.OPTIM.WARMUP_EPOCHS = 0
# ---------------------------------------------------------------------------- #
# Training options
# ---------------------------------------------------------------------------- #
_C.TRAIN = CN()
# Dataset and split
_C.TRAIN.DATASET = ""
_C.TRAIN.SPLIT = "train"
# Total mini-batch size
_C.TRAIN.BATCH_SIZE = 128
# Image size
_C.TRAIN.IM_SIZE = 224
# Evaluate model on test data every eval period epochs
_C.TRAIN.EVAL_PERIOD = 1
# Save model checkpoint every checkpoint period epochs
_C.TRAIN.CHECKPOINT_PERIOD = 1
# Resume training from the latest checkpoint in the output directory
_C.TRAIN.AUTO_RESUME = True
# Weights to start training from
_C.TRAIN.WEIGHTS = ""
# ---------------------------------------------------------------------------- #
# Testing options
# ---------------------------------------------------------------------------- #
_C.TEST = CN()
# Dataset and split
_C.TEST.DATASET = ""
_C.TEST.SPLIT = "val"
# Total mini-batch size
_C.TEST.BATCH_SIZE = 200
# Image size
_C.TEST.IM_SIZE = 256
# Weights to use for testing
_C.TEST.WEIGHTS = ""
# ---------------------------------------------------------------------------- #
# Common train/test data loader options
# ---------------------------------------------------------------------------- #
_C.DATA_LOADER = CN()
# Number of data loader workers per training process
_C.DATA_LOADER.NUM_WORKERS = 4
# Load data to pinned host memory
_C.DATA_LOADER.PIN_MEMORY = True
# ---------------------------------------------------------------------------- #
# Memory options
# ---------------------------------------------------------------------------- #
_C.MEM = CN()
# Perform ReLU inplace
_C.MEM.RELU_INPLACE = True
# ---------------------------------------------------------------------------- #
# CUDNN options
# ---------------------------------------------------------------------------- #
_C.CUDNN = CN()
# Perform benchmarking to select the fastest CUDNN algorithms to use
# Note that this may increase the memory usage and will likely not result
# in overall speedups when variable size inputs are used (e.g. COCO training)
_C.CUDNN.BENCHMARK = True
# ---------------------------------------------------------------------------- #
# Precise timing options
# ---------------------------------------------------------------------------- #
_C.PREC_TIME = CN()
# Perform precise timing at the start of training
_C.PREC_TIME.ENABLED = False
# Total mini-batch size
_C.PREC_TIME.BATCH_SIZE = 128
# Number of iterations to warm up the caches
_C.PREC_TIME.WARMUP_ITER = 3
# Number of iterations to compute avg time
_C.PREC_TIME.NUM_ITER = 30
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
# Number of GPUs to use (applies to both training and testing)
_C.NUM_GPUS = 1
# Output directory
_C.OUT_DIR = "/tmp"
# Config destination (in OUT_DIR)
_C.CFG_DEST = "config.yaml"
# Note that non-determinism may still be present due to non-deterministic
# operator implementations in GPU operator libraries
_C.RNG_SEED = 1
# Log destination ('stdout' or 'file')
_C.LOG_DEST = "stdout"
# Log period in iters
_C.LOG_PERIOD = 10
# Distributed backend
_C.DIST_BACKEND = "nccl"
# Hostname and port for initializing multi-process groups
_C.HOST = "localhost"
_C.PORT = 10001
# Models weights referred to by URL are downloaded to this local cache
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
def assert_and_infer_cfg(cache_urls=True):
"""Checks config values invariants."""
assert (
not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0
), "The first lr step must start at 0"
assert _C.TRAIN.SPLIT in [
"train",
"val",
"test",
], "Train split '{}' not supported".format(_C.TRAIN.SPLIT)
assert (
_C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0
), "Train mini-batch size should be a multiple of NUM_GPUS."
assert _C.TEST.SPLIT in [
"train",
"val",
"test",
], "Test split '{}' not supported".format(_C.TEST.SPLIT)
assert (
_C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0
), "Test mini-batch size should be a multiple of NUM_GPUS."
assert (
not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1
), "Precise BN stats computation not verified for > 1 GPU"
assert _C.LOG_DEST in [
"stdout",
"file",
], "Log destination '{}' not supported".format(_C.LOG_DEST)
assert (
not _C.PREC_TIME.ENABLED or _C.NUM_GPUS == 1
), "Precise iter time computation not verified for > 1 GPU"
if cache_urls:
cache_cfg_urls()
def cache_cfg_urls():
"""Download URLs in the config, cache them locally, and rewrite cfg to make
use of the locally cached file.
"""
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
def dump_cfg():
"""Dumps the config to the output directory."""
cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
with open(cfg_file, "w") as f:
_C.dump(stream=f)
def load_cfg(out_dir, cfg_dest="config.yaml"):
"""Loads config from specified output directory."""
cfg_file = os.path.join(out_dir, cfg_dest)
_C.merge_from_file(cfg_file)
================================================
FILE: pycls/core/optimizer.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Optimizer."""
import pycls.utils.lr_policy as lr_policy
import torch
from pycls.core.config import cfg
def construct_optimizer(model):
"""Constructs the optimizer.
Note that the momentum update in PyTorch differs from the one in Caffe2.
In particular,
Caffe2:
V := mu * V + lr * g
p := p - V
PyTorch:
V := mu * V + g
p := p - lr * V
where V is the velocity, mu is the momentum factor, lr is the learning rate,
g is the gradient and p are the parameters.
Since V is defined independently of the learning rate in PyTorch,
when the learning rate is changed there is no need to perform the
momentum correction by scaling V (unlike in the Caffe2 case).
"""
# Batchnorm parameters.
bn_params = []
# Non-batchnorm parameters.
non_bn_parameters = []
for name, p in model.named_parameters():
if "bn" in name:
bn_params.append(p)
else:
non_bn_parameters.append(p)
# Apply different weight decay to Batchnorm and non-batchnorm parameters.
bn_weight_decay = (
cfg.BN.CUSTOM_WEIGHT_DECAY
if cfg.BN.USE_CUSTOM_WEIGHT_DECAY
else cfg.OPTIM.WEIGHT_DECAY
)
optim_params = [
{"params": bn_params, "weight_decay": bn_weight_decay},
{"params": non_bn_parameters, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
]
# Check all parameters will be passed into optimizer.
assert len(list(model.parameters())) == len(non_bn_parameters) + len(
bn_params
), "parameter size does not match: {} + {} != {}".format(
len(non_bn_parameters), len(bn_params), len(list(model.parameters()))
)
return torch.optim.SGD(
optim_params,
lr=cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV,
)
def get_epoch_lr(cur_epoch):
"""Retrieves the lr for the given epoch (as specified by the lr policy)."""
return lr_policy.get_epoch_lr(cur_epoch)
def set_lr(optimizer, new_lr):
"""Sets the optimizer lr to the specified value."""
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr
================================================
FILE: pycls/datasets/__init__.py
================================================
================================================
FILE: pycls/datasets/cifar10.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""CIFAR10 dataset."""
import os
import pickle
import numpy as np
import pycls.datasets.transforms as transforms
import pycls.utils.logging as lu
import torch
import torch.utils.data
from pycls.core.config import cfg
logger = lu.get_logger(__name__)
# Per-channel mean and SD values in BGR order
_MEAN = [125.3, 123.0, 113.9]
_SD = [63.0, 62.1, 66.7]
class Cifar10(torch.utils.data.Dataset):
"""CIFAR-10 dataset."""
def __init__(self, data_path, split):
assert os.path.exists(data_path), "Data path '{}' not found".format(data_path)
assert split in ["train", "test"], "Split '{}' not supported for cifar".format(
split
)
logger.info("Constructing CIFAR-10 {}...".format(split))
self._data_path = data_path
self._split = split
# Data format:
# self._inputs - (split_size, 3, im_size, im_size) ndarray
# self._labels - split_size list
self._inputs, self._labels = self._load_data()
def _load_batch(self, batch_path):
with open(batch_path, "rb") as f:
d = pickle.load(f, encoding="bytes")
return d[b"data"], d[b"labels"]
def _load_data(self):
"""Loads data in memory."""
logger.info("{} data path: {}".format(self._split, self._data_path))
# Compute data batch names
if self._split == "train":
batch_names = ["data_batch_{}".format(i) for i in range(1, 6)]
else:
batch_names = ["test_batch"]
# Load data batches
inputs, labels = [], []
for batch_name in batch_names:
batch_path = os.path.join(self._data_path, batch_name)
inputs_batch, labels_batch = self._load_batch(batch_path)
inputs.append(inputs_batch)
labels += labels_batch
# Combine and reshape the inputs
inputs = np.vstack(inputs).astype(np.float32)
inputs = inputs.reshape((-1, 3, cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE))
return inputs, labels
def _prepare_im(self, im):
"""Prepares the image for network input."""
im = transforms.color_norm(im, _MEAN, _SD)
if self._split == "train":
im = transforms.horizontal_flip(im=im, p=0.5)
im = transforms.random_crop(im=im, size=cfg.TRAIN.IM_SIZE, pad_size=4)
return im
def __getitem__(self, index):
im, label = self._inputs[index, ...].copy(), self._labels[index]
im = self._prepare_im(im)
return im, label
def __len__(self):
return self._inputs.shape[0]
================================================
FILE: pycls/datasets/imagenet.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""ImageNet dataset."""
import os
import re
import cv2
import numpy as np
import pycls.datasets.transforms as transforms
import pycls.utils.logging as lu
import torch
import torch.utils.data
from pycls.core.config import cfg
logger = lu.get_logger(__name__)
# Per-channel mean and SD values in BGR order
_MEAN = [0.406, 0.456, 0.485]
_SD = [0.225, 0.224, 0.229]
# Eig vals and vecs of the cov mat
_EIG_VALS = np.array([[0.2175, 0.0188, 0.0045]])
_EIG_VECS = np.array(
[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
)
class ImageNet(torch.utils.data.Dataset):
"""ImageNet dataset."""
def __init__(self, data_path, split):
assert os.path.exists(data_path), "Data path '{}' not found".format(data_path)
assert split in [
"train",
"val",
], "Split '{}' not supported for ImageNet".format(split)
logger.info("Constructing ImageNet {}...".format(split))
self._data_path = data_path
self._split = split
self._construct_imdb()
def _construct_imdb(self):
"""Constructs the imdb."""
# Compile the split data path
split_path = os.path.join(self._data_path, self._split)
logger.info("{} data path: {}".format(self._split, split_path))
# Images are stored per class in subdirs (format: n)
self._class_ids = sorted(
f for f in os.listdir(split_path) if re.match(r"^n[0-9]+$", f)
)
# Map ImageNet class ids to contiguous ids
self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)}
# Construct the image db
self._imdb = []
for class_id in self._class_ids:
cont_id = self._class_id_cont_id[class_id]
im_dir = os.path.join(split_path, class_id)
for im_name in os.listdir(im_dir):
self._imdb.append(
{"im_path": os.path.join(im_dir, im_name), "class": cont_id}
)
logger.info("Number of images: {}".format(len(self._imdb)))
logger.info("Number of classes: {}".format(len(self._class_ids)))
def _prepare_im(self, im):
"""Prepares the image for network input."""
# Train and test setups differ
if self._split == "train":
# Scale and aspect ratio
im = transforms.random_sized_crop(
im=im, size=cfg.TRAIN.IM_SIZE, area_frac=0.08
)
# Horizontal flip
im = transforms.horizontal_flip(im=im, p=0.5, order="HWC")
else:
# Scale and center crop
im = transforms.scale(cfg.TEST.IM_SIZE, im)
im = transforms.center_crop(cfg.TRAIN.IM_SIZE, im)
# HWC -> CHW
im = im.transpose([2, 0, 1])
# [0, 255] -> [0, 1]
im = im / 255.0
# PCA jitter
if self._split == "train":
im = transforms.lighting(im, 0.1, _EIG_VALS, _EIG_VECS)
# Color normalization
im = transforms.color_norm(im, _MEAN, _SD)
return im
def __getitem__(self, index):
# Load the image
im = cv2.imread(self._imdb[index]["im_path"])
im = im.astype(np.float32, copy=False)
# Prepare the image for training / testing
im = self._prepare_im(im)
# Retrieve the label
label = self._imdb[index]["class"]
return im, label
def __len__(self):
return len(self._imdb)
================================================
FILE: pycls/datasets/loader.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Data loader."""
import pycls.datasets.paths as dp
import torch
from pycls.core.config import cfg
from pycls.datasets.cifar10 import Cifar10
from pycls.datasets.imagenet import ImageNet
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
# Supported datasets
_DATASET_CATALOG = {"cifar10": Cifar10, "imagenet": ImageNet}
def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last):
"""Constructs the data loader for the given dataset."""
assert dataset_name in _DATASET_CATALOG.keys(), "Dataset '{}' not supported".format(
dataset_name
)
assert dp.has_data_path(dataset_name), "Dataset '{}' has no data path".format(
dataset_name
)
# Retrieve the data path for the dataset
data_path = dp.get_data_path(dataset_name)
# Construct the dataset
dataset = _DATASET_CATALOG[dataset_name](data_path, split)
# Create a sampler for multi-process training
sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
# Create a loader
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=(False if sampler else shuffle),
sampler=sampler,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
drop_last=drop_last,
)
return loader
def construct_train_loader():
"""Train loader wrapper."""
return _construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
)
def construct_test_loader():
"""Test loader wrapper."""
return _construct_loader(
dataset_name=cfg.TEST.DATASET,
split=cfg.TEST.SPLIT,
batch_size=int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=False,
drop_last=False,
)
def shuffle(loader, cur_epoch):
""""Shuffles the data."""
assert isinstance(
loader.sampler, (RandomSampler, DistributedSampler)
), "Sampler type '{}' not supported".format(type(loader.sampler))
# RandomSampler handles shuffling automatically
if isinstance(loader.sampler, DistributedSampler):
# DistributedSampler shuffles data based on epoch
loader.sampler.set_epoch(cur_epoch)
================================================
FILE: pycls/datasets/paths.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Dataset paths."""
import os
# Default data directory (/path/pycls/pycls/datasets/data)
_DEF_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
# Data paths
_paths = {
"cifar10": _DEF_DATA_DIR + "/cifar10",
"imagenet": _DEF_DATA_DIR + "/imagenet",
}
def has_data_path(dataset_name):
"""Determines if the dataset has a data path."""
return dataset_name in _paths.keys()
def get_data_path(dataset_name):
"""Retrieves data path for the dataset."""
return _paths[dataset_name]
def register_path(name, path):
"""Registers a dataset path dynamically."""
_paths[name] = path
================================================
FILE: pycls/datasets/transforms.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Image transformations."""
import math
import cv2
import numpy as np
def color_norm(im, mean, std):
"""Performs per-channel normalization (CHW format)."""
for i in range(im.shape[0]):
im[i] = im[i] - mean[i]
im[i] = im[i] / std[i]
return im
def zero_pad(im, pad_size):
"""Performs zero padding (CHW format)."""
pad_width = ((0, 0), (pad_size, pad_size), (pad_size, pad_size))
return np.pad(im, pad_width, mode="constant")
def horizontal_flip(im, p, order="CHW"):
"""Performs horizontal flip (CHW or HWC format)."""
assert order in ["CHW", "HWC"]
if np.random.uniform() < p:
if order == "CHW":
im = im[:, :, ::-1]
else:
im = im[:, ::-1, :]
return im
def random_crop(im, size, pad_size=0):
"""Performs random crop (CHW format)."""
if pad_size > 0:
im = zero_pad(im=im, pad_size=pad_size)
h, w = im.shape[1:]
y = np.random.randint(0, h - size)
x = np.random.randint(0, w - size)
im_crop = im[:, y : (y + size), x : (x + size)]
assert im_crop.shape[1:] == (size, size)
return im_crop
def scale(size, im):
"""Performs scaling (HWC format)."""
h, w = im.shape[:2]
if (w <= h and w == size) or (h <= w and h == size):
return im
h_new, w_new = size, size
if w < h:
h_new = int(math.floor((float(h) / w) * size))
else:
w_new = int(math.floor((float(w) / h) * size))
im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_LINEAR)
return im.astype(np.float32)
def center_crop(size, im):
"""Performs center cropping (HWC format)."""
h, w = im.shape[:2]
y = int(math.ceil((h - size) / 2))
x = int(math.ceil((w - size) / 2))
im_crop = im[y : (y + size), x : (x + size), :]
assert im_crop.shape[:2] == (size, size)
return im_crop
def random_sized_crop(im, size, area_frac=0.08, max_iter=10):
"""Performs Inception-style cropping (HWC format)."""
h, w = im.shape[:2]
area = h * w
for _ in range(max_iter):
target_area = np.random.uniform(area_frac, 1.0) * area
aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0)
w_crop = int(round(math.sqrt(float(target_area) * aspect_ratio)))
h_crop = int(round(math.sqrt(float(target_area) / aspect_ratio)))
if np.random.uniform() < 0.5:
w_crop, h_crop = h_crop, w_crop
if h_crop <= h and w_crop <= w:
y = 0 if h_crop == h else np.random.randint(0, h - h_crop)
x = 0 if w_crop == w else np.random.randint(0, w - w_crop)
im_crop = im[y : (y + h_crop), x : (x + w_crop), :]
assert im_crop.shape[:2] == (h_crop, w_crop)
im_crop = cv2.resize(im_crop, (size, size), interpolation=cv2.INTER_LINEAR)
return im_crop.astype(np.float32)
return center_crop(size, scale(size, im))
def lighting(im, alpha_std, eig_val, eig_vec):
"""Performs AlexNet-style PCA jitter (CHW format)."""
if alpha_std == 0:
return im
alpha = np.random.normal(0, alpha_std, size=(1, 3))
rgb = np.sum(
eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), axis=1
)
for i in range(im.shape[0]):
im[i] = im[i] + rgb[2 - i]
return im
================================================
FILE: pycls/models/__init__.py
================================================
================================================
FILE: pycls/models/anynet.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""AnyNet models."""
import pycls.utils.logging as lu
import pycls.utils.net as nu
import torch.nn as nn
from pycls.core.config import cfg
logger = lu.get_logger(__name__)
def get_stem_fun(stem_type):
"""Retrives the stem function by name."""
stem_funs = {
"res_stem_cifar": ResStemCifar,
"res_stem_in": ResStemIN,
"simple_stem_in": SimpleStemIN,
}
assert stem_type in stem_funs.keys(), "Stem type '{}' not supported".format(
stem_type
)
return stem_funs[stem_type]
def get_block_fun(block_type):
"""Retrieves the block function by name."""
block_funs = {
"vanilla_block": VanillaBlock,
"res_basic_block": ResBasicBlock,
"res_bottleneck_block": ResBottleneckBlock,
}
assert block_type in block_funs.keys(), "Block type '{}' not supported".format(
block_type
)
return block_funs[block_type]
class AnyHead(nn.Module):
"""AnyNet head."""
def __init__(self, w_in, nc):
super(AnyHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
# self.fc = nn.Linear(w_in, nc, bias=True)
def forward(self, x):
x = self.avg_pool(x)
# print(x.shape)
x = x.view(x.size(0), -1)
# x = self.fc(x)
return x
class VanillaBlock(nn.Module):
"""Vanilla block: [3x3 conv, BN, Relu] x2"""
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
assert (
bm is None and gw is None and se_r is None
), "Vanilla block does not support bm, gw, and se_r options"
super(VanillaBlock, self).__init__()
self._construct(w_in, w_out, stride)
def _construct(self, w_in, w_out, stride):
# 3x3, BN, ReLU
self.a = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN, ReLU
self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class BasicTransform(nn.Module):
"""Basic transformation: [3x3 conv, BN, Relu] x2"""
def __init__(self, w_in, w_out, stride):
super(BasicTransform, self).__init__()
self._construct(w_in, w_out, stride)
def _construct(self, w_in, w_out, stride):
# 3x3, BN, ReLU
self.a = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN
self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class ResBasicBlock(nn.Module):
"""Residual basic block: x + F(x), F = basic transform"""
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
assert (
bm is None and gw is None and se_r is None
), "Basic transform does not support bm, gw, and se_r options"
super(ResBasicBlock, self).__init__()
self._construct(w_in, w_out, stride)
def _add_skip_proj(self, w_in, w_out, stride):
self.proj = nn.Conv2d(
w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
def _construct(self, w_in, w_out, stride):
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self._add_skip_proj(w_in, w_out, stride)
self.f = BasicTransform(w_in, w_out, stride)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block"""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self._construct(w_in, w_se)
def _construct(self, w_in, w_se):
# AvgPool
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
# FC, Activation, FC, Sigmoid
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, kernel_size=1, bias=True),
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
nn.Conv2d(w_se, w_in, kernel_size=1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
class BottleneckTransform(nn.Module):
"""Bottlenect transformation: 1x1, 3x3, 1x1"""
def __init__(self, w_in, w_out, stride, bm, gw, se_r):
super(BottleneckTransform, self).__init__()
self._construct(w_in, w_out, stride, bm, gw, se_r)
def _construct(self, w_in, w_out, stride, bm, gw, se_r):
# Compute the bottleneck width
w_b = int(round(w_out * bm))
# Compute the number of groups
num_gs = w_b // gw
# 1x1, BN, ReLU
self.a = nn.Conv2d(w_in, w_b, kernel_size=1, stride=1, padding=0, bias=False)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN, ReLU
self.b = nn.Conv2d(
w_b, w_b, kernel_size=3, stride=stride, padding=1, groups=num_gs, bias=False
)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# Squeeze-and-Excitation (SE)
if se_r:
w_se = int(round(w_in * se_r))
self.se = SE(w_b, w_se)
# 1x1, BN
self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class ResBottleneckBlock(nn.Module):
"""Residual bottleneck block: x + F(x), F = bottleneck transform"""
def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
super(ResBottleneckBlock, self).__init__()
self._construct(w_in, w_out, stride, bm, gw, se_r)
def _add_skip_proj(self, w_in, w_out, stride):
self.proj = nn.Conv2d(
w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
def _construct(self, w_in, w_out, stride, bm, gw, se_r):
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self._add_skip_proj(w_in, w_out, stride)
self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
class ResStemCifar(nn.Module):
"""ResNet stem for CIFAR."""
def __init__(self, w_in, w_out):
super(ResStemCifar, self).__init__()
self._construct(w_in, w_out)
def _construct(self, w_in, w_out):
# 3x3, BN, ReLU
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class ResStemIN(nn.Module):
"""ResNet stem for ImageNet."""
def __init__(self, w_in, w_out):
super(ResStemIN, self).__init__()
self._construct(w_in, w_out)
def _construct(self, w_in, w_out):
# 7x7, BN, ReLU, maxpool
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class SimpleStemIN(nn.Module):
"""Simple stem for ImageNet."""
def __init__(self, in_w, out_w):
super(SimpleStemIN, self).__init__()
self._construct(in_w, out_w)
def _construct(self, in_w, out_w):
# 3x3, BN, ReLU
self.conv = nn.Conv2d(
in_w, out_w, kernel_size=3, stride=2, padding=1, bias=False
)
self.bn = nn.BatchNorm2d(out_w, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class AnyStage(nn.Module):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
super(AnyStage, self).__init__()
self._construct(w_in, w_out, stride, d, block_fun, bm, gw, se_r)
def _construct(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
# Construct the blocks
for i in range(d):
# Stride and w_in apply to the first block of the stage
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
# Construct the block
self.add_module(
"b{}".format(i + 1), block_fun(b_w_in, w_out, b_stride, bm, gw, se_r)
)
def forward(self, x):
for block in self.children():
x = block(x)
return x
class AnyNet(nn.Module):
"""AnyNet model."""
def __init__(self, **kwargs):
super(AnyNet, self).__init__()
if kwargs:
self._construct(
stem_type=kwargs["stem_type"],
stem_w=kwargs["stem_w"],
block_type=kwargs["block_type"],
ds=kwargs["ds"],
ws=kwargs["ws"],
ss=kwargs["ss"],
bms=kwargs["bms"],
gws=kwargs["gws"],
se_r=kwargs["se_r"],
nc=kwargs["nc"],
)
else:
self._construct(
stem_type=cfg.ANYNET.STEM_TYPE,
stem_w=cfg.ANYNET.STEM_W,
block_type=cfg.ANYNET.BLOCK_TYPE,
ds=cfg.ANYNET.DEPTHS,
ws=cfg.ANYNET.WIDTHS,
ss=cfg.ANYNET.STRIDES,
bms=cfg.ANYNET.BOT_MULS,
gws=cfg.ANYNET.GROUP_WS,
se_r=cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None,
nc=cfg.MODEL.NUM_CLASSES,
)
self.apply(nu.init_weights)
def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
logger.info("Constructing AnyNet: ds={}, ws={}".format(ds, ws))
# Generate dummy bot muls and gs for models that do not use them
bms = bms if bms else [1.0 for _d in ds]
gws = gws if gws else [1 for _d in ds]
# Group params by stage
stage_params = list(zip(ds, ws, ss, bms, gws))
# Construct the stem
stem_fun = get_stem_fun(stem_type)
self.stem = stem_fun(3, stem_w)
# Construct the stages
block_fun = get_block_fun(block_type)
prev_w = stem_w
for i, (d, w, s, bm, gw) in enumerate(stage_params):
self.add_module(
"s{}".format(i + 1), AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r)
)
prev_w = w
# Construct the head
self.head = AnyHead(w_in=prev_w, nc=nc)
self.fc = nn.Linear(prev_w, nc, bias=True)
self.feature_num = prev_w
def forward(self, x):
# for module in self.children():
# x = module(x)
# return x
for name, module in self.named_children():
if name != 'head' and name != 'fc':
x = module(x)
# print(x.shape)
output = self.head(x)
# print(output.shape)
return output, x.detach()
================================================
FILE: pycls/models/effnet.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""EfficientNet models."""
import pycls.utils.logging as logging
import pycls.utils.net as nu
import torch
import torch.nn as nn
from pycls.core.config import cfg
logger = logging.get_logger(__name__)
class EffHead(nn.Module):
"""EfficientNet head."""
def __init__(self, w_in, w_out, nc):
super(EffHead, self).__init__()
self._construct(w_in, w_out, nc)
def _construct(self, w_in, w_out, nc):
# 1x1, BN, Swish
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=1, stride=1, padding=0, bias=False
)
self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.conv_swish = Swish()
# AvgPool
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
# Dropout
if cfg.EN.DROPOUT_RATIO > 0.0:
self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO)
# FC
self.fc = nn.Linear(w_out, nc, bias=True)
def forward(self, x):
x = self.conv_swish(self.conv_bn(self.conv(x)))
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x) if hasattr(self, "dropout") else x
x = self.fc(x)
return x
class Swish(nn.Module):
"""Swish activation function: x * sigmoid(x)"""
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block w/ Swish."""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self._construct(w_in, w_se)
def _construct(self, w_in, w_se):
# AvgPool
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
# FC, Swish, FC, Sigmoid
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, kernel_size=1, bias=True),
Swish(),
nn.Conv2d(w_se, w_in, kernel_size=1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
class MBConv(nn.Module):
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
super(MBConv, self).__init__()
self._construct(w_in, exp_r, kernel, stride, se_r, w_out)
def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out):
# Expansion ratio is wrt the input width
self.exp = None
w_exp = int(w_in * exp_r)
# Include exp ops only if the exp ratio is different from 1
if w_exp != w_in:
# 1x1, BN, Swish
self.exp = nn.Conv2d(
w_in, w_exp, kernel_size=1, stride=1, padding=0, bias=False
)
self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.exp_swish = Swish()
# 3x3 dwise, BN, Swish
self.dwise = nn.Conv2d(
w_exp,
w_exp,
kernel_size=kernel,
stride=stride,
groups=w_exp,
bias=False,
# Hacky padding to preserve res (supports only 3x3 and 5x5)
padding=(1 if kernel == 3 else 2),
)
self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.dwise_swish = Swish()
# Squeeze-and-Excitation (SE)
w_se = int(w_in * se_r)
self.se = SE(w_exp, w_se)
# 1x1, BN
self.lin_proj = nn.Conv2d(
w_exp, w_out, kernel_size=1, stride=1, padding=0, bias=False
)
self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
# Skip connection if in and out shapes are the same (MN-V2 style)
self.has_skip = (stride == 1) and (w_in == w_out)
def forward(self, x):
f_x = x
# Expansion
if self.exp:
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
# Depthwise
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
# SE
f_x = self.se(f_x)
# Linear projection
f_x = self.lin_proj_bn(self.lin_proj(f_x))
# Skip connection
if self.has_skip:
# Drop connect
if self.training and cfg.EN.DC_RATIO > 0.0:
f_x = nu.drop_connect(f_x, cfg.EN.DC_RATIO)
f_x = x + f_x
return f_x
class EffStage(nn.Module):
"""EfficientNet stage."""
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
super(EffStage, self).__init__()
self._construct(w_in, exp_r, kernel, stride, se_r, w_out, d)
def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
# Construct the blocks
for i in range(d):
# Stride and input width apply to the first block of the stage
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
# Construct the block
self.add_module(
"b{}".format(i + 1),
MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out),
)
def forward(self, x):
for block in self.children():
x = block(x)
return x
class StemIN(nn.Module):
"""EfficientNet stem for ImageNet."""
def __init__(self, w_in, w_out):
super(StemIN, self).__init__()
self._construct(w_in, w_out)
def _construct(self, w_in, w_out):
# 3x3, BN, Swish
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=2, padding=1, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.swish = Swish()
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class EffNet(nn.Module):
"""EfficientNet model."""
def __init__(self):
assert cfg.TRAIN.DATASET in [
"imagenet"
], "Training on {} is not supported".format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in [
"imagenet"
], "Testing on {} is not supported".format(cfg.TEST.DATASET)
super(EffNet, self).__init__()
self._construct(
stem_w=cfg.EN.STEM_W,
ds=cfg.EN.DEPTHS,
ws=cfg.EN.WIDTHS,
exp_rs=cfg.EN.EXP_RATIOS,
se_r=cfg.EN.SE_R,
ss=cfg.EN.STRIDES,
ks=cfg.EN.KERNELS,
head_w=cfg.EN.HEAD_W,
nc=cfg.MODEL.NUM_CLASSES,
)
self.apply(nu.init_weights)
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
# Group params by stage
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
logger.info("Constructing: EfficientNet-{}".format(stage_params))
# Construct the stem
self.stem = StemIN(3, stem_w)
prev_w = stem_w
# Construct the stages
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
self.add_module(
"s{}".format(i + 1), EffStage(prev_w, exp_r, kernel, stride, se_r, w, d)
)
prev_w = w
# Construct the head
self.head = EffHead(prev_w, head_w, nc)
def forward(self, x):
for module in self.children():
x = module(x)
return x
================================================
FILE: pycls/models/regnet.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""RegNet models."""
import numpy as np
import pycls.utils.logging as lu
from pycls.core.config import cfg
from pycls.models.anynet import AnyNet
logger = lu.get_logger(__name__)
def quantize_float(f, q):
"""Converts a float to closest non-zero int divisible by q."""
return int(round(f / q) * q)
def adjust_ws_gs_comp(ws, bms, gs):
"""Adjusts the compatibility of widths and groups."""
ws_bot = [int(w * b) for w, b in zip(ws, bms)]
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
return ws, gs
def get_stages_from_blocks(ws, rs):
"""Gets ws/ds of network at each stage from per block values."""
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
return s_ws, s_ds
def generate_regnet(w_a, w_0, w_m, d, q=8):
"""Generates per block ws from RegNet parameters."""
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
ws_cont = np.arange(d) * w_a + w_0
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
ws = w_0 * np.power(w_m, ks)
ws = np.round(np.divide(ws, q)) * q
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
return ws, num_stages, max_stage, ws_cont
class RegNet(AnyNet):
"""RegNet model."""
def __init__(self):
# Generate RegNet ws per block
b_ws, num_s, _, _ = generate_regnet(
cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
)
# print(cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH, cfg.REGNET.GROUP_W)
# Convert to per stage format
ws, ds = get_stages_from_blocks(b_ws, b_ws)
# Generate group widths and bot muls
gws = [cfg.REGNET.GROUP_W for _ in range(num_s)]
bms = [cfg.REGNET.BOT_MUL for _ in range(num_s)]
# Adjust the compatibility of ws and gws
ws, gws = adjust_ws_gs_comp(ws, bms, gws)
# Use the same stride for each stage
ss = [cfg.REGNET.STRIDE for _ in range(num_s)]
# Use SE for RegNetY
se_r = cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None
# Construct the model
kwargs = {
"stem_type": cfg.REGNET.STEM_TYPE,
"stem_w": cfg.REGNET.STEM_W,
"block_type": cfg.REGNET.BLOCK_TYPE,
"ss": ss,
"ds": ds,
"ws": ws,
"bms": bms,
"gws": gws,
"se_r": se_r,
"nc": cfg.MODEL.NUM_CLASSES,
}
super(RegNet, self).__init__(**kwargs)
================================================
FILE: pycls/models/resnet.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""ResNe(X)t models."""
import pycls.utils.logging as lu
import pycls.utils.net as nu
import torch.nn as nn
from pycls.core.config import cfg
logger = lu.get_logger(__name__)
# Stage depths for ImageNet models
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
def get_trans_fun(name):
"""Retrieves the transformation function by name."""
trans_funs = {
"basic_transform": BasicTransform,
"bottleneck_transform": BottleneckTransform,
}
assert (
name in trans_funs.keys()
), "Transformation function '{}' not supported".format(name)
return trans_funs[name]
class ResHead(nn.Module):
"""ResNet head."""
def __init__(self, w_in, nc):
super(ResHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class BasicTransform(nn.Module):
"""Basic transformation: 3x3, 3x3"""
def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
assert (
w_b is None and num_gs == 1
), "Basic transform does not support w_b and num_gs options"
super(BasicTransform, self).__init__()
self._construct(w_in, w_out, stride)
def _construct(self, w_in, w_out, stride):
# 3x3, BN, ReLU
self.a = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN
self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, 3x3, 1x1"""
def __init__(self, w_in, w_out, stride, w_b, num_gs):
super(BottleneckTransform, self).__init__()
self._construct(w_in, w_out, stride, w_b, num_gs)
def _construct(self, w_in, w_out, stride, w_b, num_gs):
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
(str1x1, str3x3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
# 1x1, BN, ReLU
self.a = nn.Conv2d(
w_in, w_b, kernel_size=1, stride=str1x1, padding=0, bias=False
)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN, ReLU
self.b = nn.Conv2d(
w_b, w_b, kernel_size=3, stride=str3x3, padding=1, groups=num_gs, bias=False
)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 1x1, BN
self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class ResBlock(nn.Module):
"""Residual block: x + F(x)"""
def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
super(ResBlock, self).__init__()
self._construct(w_in, w_out, stride, trans_fun, w_b, num_gs)
def _add_skip_proj(self, w_in, w_out, stride):
self.proj = nn.Conv2d(
w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
def _construct(self, w_in, w_out, stride, trans_fun, w_b, num_gs):
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self._add_skip_proj(w_in, w_out, stride)
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
class ResStage(nn.Module):
"""Stage of ResNet."""
def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
super(ResStage, self).__init__()
self._construct(w_in, w_out, stride, d, w_b, num_gs)
def _construct(self, w_in, w_out, stride, d, w_b, num_gs):
# Construct the blocks
for i in range(d):
# Stride and w_in apply to the first block of the stage
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
# Retrieve the transformation function
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
# Construct the block
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
self.add_module("b{}".format(i + 1), res_block)
def forward(self, x):
for block in self.children():
x = block(x)
return x
class ResStem(nn.Module):
"""Stem of ResNet."""
def __init__(self, w_in, w_out):
assert (
cfg.TRAIN.DATASET == cfg.TEST.DATASET
), "Train and test dataset must be the same for now"
super(ResStem, self).__init__()
if "cifar" in cfg.TRAIN.DATASET:
self._construct_cifar(w_in, w_out)
else:
self._construct_imagenet(w_in, w_out)
def _construct_cifar(self, w_in, w_out):
# 3x3, BN, ReLU
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def _construct_imagenet(self, w_in, w_out):
# 7x7, BN, ReLU, maxpool
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class ResNet(nn.Module):
"""ResNet model."""
def __init__(self):
assert cfg.TRAIN.DATASET in [
"cifar10",
"imagenet",
], "Training ResNet on {} is not supported".format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in [
"cifar10",
"imagenet",
], "Testing ResNet on {} is not supported".format(cfg.TEST.DATASET)
super(ResNet, self).__init__()
if "cifar" in cfg.TRAIN.DATASET:
self._construct_cifar()
else:
self._construct_imagenet()
self.apply(nu.init_weights)
def _construct_cifar(self):
assert (
cfg.MODEL.DEPTH - 2
) % 6 == 0, "Model depth should be of the format 6n + 2 for cifar"
logger.info("Constructing: ResNet-{}".format(cfg.MODEL.DEPTH))
# Each stage has the same number of blocks for cifar
d = int((cfg.MODEL.DEPTH - 2) / 6)
# Stem: (N, 3, 32, 32) -> (N, 16, 32, 32)
self.stem = ResStem(w_in=3, w_out=16)
# Stage 1: (N, 16, 32, 32) -> (N, 16, 32, 32)
self.s1 = ResStage(w_in=16, w_out=16, stride=1, d=d)
# Stage 2: (N, 16, 32, 32) -> (N, 32, 16, 16)
self.s2 = ResStage(w_in=16, w_out=32, stride=2, d=d)
# Stage 3: (N, 32, 16, 16) -> (N, 64, 8, 8)
self.s3 = ResStage(w_in=32, w_out=64, stride=2, d=d)
# Head: (N, 64, 8, 8) -> (N, num_classes)
self.head = ResHead(w_in=64, nc=cfg.MODEL.NUM_CLASSES)
def _construct_imagenet(self):
logger.info(
"Constructing: ResNe(X)t-{}-{}x{}, {}".format(
cfg.MODEL.DEPTH,
cfg.RESNET.NUM_GROUPS,
cfg.RESNET.WIDTH_PER_GROUP,
cfg.RESNET.TRANS_FUN,
)
)
# Retrieve the number of blocks per stage
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
# Compute the initial bottleneck width
num_gs = cfg.RESNET.NUM_GROUPS
w_b = cfg.RESNET.WIDTH_PER_GROUP * num_gs
# Stem: (N, 3, 224, 224) -> (N, 64, 56, 56)
self.stem = ResStem(w_in=3, w_out=64)
# Stage 1: (N, 64, 56, 56) -> (N, 256, 56, 56)
self.s1 = ResStage(w_in=64, w_out=256, stride=1, d=d1, w_b=w_b, num_gs=num_gs)
# Stage 2: (N, 256, 56, 56) -> (N, 512, 28, 28)
self.s2 = ResStage(
w_in=256, w_out=512, stride=2, d=d2, w_b=w_b * 2, num_gs=num_gs
)
# Stage 3: (N, 512, 56, 56) -> (N, 1024, 14, 14)
self.s3 = ResStage(
w_in=512, w_out=1024, stride=2, d=d3, w_b=w_b * 4, num_gs=num_gs
)
# Stage 4: (N, 1024, 14, 14) -> (N, 2048, 7, 7)
self.s4 = ResStage(
w_in=1024, w_out=2048, stride=2, d=d4, w_b=w_b * 8, num_gs=num_gs
)
# Head: (N, 2048, 7, 7) -> (N, num_classes)
self.head = ResHead(w_in=2048, nc=cfg.MODEL.NUM_CLASSES)
def forward(self, x):
for module in self.children():
x = module(x)
return x
================================================
FILE: pycls/utils/__init__.py
================================================
================================================
FILE: pycls/utils/benchmark.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions for benchmarking networks."""
import pycls.utils.logging as lu
import torch
from pycls.core.config import cfg
from pycls.utils.timer import Timer
@torch.no_grad()
def compute_fw_test_time(model, inputs):
"""Computes forward test time (no grad, eval mode)."""
# Use eval mode
model.eval()
# Warm up the caches
for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER):
model(inputs)
# Make sure warmup kernels completed
torch.cuda.synchronize()
# Compute precise forward pass time
timer = Timer()
for _cur_iter in range(cfg.PREC_TIME.NUM_ITER):
timer.tic()
model(inputs)
torch.cuda.synchronize()
timer.toc()
# Make sure forward kernels completed
torch.cuda.synchronize()
return timer.average_time
def compute_fw_bw_time(model, loss_fun, inputs, labels):
"""Computes forward backward time."""
# Use train mode
model.train()
# Warm up the caches
for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER):
preds = model(inputs)
loss = loss_fun(preds, labels)
loss.backward()
# Make sure warmup kernels completed
torch.cuda.synchronize()
# Compute precise forward backward pass time
fw_timer = Timer()
bw_timer = Timer()
for _cur_iter in range(cfg.PREC_TIME.NUM_ITER):
# Forward
fw_timer.tic()
preds = model(inputs)
loss = loss_fun(preds, labels)
torch.cuda.synchronize()
fw_timer.toc()
# Backward
bw_timer.tic()
loss.backward()
torch.cuda.synchronize()
bw_timer.toc()
# Make sure forward backward kernels completed
torch.cuda.synchronize()
return fw_timer.average_time, bw_timer.average_time
def compute_precise_time(model, loss_fun):
"""Computes precise time."""
# Generate a dummy mini-batch
im_size = cfg.TRAIN.IM_SIZE
inputs = torch.rand(cfg.PREC_TIME.BATCH_SIZE, 3, im_size, im_size)
labels = torch.zeros(cfg.PREC_TIME.BATCH_SIZE, dtype=torch.int64)
# Copy the data to the GPU
inputs = inputs.cuda(non_blocking=False)
labels = labels.cuda(non_blocking=False)
# Compute precise time
fw_test_time = compute_fw_test_time(model, inputs)
fw_time, bw_time = compute_fw_bw_time(model, loss_fun, inputs, labels)
# Log precise time
lu.log_json_stats(
{
"prec_test_fw_time": fw_test_time,
"prec_train_fw_time": fw_time,
"prec_train_bw_time": bw_time,
"prec_train_fw_bw_time": fw_time + bw_time,
}
)
================================================
FILE: pycls/utils/checkpoint.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions that handle saving and loading of checkpoints."""
import os
import pycls.utils.distributed as du
import torch
from pycls.core.config import cfg
# Common prefix for checkpoint file names
_NAME_PREFIX = "model_epoch_"
# Checkpoints directory name
_DIR_NAME = "checkpoints"
def get_checkpoint_dir():
"""Retrieves the location for storing checkpoints."""
return os.path.join(cfg.OUT_DIR, _DIR_NAME)
def get_checkpoint(epoch):
"""Retrieves the path to a checkpoint file."""
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
return os.path.join(get_checkpoint_dir(), name)
def get_last_checkpoint():
"""Retrieves the most recent checkpoint (highest epoch number)."""
checkpoint_dir = get_checkpoint_dir()
# Checkpoint file names are in lexicographic order
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
last_checkpoint_name = sorted(checkpoints)[-1]
return os.path.join(checkpoint_dir, last_checkpoint_name)
def has_checkpoint():
"""Determines if there are checkpoints available."""
checkpoint_dir = get_checkpoint_dir()
if not os.path.exists(checkpoint_dir):
return False
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
def is_checkpoint_epoch(cur_epoch):
"""Determines if a checkpoint should be saved on current epoch."""
return (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0
def save_checkpoint(model, optimizer, epoch):
"""Saves a checkpoint."""
# Save checkpoints only from the master process
if not du.is_master_proc():
return
# Ensure that the checkpoint dir exists
os.makedirs(get_checkpoint_dir(), exist_ok=True)
# Omit the DDP wrapper in the multi-gpu setting
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
# Record the state
checkpoint = {
"epoch": epoch,
"model_state": sd,
"optimizer_state": optimizer.state_dict(),
"cfg": cfg.dump(),
}
# Write the checkpoint
checkpoint_file = get_checkpoint(epoch + 1)
torch.save(checkpoint, checkpoint_file)
return checkpoint_file
def load_checkpoint(checkpoint_file, model, optimizer=None):
"""Loads the checkpoint from the given file."""
assert os.path.exists(checkpoint_file), "Checkpoint '{}' not found".format(
checkpoint_file
)
# Load the checkpoint on CPU to avoid GPU mem spike
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# Account for the DDP wrapper in the multi-gpu setting
ms = model.module if cfg.NUM_GPUS > 1 else model
ms.load_state_dict(checkpoint["model_state"])
# Load the optimizer state (commonly not done when fine-tuning)
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer_state"])
return checkpoint["epoch"]
================================================
FILE: pycls/utils/distributed.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Distributed helpers."""
import torch
from pycls.core.config import cfg
def is_master_proc():
"""Determines if the current process is the master process.
Master process is responsible for logging, writing and loading checkpoints.
In the multi GPU setting, we assign the master role to the rank 0 process.
When training using a single GPU, there is only one training processes
which is considered the master processes.
"""
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
def init_process_group(proc_rank, world_size):
"""Initializes the default process group."""
# Set the GPU to use
torch.cuda.set_device(proc_rank)
# Initialize the process group
torch.distributed.init_process_group(
backend=cfg.DIST_BACKEND,
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
world_size=world_size,
rank=proc_rank,
)
def destroy_process_group():
"""Destroys the default process group."""
torch.distributed.destroy_process_group()
def scaled_all_reduce(tensors):
"""Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of
the process group (equivalent to cfg.NUM_GPUS).
"""
# Queue the reductions
reductions = []
for tensor in tensors:
reduction = torch.distributed.all_reduce(tensor, async_op=True)
reductions.append(reduction)
# Wait for reductions to finish
for reduction in reductions:
reduction.wait()
# Scale the results
for tensor in tensors:
tensor.mul_(1.0 / cfg.NUM_GPUS)
return tensors
================================================
FILE: pycls/utils/error_handler.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Multiprocessing error handler."""
import os
import signal
import threading
class ChildException(Exception):
"""Wraps an exception from a child process."""
def __init__(self, child_trace):
super(ChildException, self).__init__(child_trace)
class ErrorHandler(object):
"""Multiprocessing error handler (based on fairseq's).
Listens for errors in child processes and
propagates the tracebacks to the parent process.
"""
def __init__(self, error_queue):
# Shared error queue
self.error_queue = error_queue
# Children processes sharing the error queue
self.children_pids = []
# Start a thread listening to errors
self.error_listener = threading.Thread(target=self.listen, daemon=True)
self.error_listener.start()
# Register the signal handler
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
"""Registers a child process."""
self.children_pids.append(pid)
def listen(self):
"""Listens for errors in the error queue."""
# Wait until there is an error in the queue
child_trace = self.error_queue.get()
# Put the error back for the signal handler
self.error_queue.put(child_trace)
# Invoke the signal handler
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, _sig_num, _stack_frame):
"""Signal handler."""
# Kill children processes
for pid in self.children_pids:
os.kill(pid, signal.SIGINT)
# Propagate the error from the child process
raise ChildException(self.error_queue.get())
================================================
FILE: pycls/utils/io.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""IO utilities (adapted from Detectron)"""
import logging
import os
import re
import sys
from urllib import request as urlrequest
logger = logging.getLogger(__name__)
_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
def cache_url(url_or_file, cache_dir):
"""Download the file specified by the URL to the cache_dir and return the
path to the cached file. If the argument is not a URL, simply return it as
is.
"""
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
if not is_url:
return url_or_file
url = url_or_file
assert url.startswith(_PYCLS_BASE_URL), (
"pycls only automatically caches URLs in the pycls S3 bucket: {}"
).format(_PYCLS_BASE_URL)
cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
if os.path.exists(cache_file_path):
return cache_file_path
cache_file_dir = os.path.dirname(cache_file_path)
if not os.path.exists(cache_file_dir):
os.makedirs(cache_file_dir)
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
download_url(url, cache_file_path)
return cache_file_path
def _progress_bar(count, total):
"""Report download progress.
Credit:
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
"""
bar_len = 60
filled_len = int(round(bar_len * count / float(total)))
percents = round(100.0 * count / float(total), 1)
bar = "=" * filled_len + "-" * (bar_len - filled_len)
sys.stdout.write(
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
)
sys.stdout.flush()
if count >= total:
sys.stdout.write("\n")
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
"""Download url and write it to dst_file_path.
Credit:
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
"""
req = urlrequest.Request(url)
response = urlrequest.urlopen(req)
total_size = response.info().get("Content-Length").strip()
total_size = int(total_size)
bytes_so_far = 0
with open(dst_file_path, "wb") as f:
while 1:
chunk = response.read(chunk_size)
bytes_so_far += len(chunk)
if not chunk:
break
if progress_hook:
progress_hook(bytes_so_far, total_size)
f.write(chunk)
return bytes_so_far
================================================
FILE: pycls/utils/logging.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Logging."""
import builtins
import decimal
import logging
import os
import sys
import pycls.utils.distributed as du
import simplejson
from pycls.core.config import cfg
# Show filename and line number in logs
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
# Log file name (for cfg.LOG_DEST = 'file')
_LOG_FILE = "stdout.log"
# Printed json stats lines will be tagged w/ this
_TAG = "json_stats: "
def _suppress_print():
"""Suppresses printing from the current process."""
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
pass
builtins.print = ignore
def setup_logging():
"""Sets up the logging."""
# Enable logging only for the master process
if du.is_master_proc():
# Clear the root logger to prevent any existing logging config
# (e.g. set by another module) from messing with our setup
logging.root.handlers = []
# Construct logging configuration
logging_config = {"level": logging.INFO, "format": _FORMAT}
# Log either to stdout or to a file
if cfg.LOG_DEST == "stdout":
logging_config["stream"] = sys.stdout
else:
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
# Configure logging
logging.basicConfig(**logging_config)
else:
_suppress_print()
def get_logger(name):
"""Retrieves the logger."""
return logging.getLogger(name)
def log_json_stats(stats):
"""Logs json stats."""
# Decimal + string workaround for having fixed len float vals in logs
stats = {
k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
for k, v in stats.items()
}
json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
logger = get_logger(__name__)
logger.info("{:s}{:s}".format(_TAG, json_stats))
def load_json_stats(log_file):
"""Loads json_stats from a single log file."""
with open(log_file, "r") as f:
lines = f.readlines()
json_lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
json_stats = [simplejson.loads(l) for l in json_lines]
return json_stats
def parse_json_stats(log, row_type, key):
"""Extract values corresponding to row_type/key out of log."""
vals = [row[key] for row in log if row["_type"] == row_type and key in row]
if key == "iter" or key == "epoch":
vals = [int(val.split("/")[0]) for val in vals]
return vals
def get_log_files(log_dir, name_filter=""):
"""Get all log files in directory containing subdirs of trained models."""
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
files = [os.path.join(log_dir, n, _LOG_FILE) for n in names]
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
files, names = zip(*f_n_ps)
return files, names
================================================
FILE: pycls/utils/lr_policy.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Learning rate policies."""
import numpy as np
from pycls.core.config import cfg
def lr_fun_steps(cur_epoch):
"""Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind)
def lr_fun_exp(cur_epoch):
"""Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch)
def lr_fun_cos(cur_epoch):
"""Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH
return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch))
def get_lr_fun():
"""Retrieves the specified lr policy function"""
lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
if lr_fun not in globals():
raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY)
return globals()[lr_fun]
def get_epoch_lr(cur_epoch):
"""Retrieves the lr for the given epoch according to the policy."""
lr = get_lr_fun()(cur_epoch)
# Linear warmup
if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS:
alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
lr *= warmup_factor
return lr
================================================
FILE: pycls/utils/meters.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Meters."""
import datetime
from collections import deque
import numpy as np
import pycls.utils.logging as lu
import pycls.utils.metrics as metrics
from pycls.core.config import cfg
from pycls.utils.timer import Timer
def eta_str(eta_td):
"""Converts an eta timedelta to a fixed-width string format."""
days = eta_td.days
hrs, rem = divmod(eta_td.seconds, 3600)
mins, secs = divmod(rem, 60)
return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs)
class ScalarMeter(object):
"""Measures a scalar value (adapted from Detectron)."""
def __init__(self, window_size):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
def reset(self):
self.deque.clear()
self.total = 0.0
self.count = 0
def add_value(self, value):
self.deque.append(value)
self.count += 1
self.total += value
def get_win_median(self):
return np.median(self.deque)
def get_win_avg(self):
return np.mean(self.deque)
def get_global_avg(self):
return self.total / self.count
class TrainMeter(object):
"""Measures training stats."""
def __init__(self, epoch_iters):
self.epoch_iters = epoch_iters
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
self.iter_timer = Timer()
self.loss = ScalarMeter(cfg.LOG_PERIOD)
self.loss_total = 0.0
self.lr = None
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Number of misclassified examples
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def reset(self, timer=False):
if timer:
self.iter_timer.reset()
self.loss.reset()
self.loss_total = 0.0
self.lr = None
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
# Current minibatch stats
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.loss.add_value(loss)
self.lr = lr
# Aggregate stats
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.loss_total += loss * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
eta_sec = self.iter_timer.average_time * (
self.max_iter - (cur_epoch * self.epoch_iters + cur_iter + 1)
)
eta_td = datetime.timedelta(seconds=int(eta_sec))
mem_usage = metrics.gpu_mem_usage()
stats = {
"_type": "train_iter",
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"eta": eta_str(eta_td),
"top1_err": self.mb_top1_err.get_win_median(),
"top5_err": self.mb_top5_err.get_win_median(),
"loss": self.loss.get_win_median(),
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
lu.log_json_stats(stats)
def get_epoch_stats(self, cur_epoch):
eta_sec = self.iter_timer.average_time * (
self.max_iter - (cur_epoch + 1) * self.epoch_iters
)
eta_td = datetime.timedelta(seconds=int(eta_sec))
mem_usage = metrics.gpu_mem_usage()
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
avg_loss = self.loss_total / self.num_samples
stats = {
"_type": "train_epoch",
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"eta": eta_str(eta_td),
"top1_err": top1_err,
"top5_err": top5_err,
"loss": avg_loss,
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
lu.log_json_stats(stats)
class TestMeter(object):
"""Measures testing stats."""
def __init__(self, max_iter):
self.max_iter = max_iter
self.iter_timer = Timer()
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Min errors (over the full test set)
self.min_top1_err = 100.0
self.min_top5_err = 100.0
# Number of misclassified examples
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def reset(self, min_errs=False):
if min_errs:
self.min_top1_err = 100.0
self.min_top5_err = 100.0
self.iter_timer.reset()
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, top1_err, top5_err, mb_size):
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
mem_usage = metrics.gpu_mem_usage()
iter_stats = {
"_type": "test_iter",
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"top1_err": self.mb_top1_err.get_win_median(),
"top5_err": self.mb_top5_err.get_win_median(),
"mem": int(np.ceil(mem_usage)),
}
return iter_stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
lu.log_json_stats(stats)
def get_epoch_stats(self, cur_epoch):
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
self.min_top1_err = min(self.min_top1_err, top1_err)
self.min_top5_err = min(self.min_top5_err, top5_err)
mem_usage = metrics.gpu_mem_usage()
stats = {
"_type": "test_epoch",
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"top1_err": top1_err,
"top5_err": top5_err,
"min_top1_err": self.min_top1_err,
"min_top5_err": self.min_top5_err,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
lu.log_json_stats(stats)
================================================
FILE: pycls/utils/metrics.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions for computing metrics."""
import numpy as np
import torch
import torch.nn as nn
from pycls.core.config import cfg
# Number of bytes in a megabyte
_B_IN_MB = 1024 * 1024
def topks_correct(preds, labels, ks):
"""Computes the number of top-k correct predictions for each k."""
assert preds.size(0) == labels.size(
0
), "Batch dim of predictions and labels must match"
# Find the top max_k predictions for each sample
_top_max_k_vals, top_max_k_inds = torch.topk(
preds, max(ks), dim=1, largest=True, sorted=True
)
# (batch_size, max_k) -> (max_k, batch_size)
top_max_k_inds = top_max_k_inds.t()
# (batch_size, ) -> (max_k, batch_size)
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
# (i, j) = 1 if top i-th prediction for the j-th sample is correct
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
# Compute the number of topk correct predictions for each k
topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks]
return topks_correct
def topk_errors(preds, labels, ks):
"""Computes the top-k error for each k."""
num_topks_correct = topks_correct(preds, labels, ks)
return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct]
def topk_accuracies(preds, labels, ks):
"""Computes the top-k accuracy for each k."""
num_topks_correct = topks_correct(preds, labels, ks)
return [(x / preds.size(0)) * 100.0 for x in num_topks_correct]
def params_count(model):
"""Computes the number of parameters."""
return np.sum([p.numel() for p in model.parameters()]).item()
def flops_count(model):
"""Computes the number of flops statically."""
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
count = 0
for n, m in model.named_modules():
if isinstance(m, nn.Conv2d):
if "se." in n:
count += m.in_channels * m.out_channels + m.bias.numel()
continue
h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
count += np.prod([m.weight.numel(), h_out, w_out])
if ".proj" not in n:
h, w = h_out, w_out
elif isinstance(m, nn.MaxPool2d):
h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
elif isinstance(m, nn.Linear):
count += m.in_features * m.out_features + m.bias.numel()
return count.item()
def acts_count(model):
"""Computes the number of activations statically."""
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
count = 0
for n, m in model.named_modules():
if isinstance(m, nn.Conv2d):
if "se." in n:
count += m.out_channels
continue
h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
count += np.prod([m.out_channels, h_out, w_out])
if ".proj" not in n:
h, w = h_out, w_out
elif isinstance(m, nn.MaxPool2d):
h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
elif isinstance(m, nn.Linear):
count += m.out_features
return count.item()
def gpu_mem_usage():
"""Computes the GPU memory usage for the current device (MB)."""
mem_usage_bytes = torch.cuda.max_memory_allocated()
return mem_usage_bytes / _B_IN_MB
================================================
FILE: pycls/utils/multiprocessing.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Multiprocessing helpers."""
import multiprocessing as mp
import traceback
import pycls.utils.distributed as du
from pycls.utils.error_handler import ErrorHandler
def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
"""Runs a function from a child process."""
try:
# Initialize the process group
du.init_process_group(proc_rank, world_size)
# Run the function
fun(*fun_args, **fun_kwargs)
except KeyboardInterrupt:
# Killed by the parent process
pass
except Exception:
# Propagate exception to the parent process
error_queue.put(traceback.format_exc())
finally:
# Destroy the process group
du.destroy_process_group()
def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
"""Runs a function in a multi-proc setting."""
if fun_kwargs is None:
fun_kwargs = {}
# Handle errors from training subprocesses
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Run each training subprocess
ps = []
for i in range(num_proc):
p_i = mp.Process(
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
)
ps.append(p_i)
p_i.start()
error_handler.add_child(p_i.pid)
# Wait for each subprocess to finish
for p in ps:
p.join()
================================================
FILE: pycls/utils/net.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions for manipulating networks."""
import itertools
import math
import torch
import torch.nn as nn
from pycls.core.config import cfg
def init_weights(m):
"""Performs ResNet-style weight initialization."""
if isinstance(m, nn.Conv2d):
# Note that there is no bias due to BN
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
elif isinstance(m, nn.BatchNorm2d):
zero_init_gamma = (
hasattr(m, "final_bn") and m.final_bn and cfg.BN.ZERO_INIT_FINAL_GAMMA
)
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.01)
m.bias.data.zero_()
@torch.no_grad()
def compute_precise_bn_stats(model, loader):
"""Computes precise BN stats on training data."""
# Compute the number of minibatches to use
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
# Retrieve the BN layers
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
# Initialize stats storage
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
# Remember momentum values
moms = [bn.momentum for bn in bns]
# Disable momentum
for bn in bns:
bn.momentum = 1.0
# Accumulate the stats across the data samples
for inputs, _labels in itertools.islice(loader, num_iter):
model(inputs.cuda())
# Accumulate the stats for each BN layer
for i, bn in enumerate(bns):
m, v = bn.running_mean, bn.running_var
sqs[i] += (v + m * m) / num_iter
mus[i] += m / num_iter
# Set the stats and restore momentum values
for i, bn in enumerate(bns):
bn.running_var = sqs[i] - mus[i] * mus[i]
bn.running_mean = mus[i]
bn.momentum = moms[i]
def reset_bn_stats(model):
"""Resets running BN stats."""
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.reset_running_stats()
def drop_connect(x, drop_ratio):
"""Drop connect (adapted from DARTS)."""
keep_ratio = 1.0 - drop_ratio
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
mask.bernoulli_(keep_ratio)
x.div_(keep_ratio)
x.mul_(mask)
return x
def get_flat_weights(model):
"""Gets all model weights as a single flat vector."""
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
def set_flat_weights(model, flat_weights):
"""Sets all model weights from a single flat vector."""
k = 0
for p in model.parameters():
n = p.data.numel()
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
k += n
assert k == flat_weights.numel()
================================================
FILE: pycls/utils/plotting.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Plotting functions."""
import colorlover as cl
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.offline as offline
import pycls.utils.logging as lu
def get_plot_colors(max_colors, color_format="pyplot"):
"""Generate colors for plotting."""
colors = cl.scales["11"]["qual"]["Paired"]
if max_colors > len(colors):
colors = cl.to_rgb(cl.interp(colors, max_colors))
if color_format == "pyplot":
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
return colors
def prepare_plot_data(log_files, names, key="top1_err"):
"""Load logs and extract data for plotting error curves."""
plot_data = []
for file, name in zip(log_files, names):
d, log = {}, lu.load_json_stats(file)
for phase in ["train", "test"]:
x = lu.parse_json_stats(log, phase + "_epoch", "epoch")
y = lu.parse_json_stats(log, phase + "_epoch", key)
d["x_" + phase], d["y_" + phase] = x, y
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
plot_data.append(d)
assert len(plot_data) > 0, "No data to plot"
return plot_data
def plot_error_curves_plotly(log_files, names, filename, key="top1_err"):
"""Plot error curves using plotly and save to file."""
plot_data = prepare_plot_data(log_files, names, key)
colors = get_plot_colors(len(plot_data), "plotly")
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
data = []
for i, d in enumerate(plot_data):
s = str(i)
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=True,
showlegend=False,
)
)
data.append(
go.Scatter(
x=d["x_test"],
y=d["y_test"],
mode="lines",
name=d["test_label"],
line=line_test,
legendgroup=s,
visible=True,
showlegend=True,
)
)
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=False,
showlegend=True,
)
)
# Prepare layout w ability to toggle 'all', 'train', 'test'
titlefont = {"size": 18, "color": "#7f7f7f"}
vis = [[True, True, False], [False, False, True], [False, True, False]]
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
buttons = [{"label": l, "args": v, "method": "update"} for l, v in buttons]
layout = go.Layout(
title=key + " vs. epoch
[dash=train, solid=test]",
xaxis={"title": "epoch", "titlefont": titlefont},
yaxis={"title": key, "titlefont": titlefont},
showlegend=True,
hoverlabel={"namelength": -1},
updatemenus=[
{
"buttons": buttons,
"direction": "down",
"showactive": True,
"x": 1.02,
"xanchor": "left",
"y": 1.08,
"yanchor": "top",
}
],
)
# Create plotly plot
offline.plot({"data": data, "layout": layout}, filename=filename)
def plot_error_curves_pyplot(log_files, names, filename=None, key="top1_err"):
"""Plot error curves using matplotlib.pyplot and save to file."""
plot_data = prepare_plot_data(log_files, names, key)
colors = get_plot_colors(len(names))
for ind, d in enumerate(plot_data):
c, lbl = colors[ind], d["test_label"]
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
plt.title(key + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
plt.xlabel("epoch", fontsize=14)
plt.ylabel(key, fontsize=14)
plt.grid(alpha=0.4)
plt.legend()
if filename:
plt.savefig(filename)
plt.clf()
else:
plt.show()
================================================
FILE: pycls/utils/timer.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Timer."""
import time
class Timer(object):
"""A simple timer (adapted from Detectron)."""
def __init__(self):
self.reset()
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
def reset(self):
self.total_time = 0.0
self.calls = 0
self.start_time = 0.0
self.diff = 0.0
self.average_time = 0.0
================================================
FILE: simplejson/__init__.py
================================================
r"""JSON (JavaScript Object Notation) is a subset of
JavaScript syntax (ECMA-262 3rd edition) used as a lightweight data
interchange format.
:mod:`simplejson` exposes an API familiar to users of the standard library
:mod:`marshal` and :mod:`pickle` modules. It is the externally maintained
version of the :mod:`json` library contained in Python 2.6, but maintains
compatibility back to Python 2.5 and (currently) has significant performance
advantages, even without using the optional C extension for speedups.
Encoding basic Python object hierarchies::
>>> import simplejson as json
>>> json.dumps(['foo', {'bar': ('baz', None, 1.0, 2)}])
'["foo", {"bar": ["baz", null, 1.0, 2]}]'
>>> print(json.dumps("\"foo\bar"))
"\"foo\bar"
>>> print(json.dumps(u'\u1234'))
"\u1234"
>>> print(json.dumps('\\'))
"\\"
>>> print(json.dumps({"c": 0, "b": 0, "a": 0}, sort_keys=True))
{"a": 0, "b": 0, "c": 0}
>>> from simplejson.compat import StringIO
>>> io = StringIO()
>>> json.dump(['streaming API'], io)
>>> io.getvalue()
'["streaming API"]'
Compact encoding::
>>> import simplejson as json
>>> obj = [1,2,3,{'4': 5, '6': 7}]
>>> json.dumps(obj, separators=(',',':'), sort_keys=True)
'[1,2,3,{"4":5,"6":7}]'
Pretty printing::
>>> import simplejson as json
>>> print(json.dumps({'4': 5, '6': 7}, sort_keys=True, indent=' '))
{
"4": 5,
"6": 7
}
Decoding JSON::
>>> import simplejson as json
>>> obj = [u'foo', {u'bar': [u'baz', None, 1.0, 2]}]
>>> json.loads('["foo", {"bar":["baz", null, 1.0, 2]}]') == obj
True
>>> json.loads('"\\"foo\\bar"') == u'"foo\x08ar'
True
>>> from simplejson.compat import StringIO
>>> io = StringIO('["streaming API"]')
>>> json.load(io)[0] == 'streaming API'
True
Specializing JSON object decoding::
>>> import simplejson as json
>>> def as_complex(dct):
... if '__complex__' in dct:
... return complex(dct['real'], dct['imag'])
... return dct
...
>>> json.loads('{"__complex__": true, "real": 1, "imag": 2}',
... object_hook=as_complex)
(1+2j)
>>> from decimal import Decimal
>>> json.loads('1.1', parse_float=Decimal) == Decimal('1.1')
True
Specializing JSON object encoding::
>>> import simplejson as json
>>> def encode_complex(obj):
... if isinstance(obj, complex):
... return [obj.real, obj.imag]
... raise TypeError('Object of type %s is not JSON serializable' %
... obj.__class__.__name__)
...
>>> json.dumps(2 + 1j, default=encode_complex)
'[2.0, 1.0]'
>>> json.JSONEncoder(default=encode_complex).encode(2 + 1j)
'[2.0, 1.0]'
>>> ''.join(json.JSONEncoder(default=encode_complex).iterencode(2 + 1j))
'[2.0, 1.0]'
Using simplejson.tool from the shell to validate and pretty-print::
$ echo '{"json":"obj"}' | python -m simplejson.tool
{
"json": "obj"
}
$ echo '{ 1.2:3.4}' | python -m simplejson.tool
Expecting property name: line 1 column 3 (char 2)
Parsing multiple documents serialized as JSON lines (newline-delimited JSON)::
>>> import simplejson as json
>>> def loads_lines(docs):
... for doc in docs.splitlines():
... yield json.loads(doc)
...
>>> sum(doc["count"] for doc in loads_lines('{"count":1}\n{"count":2}\n{"count":3}\n'))
6
Serializing multiple objects to JSON lines (newline-delimited JSON)::
>>> import simplejson as json
>>> def dumps_lines(objs):
... for obj in objs:
... yield json.dumps(obj, separators=(',',':')) + '\n'
...
>>> ''.join(dumps_lines([{'count': 1}, {'count': 2}, {'count': 3}]))
'{"count":1}\n{"count":2}\n{"count":3}\n'
"""
from __future__ import absolute_import
__version__ = '3.17.2'
__all__ = [
'dump', 'dumps', 'load', 'loads',
'JSONDecoder', 'JSONDecodeError', 'JSONEncoder',
'OrderedDict', 'simple_first', 'RawJSON'
]
__author__ = 'Bob Ippolito '
from decimal import Decimal
from .errors import JSONDecodeError
from .raw_json import RawJSON
from .decoder import JSONDecoder
from .encoder import JSONEncoder, JSONEncoderForHTML
def _import_OrderedDict():
import collections
try:
return collections.OrderedDict
except AttributeError:
from . import ordered_dict
return ordered_dict.OrderedDict
OrderedDict = _import_OrderedDict()
def _import_c_make_encoder():
try:
from ._speedups import make_encoder
return make_encoder
except ImportError:
return None
_default_encoder = JSONEncoder(
skipkeys=False,
ensure_ascii=True,
check_circular=True,
allow_nan=True,
indent=None,
separators=None,
encoding='utf-8',
default=None,
use_decimal=True,
namedtuple_as_object=True,
tuple_as_array=True,
iterable_as_array=False,
bigint_as_string=False,
item_sort_key=None,
for_json=False,
ignore_nan=False,
int_as_string_bitcount=None,
)
def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True,
allow_nan=True, cls=None, indent=None, separators=None,
encoding='utf-8', default=None, use_decimal=True,
namedtuple_as_object=True, tuple_as_array=True,
bigint_as_string=False, sort_keys=False, item_sort_key=None,
for_json=False, ignore_nan=False, int_as_string_bitcount=None,
iterable_as_array=False, **kw):
"""Serialize ``obj`` as a JSON formatted stream to ``fp`` (a
``.write()``-supporting file-like object).
If *skipkeys* is true then ``dict`` keys that are not basic types
(``str``, ``int``, ``long``, ``float``, ``bool``, ``None``)
will be skipped instead of raising a ``TypeError``.
If *ensure_ascii* is false (default: ``True``), then the output may
contain non-ASCII characters, so long as they do not need to be escaped
by JSON. When it is true, all non-ASCII characters are escaped.
If *allow_nan* is false, then it will be a ``ValueError`` to
serialize out of range ``float`` values (``nan``, ``inf``, ``-inf``)
in strict compliance of the original JSON specification, instead of using
the JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``). See
*ignore_nan* for ECMA-262 compliant behavior.
If *indent* is a string, then JSON array elements and object members
will be pretty-printed with a newline followed by that string repeated
for each level of nesting. ``None`` (the default) selects the most compact
representation without any newlines.
If specified, *separators* should be an
``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')``
if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most
compact JSON representation, you should specify ``(',', ':')`` to eliminate
whitespace.
*encoding* is the character encoding for str instances, default is UTF-8.
*default(obj)* is a function that should return a serializable version
of obj or raise ``TypeError``. The default simply raises ``TypeError``.
If *use_decimal* is true (default: ``True``) then decimal.Decimal
will be natively serialized to JSON with full precision.
If *namedtuple_as_object* is true (default: ``True``),
:class:`tuple` subclasses with ``_asdict()`` methods will be encoded
as JSON objects.
If *tuple_as_array* is true (default: ``True``),
:class:`tuple` (and subclasses) will be encoded as JSON arrays.
If *iterable_as_array* is true (default: ``False``),
any object not in the above table that implements ``__iter__()``
will be encoded as a JSON array.
If *bigint_as_string* is true (default: ``False``), ints 2**53 and higher
or lower than -2**53 will be encoded as strings. This is to avoid the
rounding that happens in Javascript otherwise. Note that this is still a
lossy operation that will not round-trip correctly and should be used
sparingly.
If *int_as_string_bitcount* is a positive number (n), then int of size
greater than or equal to 2**n or lower than or equal to -2**n will be
encoded as strings.
If specified, *item_sort_key* is a callable used to sort the items in
each dictionary. This is useful if you want to sort items other than
in alphabetical order by key. This option takes precedence over
*sort_keys*.
If *sort_keys* is true (default: ``False``), the output of dictionaries
will be sorted by item.
If *for_json* is true (default: ``False``), objects with a ``for_json()``
method will use the return value of that method for encoding as JSON
instead of the object.
If *ignore_nan* is true (default: ``False``), then out of range
:class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as
``null`` in compliance with the ECMA-262 specification. If true, this will
override *allow_nan*.
To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the
``.default()`` method to serialize additional types), specify it with
the ``cls`` kwarg. NOTE: You should use *default* or *for_json* instead
of subclassing whenever possible.
"""
# cached encoder
if (not skipkeys and ensure_ascii and
check_circular and allow_nan and
cls is None and indent is None and separators is None and
encoding == 'utf-8' and default is None and use_decimal
and namedtuple_as_object and tuple_as_array and not iterable_as_array
and not bigint_as_string and not sort_keys
and not item_sort_key and not for_json
and not ignore_nan and int_as_string_bitcount is None
and not kw
):
iterable = _default_encoder.iterencode(obj)
else:
if cls is None:
cls = JSONEncoder
iterable = cls(skipkeys=skipkeys, ensure_ascii=ensure_ascii,
check_circular=check_circular, allow_nan=allow_nan, indent=indent,
separators=separators, encoding=encoding,
default=default, use_decimal=use_decimal,
namedtuple_as_object=namedtuple_as_object,
tuple_as_array=tuple_as_array,
iterable_as_array=iterable_as_array,
bigint_as_string=bigint_as_string,
sort_keys=sort_keys,
item_sort_key=item_sort_key,
for_json=for_json,
ignore_nan=ignore_nan,
int_as_string_bitcount=int_as_string_bitcount,
**kw).iterencode(obj)
# could accelerate with writelines in some versions of Python, at
# a debuggability cost
for chunk in iterable:
fp.write(chunk)
def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True,
allow_nan=True, cls=None, indent=None, separators=None,
encoding='utf-8', default=None, use_decimal=True,
namedtuple_as_object=True, tuple_as_array=True,
bigint_as_string=False, sort_keys=False, item_sort_key=None,
for_json=False, ignore_nan=False, int_as_string_bitcount=None,
iterable_as_array=False, **kw):
"""Serialize ``obj`` to a JSON formatted ``str``.
If ``skipkeys`` is false then ``dict`` keys that are not basic types
(``str``, ``int``, ``long``, ``float``, ``bool``, ``None``)
will be skipped instead of raising a ``TypeError``.
If *ensure_ascii* is false (default: ``True``), then the output may
contain non-ASCII characters, so long as they do not need to be escaped
by JSON. When it is true, all non-ASCII characters are escaped.
If ``check_circular`` is false, then the circular reference check
for container types will be skipped and a circular reference will
result in an ``OverflowError`` (or worse).
If ``allow_nan`` is false, then it will be a ``ValueError`` to
serialize out of range ``float`` values (``nan``, ``inf``, ``-inf``) in
strict compliance of the JSON specification, instead of using the
JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``).
If ``indent`` is a string, then JSON array elements and object members
will be pretty-printed with a newline followed by that string repeated
for each level of nesting. ``None`` (the default) selects the most compact
representation without any newlines. For backwards compatibility with
versions of simplejson earlier than 2.1.0, an integer is also accepted
and is converted to a string with that many spaces.
If specified, ``separators`` should be an
``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')``
if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most
compact JSON representation, you should specify ``(',', ':')`` to eliminate
whitespace.
``encoding`` is the character encoding for bytes instances, default is
UTF-8.
``default(obj)`` is a function that should return a serializable version
of obj or raise TypeError. The default simply raises TypeError.
If *use_decimal* is true (default: ``True``) then decimal.Decimal
will be natively serialized to JSON with full precision.
If *namedtuple_as_object* is true (default: ``True``),
:class:`tuple` subclasses with ``_asdict()`` methods will be encoded
as JSON objects.
If *tuple_as_array* is true (default: ``True``),
:class:`tuple` (and subclasses) will be encoded as JSON arrays.
If *iterable_as_array* is true (default: ``False``),
any object not in the above table that implements ``__iter__()``
will be encoded as a JSON array.
If *bigint_as_string* is true (not the default), ints 2**53 and higher
or lower than -2**53 will be encoded as strings. This is to avoid the
rounding that happens in Javascript otherwise.
If *int_as_string_bitcount* is a positive number (n), then int of size
greater than or equal to 2**n or lower than or equal to -2**n will be
encoded as strings.
If specified, *item_sort_key* is a callable used to sort the items in
each dictionary. This is useful if you want to sort items other than
in alphabetical order by key. This option takes precendence over
*sort_keys*.
If *sort_keys* is true (default: ``False``), the output of dictionaries
will be sorted by item.
If *for_json* is true (default: ``False``), objects with a ``for_json()``
method will use the return value of that method for encoding as JSON
instead of the object.
If *ignore_nan* is true (default: ``False``), then out of range
:class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as
``null`` in compliance with the ECMA-262 specification. If true, this will
override *allow_nan*.
To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the
``.default()`` method to serialize additional types), specify it with
the ``cls`` kwarg. NOTE: You should use *default* instead of subclassing
whenever possible.
"""
# cached encoder
if (not skipkeys and ensure_ascii and
check_circular and allow_nan and
cls is None and indent is None and separators is None and
encoding == 'utf-8' and default is None and use_decimal
and namedtuple_as_object and tuple_as_array and not iterable_as_array
and not bigint_as_string and not sort_keys
and not item_sort_key and not for_json
and not ignore_nan and int_as_string_bitcount is None
and not kw
):
return _default_encoder.encode(obj)
if cls is None:
cls = JSONEncoder
return cls(
skipkeys=skipkeys, ensure_ascii=ensure_ascii,
check_circular=check_circular, allow_nan=allow_nan, indent=indent,
separators=separators, encoding=encoding, default=default,
use_decimal=use_decimal,
namedtuple_as_object=namedtuple_as_object,
tuple_as_array=tuple_as_array,
iterable_as_array=iterable_as_array,
bigint_as_string=bigint_as_string,
sort_keys=sort_keys,
item_sort_key=item_sort_key,
for_json=for_json,
ignore_nan=ignore_nan,
int_as_string_bitcount=int_as_string_bitcount,
**kw).encode(obj)
_default_decoder = JSONDecoder(encoding=None, object_hook=None,
object_pairs_hook=None)
def load(fp, encoding=None, cls=None, object_hook=None, parse_float=None,
parse_int=None, parse_constant=None, object_pairs_hook=None,
use_decimal=False, namedtuple_as_object=True, tuple_as_array=True,
**kw):
"""Deserialize ``fp`` (a ``.read()``-supporting file-like object containing
a JSON document as `str` or `bytes`) to a Python object.
*encoding* determines the encoding used to interpret any
`bytes` objects decoded by this instance (``'utf-8'`` by
default). It has no effect when decoding `str` objects.
*object_hook*, if specified, will be called with the result of every
JSON object decoded and its return value will be used in place of the
given :class:`dict`. This can be used to provide custom
deserializations (e.g. to support JSON-RPC class hinting).
*object_pairs_hook* is an optional function that will be called with
the result of any object literal decode with an ordered list of pairs.
The return value of *object_pairs_hook* will be used instead of the
:class:`dict`. This feature can be used to implement custom decoders
that rely on the order that the key and value pairs are decoded (for
example, :func:`collections.OrderedDict` will remember the order of
insertion). If *object_hook* is also defined, the *object_pairs_hook*
takes priority.
*parse_float*, if specified, will be called with the string of every
JSON float to be decoded. By default, this is equivalent to
``float(num_str)``. This can be used to use another datatype or parser
for JSON floats (e.g. :class:`decimal.Decimal`).
*parse_int*, if specified, will be called with the string of every
JSON int to be decoded. By default, this is equivalent to
``int(num_str)``. This can be used to use another datatype or parser
for JSON integers (e.g. :class:`float`).
*parse_constant*, if specified, will be called with one of the
following strings: ``'-Infinity'``, ``'Infinity'``, ``'NaN'``. This
can be used to raise an exception if invalid JSON numbers are
encountered.
If *use_decimal* is true (default: ``False``) then it implies
parse_float=decimal.Decimal for parity with ``dump``.
To use a custom ``JSONDecoder`` subclass, specify it with the ``cls``
kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead
of subclassing whenever possible.
"""
return loads(fp.read(),
encoding=encoding, cls=cls, object_hook=object_hook,
parse_float=parse_float, parse_int=parse_int,
parse_constant=parse_constant, object_pairs_hook=object_pairs_hook,
use_decimal=use_decimal, **kw)
def loads(s, encoding=None, cls=None, object_hook=None, parse_float=None,
parse_int=None, parse_constant=None, object_pairs_hook=None,
use_decimal=False, **kw):
"""Deserialize ``s`` (a ``str`` or ``unicode`` instance containing a JSON
document) to a Python object.
*encoding* determines the encoding used to interpret any
:class:`bytes` objects decoded by this instance (``'utf-8'`` by
default). It has no effect when decoding :class:`unicode` objects.
*object_hook*, if specified, will be called with the result of every
JSON object decoded and its return value will be used in place of the
given :class:`dict`. This can be used to provide custom
deserializations (e.g. to support JSON-RPC class hinting).
*object_pairs_hook* is an optional function that will be called with
the result of any object literal decode with an ordered list of pairs.
The return value of *object_pairs_hook* will be used instead of the
:class:`dict`. This feature can be used to implement custom decoders
that rely on the order that the key and value pairs are decoded (for
example, :func:`collections.OrderedDict` will remember the order of
insertion). If *object_hook* is also defined, the *object_pairs_hook*
takes priority.
*parse_float*, if specified, will be called with the string of every
JSON float to be decoded. By default, this is equivalent to
``float(num_str)``. This can be used to use another datatype or parser
for JSON floats (e.g. :class:`decimal.Decimal`).
*parse_int*, if specified, will be called with the string of every
JSON int to be decoded. By default, this is equivalent to
``int(num_str)``. This can be used to use another datatype or parser
for JSON integers (e.g. :class:`float`).
*parse_constant*, if specified, will be called with one of the
following strings: ``'-Infinity'``, ``'Infinity'``, ``'NaN'``. This
can be used to raise an exception if invalid JSON numbers are
encountered.
If *use_decimal* is true (default: ``False``) then it implies
parse_float=decimal.Decimal for parity with ``dump``.
To use a custom ``JSONDecoder`` subclass, specify it with the ``cls``
kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead
of subclassing whenever possible.
"""
if (cls is None and encoding is None and object_hook is None and
parse_int is None and parse_float is None and
parse_constant is None and object_pairs_hook is None
and not use_decimal and not kw):
return _default_decoder.decode(s)
if cls is None:
cls = JSONDecoder
if object_hook is not None:
kw['object_hook'] = object_hook
if object_pairs_hook is not None:
kw['object_pairs_hook'] = object_pairs_hook
if parse_float is not None:
kw['parse_float'] = parse_float
if parse_int is not None:
kw['parse_int'] = parse_int
if parse_constant is not None:
kw['parse_constant'] = parse_constant
if use_decimal:
if parse_float is not None:
raise TypeError("use_decimal=True implies parse_float=Decimal")
kw['parse_float'] = Decimal
return cls(encoding=encoding, **kw).decode(s)
def _toggle_speedups(enabled):
from . import decoder as dec
from . import encoder as enc
from . import scanner as scan
c_make_encoder = _import_c_make_encoder()
if enabled:
dec.scanstring = dec.c_scanstring or dec.py_scanstring
enc.c_make_encoder = c_make_encoder
enc.encode_basestring_ascii = (enc.c_encode_basestring_ascii or
enc.py_encode_basestring_ascii)
scan.make_scanner = scan.c_make_scanner or scan.py_make_scanner
else:
dec.scanstring = dec.py_scanstring
enc.c_make_encoder = None
enc.encode_basestring_ascii = enc.py_encode_basestring_ascii
scan.make_scanner = scan.py_make_scanner
dec.make_scanner = scan.make_scanner
global _default_decoder
_default_decoder = JSONDecoder(
encoding=None,
object_hook=None,
object_pairs_hook=None,
)
global _default_encoder
_default_encoder = JSONEncoder(
skipkeys=False,
ensure_ascii=True,
check_circular=True,
allow_nan=True,
indent=None,
separators=None,
encoding='utf-8',
default=None,
)
def simple_first(kv):
"""Helper function to pass to item_sort_key to sort simple
elements to the top, then container elements.
"""
return (isinstance(kv[1], (list, dict, tuple)), kv[0])
================================================
FILE: simplejson/_speedups.c
================================================
/* -*- mode: C; c-file-style: "python"; c-basic-offset: 4 -*- */
#include "Python.h"
#include "structmember.h"
#if PY_MAJOR_VERSION >= 3
#define PyInt_FromSsize_t PyLong_FromSsize_t
#define PyInt_AsSsize_t PyLong_AsSsize_t
#define PyInt_Check(obj) 0
#define PyInt_CheckExact(obj) 0
#define JSON_UNICHR Py_UCS4
#define JSON_InternFromString PyUnicode_InternFromString
#define PyString_GET_SIZE PyUnicode_GET_LENGTH
#define PY2_UNUSED
#define PY3_UNUSED UNUSED
#else /* PY_MAJOR_VERSION >= 3 */
#define PY2_UNUSED UNUSED
#define PY3_UNUSED
#define PyBytes_Check PyString_Check
#define PyUnicode_READY(obj) 0
#define PyUnicode_KIND(obj) (sizeof(Py_UNICODE))
#define PyUnicode_DATA(obj) ((void *)(PyUnicode_AS_UNICODE(obj)))
#define PyUnicode_READ(kind, data, index) ((JSON_UNICHR)((const Py_UNICODE *)(data))[(index)])
#define PyUnicode_GET_LENGTH PyUnicode_GET_SIZE
#define JSON_UNICHR Py_UNICODE
#define JSON_InternFromString PyString_InternFromString
#endif /* PY_MAJOR_VERSION < 3 */
#if PY_VERSION_HEX < 0x02070000
#if !defined(PyOS_string_to_double)
#define PyOS_string_to_double json_PyOS_string_to_double
static double
json_PyOS_string_to_double(const char *s, char **endptr, PyObject *overflow_exception);
static double
json_PyOS_string_to_double(const char *s, char **endptr, PyObject *overflow_exception)
{
double x;
assert(endptr == NULL);
assert(overflow_exception == NULL);
PyFPE_START_PROTECT("json_PyOS_string_to_double", return -1.0;)
x = PyOS_ascii_atof(s);
PyFPE_END_PROTECT(x)
return x;
}
#endif
#endif /* PY_VERSION_HEX < 0x02070000 */
#if PY_VERSION_HEX < 0x02060000
#if !defined(Py_TYPE)
#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type)
#endif
#if !defined(Py_SIZE)
#define Py_SIZE(ob) (((PyVarObject*)(ob))->ob_size)
#endif
#if !defined(PyVarObject_HEAD_INIT)
#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
#endif
#endif /* PY_VERSION_HEX < 0x02060000 */
#ifdef __GNUC__
#define UNUSED __attribute__((__unused__))
#else
#define UNUSED
#endif
#define DEFAULT_ENCODING "utf-8"
#define PyScanner_Check(op) PyObject_TypeCheck(op, &PyScannerType)
#define PyScanner_CheckExact(op) (Py_TYPE(op) == &PyScannerType)
#define PyEncoder_Check(op) PyObject_TypeCheck(op, &PyEncoderType)
#define PyEncoder_CheckExact(op) (Py_TYPE(op) == &PyEncoderType)
#define JSON_ALLOW_NAN 1
#define JSON_IGNORE_NAN 2
static PyObject *JSON_Infinity = NULL;
static PyObject *JSON_NegInfinity = NULL;
static PyObject *JSON_NaN = NULL;
static PyObject *JSON_EmptyUnicode = NULL;
#if PY_MAJOR_VERSION < 3
static PyObject *JSON_EmptyStr = NULL;
#endif
static PyTypeObject PyScannerType;
static PyTypeObject PyEncoderType;
typedef struct {
PyObject *large_strings; /* A list of previously accumulated large strings */
PyObject *small_strings; /* Pending small strings */
} JSON_Accu;
static int
JSON_Accu_Init(JSON_Accu *acc);
static int
JSON_Accu_Accumulate(JSON_Accu *acc, PyObject *unicode);
static PyObject *
JSON_Accu_FinishAsList(JSON_Accu *acc);
static void
JSON_Accu_Destroy(JSON_Accu *acc);
#define ERR_EXPECTING_VALUE "Expecting value"
#define ERR_ARRAY_DELIMITER "Expecting ',' delimiter or ']'"
#define ERR_ARRAY_VALUE_FIRST "Expecting value or ']'"
#define ERR_OBJECT_DELIMITER "Expecting ',' delimiter or '}'"
#define ERR_OBJECT_PROPERTY "Expecting property name enclosed in double quotes"
#define ERR_OBJECT_PROPERTY_FIRST "Expecting property name enclosed in double quotes or '}'"
#define ERR_OBJECT_PROPERTY_DELIMITER "Expecting ':' delimiter"
#define ERR_STRING_UNTERMINATED "Unterminated string starting at"
#define ERR_STRING_CONTROL "Invalid control character %r at"
#define ERR_STRING_ESC1 "Invalid \\X escape sequence %r"
#define ERR_STRING_ESC4 "Invalid \\uXXXX escape sequence"
typedef struct _PyScannerObject {
PyObject_HEAD
PyObject *encoding;
PyObject *strict_bool;
int strict;
PyObject *object_hook;
PyObject *pairs_hook;
PyObject *parse_float;
PyObject *parse_int;
PyObject *parse_constant;
PyObject *memo;
} PyScannerObject;
static PyMemberDef scanner_members[] = {
{"encoding", T_OBJECT, offsetof(PyScannerObject, encoding), READONLY, "encoding"},
{"strict", T_OBJECT, offsetof(PyScannerObject, strict_bool), READONLY, "strict"},
{"object_hook", T_OBJECT, offsetof(PyScannerObject, object_hook), READONLY, "object_hook"},
{"object_pairs_hook", T_OBJECT, offsetof(PyScannerObject, pairs_hook), READONLY, "object_pairs_hook"},
{"parse_float", T_OBJECT, offsetof(PyScannerObject, parse_float), READONLY, "parse_float"},
{"parse_int", T_OBJECT, offsetof(PyScannerObject, parse_int), READONLY, "parse_int"},
{"parse_constant", T_OBJECT, offsetof(PyScannerObject, parse_constant), READONLY, "parse_constant"},
{NULL}
};
typedef struct _PyEncoderObject {
PyObject_HEAD
PyObject *markers;
PyObject *defaultfn;
PyObject *encoder;
PyObject *indent;
PyObject *key_separator;
PyObject *item_separator;
PyObject *sort_keys;
PyObject *key_memo;
PyObject *encoding;
PyObject *Decimal;
PyObject *skipkeys_bool;
int skipkeys;
int fast_encode;
/* 0, JSON_ALLOW_NAN, JSON_IGNORE_NAN */
int allow_or_ignore_nan;
int use_decimal;
int namedtuple_as_object;
int tuple_as_array;
int iterable_as_array;
PyObject *max_long_size;
PyObject *min_long_size;
PyObject *item_sort_key;
PyObject *item_sort_kw;
int for_json;
} PyEncoderObject;
static PyMemberDef encoder_members[] = {
{"markers", T_OBJECT, offsetof(PyEncoderObject, markers), READONLY, "markers"},
{"default", T_OBJECT, offsetof(PyEncoderObject, defaultfn), READONLY, "default"},
{"encoder", T_OBJECT, offsetof(PyEncoderObject, encoder), READONLY, "encoder"},
{"encoding", T_OBJECT, offsetof(PyEncoderObject, encoder), READONLY, "encoding"},
{"indent", T_OBJECT, offsetof(PyEncoderObject, indent), READONLY, "indent"},
{"key_separator", T_OBJECT, offsetof(PyEncoderObject, key_separator), READONLY, "key_separator"},
{"item_separator", T_OBJECT, offsetof(PyEncoderObject, item_separator), READONLY, "item_separator"},
{"sort_keys", T_OBJECT, offsetof(PyEncoderObject, sort_keys), READONLY, "sort_keys"},
/* Python 2.5 does not support T_BOOl */
{"skipkeys", T_OBJECT, offsetof(PyEncoderObject, skipkeys_bool), READONLY, "skipkeys"},
{"key_memo", T_OBJECT, offsetof(PyEncoderObject, key_memo), READONLY, "key_memo"},
{"item_sort_key", T_OBJECT, offsetof(PyEncoderObject, item_sort_key), READONLY, "item_sort_key"},
{"max_long_size", T_OBJECT, offsetof(PyEncoderObject, max_long_size), READONLY, "max_long_size"},
{"min_long_size", T_OBJECT, offsetof(PyEncoderObject, min_long_size), READONLY, "min_long_size"},
{NULL}
};
static PyObject *
join_list_unicode(PyObject *lst);
static PyObject *
JSON_ParseEncoding(PyObject *encoding);
static PyObject *
maybe_quote_bigint(PyEncoderObject* s, PyObject *encoded, PyObject *obj);
static Py_ssize_t
ascii_char_size(JSON_UNICHR c);
static Py_ssize_t
ascii_escape_char(JSON_UNICHR c, char *output, Py_ssize_t chars);
static PyObject *
ascii_escape_unicode(PyObject *pystr);
static PyObject *
ascii_escape_str(PyObject *pystr);
static PyObject *
py_encode_basestring_ascii(PyObject* self UNUSED, PyObject *pystr);
#if PY_MAJOR_VERSION < 3
static PyObject *
join_list_string(PyObject *lst);
static PyObject *
scan_once_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr);
static PyObject *
scanstring_str(PyObject *pystr, Py_ssize_t end, char *encoding, int strict, Py_ssize_t *next_end_ptr);
static PyObject *
_parse_object_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr);
#endif
static PyObject *
scanstring_unicode(PyObject *pystr, Py_ssize_t end, int strict, Py_ssize_t *next_end_ptr);
static PyObject *
scan_once_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr);
static PyObject *
_build_rval_index_tuple(PyObject *rval, Py_ssize_t idx);
static PyObject *
scanner_new(PyTypeObject *type, PyObject *args, PyObject *kwds);
static void
scanner_dealloc(PyObject *self);
static int
scanner_clear(PyObject *self);
static PyObject *
encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds);
static void
encoder_dealloc(PyObject *self);
static int
encoder_clear(PyObject *self);
static int
is_raw_json(PyObject *obj);
static PyObject *
encoder_stringify_key(PyEncoderObject *s, PyObject *key);
static int
encoder_listencode_list(PyEncoderObject *s, JSON_Accu *rval, PyObject *seq, Py_ssize_t indent_level);
static int
encoder_listencode_obj(PyEncoderObject *s, JSON_Accu *rval, PyObject *obj, Py_ssize_t indent_level);
static int
encoder_listencode_dict(PyEncoderObject *s, JSON_Accu *rval, PyObject *dct, Py_ssize_t indent_level);
static PyObject *
_encoded_const(PyObject *obj);
static void
raise_errmsg(char *msg, PyObject *s, Py_ssize_t end);
static PyObject *
encoder_encode_string(PyEncoderObject *s, PyObject *obj);
static int
_convertPyInt_AsSsize_t(PyObject *o, Py_ssize_t *size_ptr);
static PyObject *
_convertPyInt_FromSsize_t(Py_ssize_t *size_ptr);
static PyObject *
encoder_encode_float(PyEncoderObject *s, PyObject *obj);
static int
_is_namedtuple(PyObject *obj);
static int
_has_for_json_hook(PyObject *obj);
static PyObject *
moduleinit(void);
#define S_CHAR(c) (c >= ' ' && c <= '~' && c != '\\' && c != '"')
#define IS_WHITESPACE(c) (((c) == ' ') || ((c) == '\t') || ((c) == '\n') || ((c) == '\r'))
#define MIN_EXPANSION 6
static PyObject* RawJSONType = NULL;
static int
is_raw_json(PyObject *obj)
{
return PyObject_IsInstance(obj, RawJSONType) ? 1 : 0;
}
static int
JSON_Accu_Init(JSON_Accu *acc)
{
/* Lazily allocated */
acc->large_strings = NULL;
acc->small_strings = PyList_New(0);
if (acc->small_strings == NULL)
return -1;
return 0;
}
static int
flush_accumulator(JSON_Accu *acc)
{
Py_ssize_t nsmall = PyList_GET_SIZE(acc->small_strings);
if (nsmall) {
int ret;
PyObject *joined;
if (acc->large_strings == NULL) {
acc->large_strings = PyList_New(0);
if (acc->large_strings == NULL)
return -1;
}
#if PY_MAJOR_VERSION >= 3
joined = join_list_unicode(acc->small_strings);
#else /* PY_MAJOR_VERSION >= 3 */
joined = join_list_string(acc->small_strings);
#endif /* PY_MAJOR_VERSION < 3 */
if (joined == NULL)
return -1;
if (PyList_SetSlice(acc->small_strings, 0, nsmall, NULL)) {
Py_DECREF(joined);
return -1;
}
ret = PyList_Append(acc->large_strings, joined);
Py_DECREF(joined);
return ret;
}
return 0;
}
static int
JSON_Accu_Accumulate(JSON_Accu *acc, PyObject *unicode)
{
Py_ssize_t nsmall;
#if PY_MAJOR_VERSION >= 3
assert(PyUnicode_Check(unicode));
#else /* PY_MAJOR_VERSION >= 3 */
assert(PyString_Check(unicode) || PyUnicode_Check(unicode));
#endif /* PY_MAJOR_VERSION < 3 */
if (PyList_Append(acc->small_strings, unicode))
return -1;
nsmall = PyList_GET_SIZE(acc->small_strings);
/* Each item in a list of unicode objects has an overhead (in 64-bit
* builds) of:
* - 8 bytes for the list slot
* - 56 bytes for the header of the unicode object
* that is, 64 bytes. 100000 such objects waste more than 6MB
* compared to a single concatenated string.
*/
if (nsmall < 100000)
return 0;
return flush_accumulator(acc);
}
static PyObject *
JSON_Accu_FinishAsList(JSON_Accu *acc)
{
int ret;
PyObject *res;
ret = flush_accumulator(acc);
Py_CLEAR(acc->small_strings);
if (ret) {
Py_CLEAR(acc->large_strings);
return NULL;
}
res = acc->large_strings;
acc->large_strings = NULL;
if (res == NULL)
return PyList_New(0);
return res;
}
static void
JSON_Accu_Destroy(JSON_Accu *acc)
{
Py_CLEAR(acc->small_strings);
Py_CLEAR(acc->large_strings);
}
static int
IS_DIGIT(JSON_UNICHR c)
{
return c >= '0' && c <= '9';
}
static PyObject *
maybe_quote_bigint(PyEncoderObject* s, PyObject *encoded, PyObject *obj)
{
if (s->max_long_size != Py_None && s->min_long_size != Py_None) {
if (PyObject_RichCompareBool(obj, s->max_long_size, Py_GE) ||
PyObject_RichCompareBool(obj, s->min_long_size, Py_LE)) {
#if PY_MAJOR_VERSION >= 3
PyObject* quoted = PyUnicode_FromFormat("\"%U\"", encoded);
#else
PyObject* quoted = PyString_FromFormat("\"%s\"",
PyString_AsString(encoded));
#endif
Py_DECREF(encoded);
encoded = quoted;
}
}
return encoded;
}
static int
_is_namedtuple(PyObject *obj)
{
int rval = 0;
PyObject *_asdict = PyObject_GetAttrString(obj, "_asdict");
if (_asdict == NULL) {
PyErr_Clear();
return 0;
}
rval = PyCallable_Check(_asdict);
Py_DECREF(_asdict);
return rval;
}
static int
_has_for_json_hook(PyObject *obj)
{
int rval = 0;
PyObject *for_json = PyObject_GetAttrString(obj, "for_json");
if (for_json == NULL) {
PyErr_Clear();
return 0;
}
rval = PyCallable_Check(for_json);
Py_DECREF(for_json);
return rval;
}
static int
_convertPyInt_AsSsize_t(PyObject *o, Py_ssize_t *size_ptr)
{
/* PyObject to Py_ssize_t converter */
*size_ptr = PyInt_AsSsize_t(o);
if (*size_ptr == -1 && PyErr_Occurred())
return 0;
return 1;
}
static PyObject *
_convertPyInt_FromSsize_t(Py_ssize_t *size_ptr)
{
/* Py_ssize_t to PyObject converter */
return PyInt_FromSsize_t(*size_ptr);
}
static Py_ssize_t
ascii_escape_char(JSON_UNICHR c, char *output, Py_ssize_t chars)
{
/* Escape unicode code point c to ASCII escape sequences
in char *output. output must have at least 12 bytes unused to
accommodate an escaped surrogate pair "\uXXXX\uXXXX" */
if (S_CHAR(c)) {
output[chars++] = (char)c;
}
else {
output[chars++] = '\\';
switch (c) {
case '\\': output[chars++] = (char)c; break;
case '"': output[chars++] = (char)c; break;
case '\b': output[chars++] = 'b'; break;
case '\f': output[chars++] = 'f'; break;
case '\n': output[chars++] = 'n'; break;
case '\r': output[chars++] = 'r'; break;
case '\t': output[chars++] = 't'; break;
default:
#if PY_MAJOR_VERSION >= 3 || defined(Py_UNICODE_WIDE)
if (c >= 0x10000) {
/* UTF-16 surrogate pair */
JSON_UNICHR v = c - 0x10000;
c = 0xd800 | ((v >> 10) & 0x3ff);
output[chars++] = 'u';
output[chars++] = "0123456789abcdef"[(c >> 12) & 0xf];
output[chars++] = "0123456789abcdef"[(c >> 8) & 0xf];
output[chars++] = "0123456789abcdef"[(c >> 4) & 0xf];
output[chars++] = "0123456789abcdef"[(c ) & 0xf];
c = 0xdc00 | (v & 0x3ff);
output[chars++] = '\\';
}
#endif
output[chars++] = 'u';
output[chars++] = "0123456789abcdef"[(c >> 12) & 0xf];
output[chars++] = "0123456789abcdef"[(c >> 8) & 0xf];
output[chars++] = "0123456789abcdef"[(c >> 4) & 0xf];
output[chars++] = "0123456789abcdef"[(c ) & 0xf];
}
}
return chars;
}
static Py_ssize_t
ascii_char_size(JSON_UNICHR c)
{
if (S_CHAR(c)) {
return 1;
}
else if (c == '\\' ||
c == '"' ||
c == '\b' ||
c == '\f' ||
c == '\n' ||
c == '\r' ||
c == '\t') {
return 2;
}
#if PY_MAJOR_VERSION >= 3 || defined(Py_UNICODE_WIDE)
else if (c >= 0x10000U) {
return 2 * MIN_EXPANSION;
}
#endif
else {
return MIN_EXPANSION;
}
}
static PyObject *
ascii_escape_unicode(PyObject *pystr)
{
/* Take a PyUnicode pystr and return a new ASCII-only escaped PyString */
Py_ssize_t i;
Py_ssize_t input_chars = PyUnicode_GET_LENGTH(pystr);
Py_ssize_t output_size = 2;
Py_ssize_t chars;
PY2_UNUSED int kind = PyUnicode_KIND(pystr);
void *data = PyUnicode_DATA(pystr);
PyObject *rval;
char *output;
output_size = 2;
for (i = 0; i < input_chars; i++) {
output_size += ascii_char_size(PyUnicode_READ(kind, data, i));
}
#if PY_MAJOR_VERSION >= 3
rval = PyUnicode_New(output_size, 127);
if (rval == NULL) {
return NULL;
}
assert(PyUnicode_KIND(rval) == PyUnicode_1BYTE_KIND);
output = (char *)PyUnicode_DATA(rval);
#else
rval = PyString_FromStringAndSize(NULL, output_size);
if (rval == NULL) {
return NULL;
}
output = PyString_AS_STRING(rval);
#endif
chars = 0;
output[chars++] = '"';
for (i = 0; i < input_chars; i++) {
chars = ascii_escape_char(PyUnicode_READ(kind, data, i), output, chars);
}
output[chars++] = '"';
assert(chars == output_size);
return rval;
}
#if PY_MAJOR_VERSION >= 3
static PyObject *
ascii_escape_str(PyObject *pystr)
{
PyObject *rval;
PyObject *input = PyUnicode_DecodeUTF8(PyBytes_AS_STRING(pystr), PyBytes_GET_SIZE(pystr), NULL);
if (input == NULL)
return NULL;
rval = ascii_escape_unicode(input);
Py_DECREF(input);
return rval;
}
#else /* PY_MAJOR_VERSION >= 3 */
static PyObject *
ascii_escape_str(PyObject *pystr)
{
/* Take a PyString pystr and return a new ASCII-only escaped PyString */
Py_ssize_t i;
Py_ssize_t input_chars;
Py_ssize_t output_size;
Py_ssize_t chars;
PyObject *rval;
char *output;
char *input_str;
input_chars = PyString_GET_SIZE(pystr);
input_str = PyString_AS_STRING(pystr);
output_size = 2;
/* Fast path for a string that's already ASCII */
for (i = 0; i < input_chars; i++) {
JSON_UNICHR c = (JSON_UNICHR)input_str[i];
if (c > 0x7f) {
/* We hit a non-ASCII character, bail to unicode mode */
PyObject *uni;
uni = PyUnicode_DecodeUTF8(input_str, input_chars, "strict");
if (uni == NULL) {
return NULL;
}
rval = ascii_escape_unicode(uni);
Py_DECREF(uni);
return rval;
}
output_size += ascii_char_size(c);
}
rval = PyString_FromStringAndSize(NULL, output_size);
if (rval == NULL) {
return NULL;
}
chars = 0;
output = PyString_AS_STRING(rval);
output[chars++] = '"';
for (i = 0; i < input_chars; i++) {
chars = ascii_escape_char((JSON_UNICHR)input_str[i], output, chars);
}
output[chars++] = '"';
assert(chars == output_size);
return rval;
}
#endif /* PY_MAJOR_VERSION < 3 */
static PyObject *
encoder_stringify_key(PyEncoderObject *s, PyObject *key)
{
if (PyUnicode_Check(key)) {
Py_INCREF(key);
return key;
}
#if PY_MAJOR_VERSION >= 3
else if (PyBytes_Check(key) && s->encoding != NULL) {
const char *encoding = PyUnicode_AsUTF8(s->encoding);
if (encoding == NULL)
return NULL;
return PyUnicode_Decode(
PyBytes_AS_STRING(key),
PyBytes_GET_SIZE(key),
encoding,
NULL);
}
#else /* PY_MAJOR_VERSION >= 3 */
else if (PyString_Check(key)) {
Py_INCREF(key);
return key;
}
#endif /* PY_MAJOR_VERSION < 3 */
else if (PyFloat_Check(key)) {
return encoder_encode_float(s, key);
}
else if (key == Py_True || key == Py_False || key == Py_None) {
/* This must come before the PyInt_Check because
True and False are also 1 and 0.*/
return _encoded_const(key);
}
else if (PyInt_Check(key) || PyLong_Check(key)) {
if (!(PyInt_CheckExact(key) || PyLong_CheckExact(key))) {
/* See #118, do not trust custom str/repr */
PyObject *res;
PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyLong_Type, key, NULL);
if (tmp == NULL) {
return NULL;
}
res = PyObject_Str(tmp);
Py_DECREF(tmp);
return res;
}
else {
return PyObject_Str(key);
}
}
else if (s->use_decimal && PyObject_TypeCheck(key, (PyTypeObject *)s->Decimal)) {
return PyObject_Str(key);
}
if (s->skipkeys) {
Py_INCREF(Py_None);
return Py_None;
}
PyErr_Format(PyExc_TypeError,
"keys must be str, int, float, bool or None, "
"not %.100s", key->ob_type->tp_name);
return NULL;
}
static PyObject *
encoder_dict_iteritems(PyEncoderObject *s, PyObject *dct)
{
PyObject *items;
PyObject *iter = NULL;
PyObject *lst = NULL;
PyObject *item = NULL;
PyObject *kstr = NULL;
PyObject *sortfun = NULL;
PyObject *sortres;
static PyObject *sortargs = NULL;
if (sortargs == NULL) {
sortargs = PyTuple_New(0);
if (sortargs == NULL)
return NULL;
}
if (PyDict_CheckExact(dct))
items = PyDict_Items(dct);
else
items = PyMapping_Items(dct);
if (items == NULL)
return NULL;
iter = PyObject_GetIter(items);
Py_DECREF(items);
if (iter == NULL)
return NULL;
if (s->item_sort_kw == Py_None)
return iter;
lst = PyList_New(0);
if (lst == NULL)
goto bail;
while ((item = PyIter_Next(iter))) {
PyObject *key, *value;
if (!PyTuple_Check(item) || Py_SIZE(item) != 2) {
PyErr_SetString(PyExc_ValueError, "items must return 2-tuples");
goto bail;
}
key = PyTuple_GET_ITEM(item, 0);
if (key == NULL)
goto bail;
#if PY_MAJOR_VERSION < 3
else if (PyString_Check(key)) {
/* item can be added as-is */
}
#endif /* PY_MAJOR_VERSION < 3 */
else if (PyUnicode_Check(key)) {
/* item can be added as-is */
}
else {
PyObject *tpl;
kstr = encoder_stringify_key(s, key);
if (kstr == NULL)
goto bail;
else if (kstr == Py_None) {
/* skipkeys */
Py_DECREF(kstr);
continue;
}
value = PyTuple_GET_ITEM(item, 1);
if (value == NULL)
goto bail;
tpl = PyTuple_Pack(2, kstr, value);
if (tpl == NULL)
goto bail;
Py_CLEAR(kstr);
Py_DECREF(item);
item = tpl;
}
if (PyList_Append(lst, item))
goto bail;
Py_DECREF(item);
}
Py_CLEAR(iter);
if (PyErr_Occurred())
goto bail;
sortfun = PyObject_GetAttrString(lst, "sort");
if (sortfun == NULL)
goto bail;
sortres = PyObject_Call(sortfun, sortargs, s->item_sort_kw);
if (!sortres)
goto bail;
Py_DECREF(sortres);
Py_CLEAR(sortfun);
iter = PyObject_GetIter(lst);
Py_CLEAR(lst);
return iter;
bail:
Py_XDECREF(sortfun);
Py_XDECREF(kstr);
Py_XDECREF(item);
Py_XDECREF(lst);
Py_XDECREF(iter);
return NULL;
}
/* Use JSONDecodeError exception to raise a nice looking ValueError subclass */
static PyObject *JSONDecodeError = NULL;
static void
raise_errmsg(char *msg, PyObject *s, Py_ssize_t end)
{
PyObject *exc = PyObject_CallFunction(JSONDecodeError, "(zOO&)", msg, s, _convertPyInt_FromSsize_t, &end);
if (exc) {
PyErr_SetObject(JSONDecodeError, exc);
Py_DECREF(exc);
}
}
static PyObject *
join_list_unicode(PyObject *lst)
{
/* return u''.join(lst) */
return PyUnicode_Join(JSON_EmptyUnicode, lst);
}
#if PY_MAJOR_VERSION >= 3
#define join_list_string join_list_unicode
#else /* PY_MAJOR_VERSION >= 3 */
static PyObject *
join_list_string(PyObject *lst)
{
/* return ''.join(lst) */
static PyObject *joinfn = NULL;
if (joinfn == NULL) {
joinfn = PyObject_GetAttrString(JSON_EmptyStr, "join");
if (joinfn == NULL)
return NULL;
}
return PyObject_CallFunctionObjArgs(joinfn, lst, NULL);
}
#endif /* PY_MAJOR_VERSION < 3 */
static PyObject *
_build_rval_index_tuple(PyObject *rval, Py_ssize_t idx)
{
/* return (rval, idx) tuple, stealing reference to rval */
PyObject *tpl;
PyObject *pyidx;
/*
steal a reference to rval, returns (rval, idx)
*/
if (rval == NULL) {
assert(PyErr_Occurred());
return NULL;
}
pyidx = PyInt_FromSsize_t(idx);
if (pyidx == NULL) {
Py_DECREF(rval);
return NULL;
}
tpl = PyTuple_New(2);
if (tpl == NULL) {
Py_DECREF(pyidx);
Py_DECREF(rval);
return NULL;
}
PyTuple_SET_ITEM(tpl, 0, rval);
PyTuple_SET_ITEM(tpl, 1, pyidx);
return tpl;
}
#define APPEND_OLD_CHUNK \
if (chunk != NULL) { \
if (chunks == NULL) { \
chunks = PyList_New(0); \
if (chunks == NULL) { \
goto bail; \
} \
} \
if (PyList_Append(chunks, chunk)) { \
goto bail; \
} \
Py_CLEAR(chunk); \
}
#if PY_MAJOR_VERSION < 3
static PyObject *
scanstring_str(PyObject *pystr, Py_ssize_t end, char *encoding, int strict, Py_ssize_t *next_end_ptr)
{
/* Read the JSON string from PyString pystr.
end is the index of the first character after the quote.
encoding is the encoding of pystr (must be an ASCII superset)
if strict is zero then literal control characters are allowed
*next_end_ptr is a return-by-reference index of the character
after the end quote
Return value is a new PyString (if ASCII-only) or PyUnicode
*/
PyObject *rval;
Py_ssize_t len = PyString_GET_SIZE(pystr);
Py_ssize_t begin = end - 1;
Py_ssize_t next = begin;
int has_unicode = 0;
char *buf = PyString_AS_STRING(pystr);
PyObject *chunks = NULL;
PyObject *chunk = NULL;
PyObject *strchunk = NULL;
if (len == end) {
raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin);
goto bail;
}
else if (end < 0 || len < end) {
PyErr_SetString(PyExc_ValueError, "end is out of bounds");
goto bail;
}
while (1) {
/* Find the end of the string or the next escape */
Py_UNICODE c = 0;
for (next = end; next < len; next++) {
c = (unsigned char)buf[next];
if (c == '"' || c == '\\') {
break;
}
else if (strict && c <= 0x1f) {
raise_errmsg(ERR_STRING_CONTROL, pystr, next);
goto bail;
}
else if (c > 0x7f) {
has_unicode = 1;
}
}
if (!(c == '"' || c == '\\')) {
raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin);
goto bail;
}
/* Pick up this chunk if it's not zero length */
if (next != end) {
APPEND_OLD_CHUNK
strchunk = PyString_FromStringAndSize(&buf[end], next - end);
if (strchunk == NULL) {
goto bail;
}
if (has_unicode) {
chunk = PyUnicode_FromEncodedObject(strchunk, encoding, NULL);
Py_DECREF(strchunk);
if (chunk == NULL) {
goto bail;
}
}
else {
chunk = strchunk;
}
}
next++;
if (c == '"') {
end = next;
break;
}
if (next == len) {
raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin);
goto bail;
}
c = buf[next];
if (c != 'u') {
/* Non-unicode backslash escapes */
end = next + 1;
switch (c) {
case '"': break;
case '\\': break;
case '/': break;
case 'b': c = '\b'; break;
case 'f': c = '\f'; break;
case 'n': c = '\n'; break;
case 'r': c = '\r'; break;
case 't': c = '\t'; break;
default: c = 0;
}
if (c == 0) {
raise_errmsg(ERR_STRING_ESC1, pystr, end - 2);
goto bail;
}
}
else {
c = 0;
next++;
end = next + 4;
if (end >= len) {
raise_errmsg(ERR_STRING_ESC4, pystr, next - 1);
goto bail;
}
/* Decode 4 hex digits */
for (; next < end; next++) {
JSON_UNICHR digit = (JSON_UNICHR)buf[next];
c <<= 4;
switch (digit) {
case '0': case '1': case '2': case '3': case '4':
case '5': case '6': case '7': case '8': case '9':
c |= (digit - '0'); break;
case 'a': case 'b': case 'c': case 'd': case 'e':
case 'f':
c |= (digit - 'a' + 10); break;
case 'A': case 'B': case 'C': case 'D': case 'E':
case 'F':
c |= (digit - 'A' + 10); break;
default:
raise_errmsg(ERR_STRING_ESC4, pystr, end - 5);
goto bail;
}
}
#if defined(Py_UNICODE_WIDE)
/* Surrogate pair */
if ((c & 0xfc00) == 0xd800) {
if (end + 6 < len && buf[next] == '\\' && buf[next+1] == 'u') {
JSON_UNICHR c2 = 0;
end += 6;
/* Decode 4 hex digits */
for (next += 2; next < end; next++) {
c2 <<= 4;
JSON_UNICHR digit = buf[next];
switch (digit) {
case '0': case '1': case '2': case '3': case '4':
case '5': case '6': case '7': case '8': case '9':
c2 |= (digit - '0'); break;
case 'a': case 'b': case 'c': case 'd': case 'e':
case 'f':
c2 |= (digit - 'a' + 10); break;
case 'A': case 'B': case 'C': case 'D': case 'E':
case 'F':
c2 |= (digit - 'A' + 10); break;
default:
raise_errmsg(ERR_STRING_ESC4, pystr, end - 5);
goto bail;
}
}
if ((c2 & 0xfc00) != 0xdc00) {
/* not a low surrogate, rewind */
end -= 6;
next = end;
}
else {
c = 0x10000 + (((c - 0xd800) << 10) | (c2 - 0xdc00));
}
}
}
#endif /* Py_UNICODE_WIDE */
}
if (c > 0x7f) {
has_unicode = 1;
}
APPEND_OLD_CHUNK
if (has_unicode) {
chunk = PyUnicode_FromOrdinal(c);
if (chunk == NULL) {
goto bail;
}
}
else {
char c_char = Py_CHARMASK(c);
chunk = PyString_FromStringAndSize(&c_char, 1);
if (chunk == NULL) {
goto bail;
}
}
}
if (chunks == NULL) {
if (chunk != NULL)
rval = chunk;
else {
rval = JSON_EmptyStr;
Py_INCREF(rval);
}
}
else {
APPEND_OLD_CHUNK
rval = join_list_string(chunks);
if (rval == NULL) {
goto bail;
}
Py_CLEAR(chunks);
}
*next_end_ptr = end;
return rval;
bail:
*next_end_ptr = -1;
Py_XDECREF(chunk);
Py_XDECREF(chunks);
return NULL;
}
#endif /* PY_MAJOR_VERSION < 3 */
static PyObject *
scanstring_unicode(PyObject *pystr, Py_ssize_t end, int strict, Py_ssize_t *next_end_ptr)
{
/* Read the JSON string from PyUnicode pystr.
end is the index of the first character after the quote.
if strict is zero then literal control characters are allowed
*next_end_ptr is a return-by-reference index of the character
after the end quote
Return value is a new PyUnicode
*/
PyObject *rval;
Py_ssize_t begin = end - 1;
Py_ssize_t next = begin;
PY2_UNUSED int kind = PyUnicode_KIND(pystr);
Py_ssize_t len = PyUnicode_GET_LENGTH(pystr);
void *buf = PyUnicode_DATA(pystr);
PyObject *chunks = NULL;
PyObject *chunk = NULL;
if (len == end) {
raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin);
goto bail;
}
else if (end < 0 || len < end) {
PyErr_SetString(PyExc_ValueError, "end is out of bounds");
goto bail;
}
while (1) {
/* Find the end of the string or the next escape */
JSON_UNICHR c = 0;
for (next = end; next < len; next++) {
c = PyUnicode_READ(kind, buf, next);
if (c == '"' || c == '\\') {
break;
}
else if (strict && c <= 0x1f) {
raise_errmsg(ERR_STRING_CONTROL, pystr, next);
goto bail;
}
}
if (!(c == '"' || c == '\\')) {
raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin);
goto bail;
}
/* Pick up this chunk if it's not zero length */
if (next != end) {
APPEND_OLD_CHUNK
#if PY_MAJOR_VERSION < 3
chunk = PyUnicode_FromUnicode(&((const Py_UNICODE *)buf)[end], next - end);
#else
chunk = PyUnicode_Substring(pystr, end, next);
#endif
if (chunk == NULL) {
goto bail;
}
}
next++;
if (c == '"') {
end = next;
break;
}
if (next == len) {
raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin);
goto bail;
}
c = PyUnicode_READ(kind, buf, next);
if (c != 'u') {
/* Non-unicode backslash escapes */
end = next + 1;
switch (c) {
case '"': break;
case '\\': break;
case '/': break;
case 'b': c = '\b'; break;
case 'f': c = '\f'; break;
case 'n': c = '\n'; break;
case 'r': c = '\r'; break;
case 't': c = '\t'; break;
default: c = 0;
}
if (c == 0) {
raise_errmsg(ERR_STRING_ESC1, pystr, end - 2);
goto bail;
}
}
else {
c = 0;
next++;
end = next + 4;
if (end >= len) {
raise_errmsg(ERR_STRING_ESC4, pystr, next - 1);
goto bail;
}
/* Decode 4 hex digits */
for (; next < end; next++) {
JSON_UNICHR digit = PyUnicode_READ(kind, buf, next);
c <<= 4;
switch (digit) {
case '0': case '1': case '2': case '3': case '4':
case '5': case '6': case '7': case '8': case '9':
c |= (digit - '0'); break;
case 'a': case 'b': case 'c': case 'd': case 'e':
case 'f':
c |= (digit - 'a' + 10); break;
case 'A': case 'B': case 'C': case 'D': case 'E':
case 'F':
c |= (digit - 'A' + 10); break;
default:
raise_errmsg(ERR_STRING_ESC4, pystr, end - 5);
goto bail;
}
}
#if PY_MAJOR_VERSION >= 3 || defined(Py_UNICODE_WIDE)
/* Surrogate pair */
if ((c & 0xfc00) == 0xd800) {
JSON_UNICHR c2 = 0;
if (end + 6 < len &&
PyUnicode_READ(kind, buf, next) == '\\' &&
PyUnicode_READ(kind, buf, next + 1) == 'u') {
end += 6;
/* Decode 4 hex digits */
for (next += 2; next < end; next++) {
JSON_UNICHR digit = PyUnicode_READ(kind, buf, next);
c2 <<= 4;
switch (digit) {
case '0': case '1': case '2': case '3': case '4':
case '5': case '6': case '7': case '8': case '9':
c2 |= (digit - '0'); break;
case 'a': case 'b': case 'c': case 'd': case 'e':
case 'f':
c2 |= (digit - 'a' + 10); break;
case 'A': case 'B': case 'C': case 'D': case 'E':
case 'F':
c2 |= (digit - 'A' + 10); break;
default:
raise_errmsg(ERR_STRING_ESC4, pystr, end - 5);
goto bail;
}
}
if ((c2 & 0xfc00) != 0xdc00) {
/* not a low surrogate, rewind */
end -= 6;
next = end;
}
else {
c = 0x10000 + (((c - 0xd800) << 10) | (c2 - 0xdc00));
}
}
}
#endif
}
APPEND_OLD_CHUNK
chunk = PyUnicode_FromOrdinal(c);
if (chunk == NULL) {
goto bail;
}
}
if (chunks == NULL) {
if (chunk != NULL)
rval = chunk;
else {
rval = JSON_EmptyUnicode;
Py_INCREF(rval);
}
}
else {
APPEND_OLD_CHUNK
rval = join_list_unicode(chunks);
if (rval == NULL) {
goto bail;
}
Py_CLEAR(chunks);
}
*next_end_ptr = end;
return rval;
bail:
*next_end_ptr = -1;
Py_XDECREF(chunk);
Py_XDECREF(chunks);
return NULL;
}
PyDoc_STRVAR(pydoc_scanstring,
"scanstring(basestring, end, encoding, strict=True) -> (str, end)\n"
"\n"
"Scan the string s for a JSON string. End is the index of the\n"
"character in s after the quote that started the JSON string.\n"
"Unescapes all valid JSON string escape sequences and raises ValueError\n"
"on attempt to decode an invalid string. If strict is False then literal\n"
"control characters are allowed in the string.\n"
"\n"
"Returns a tuple of the decoded string and the index of the character in s\n"
"after the end quote."
);
static PyObject *
py_scanstring(PyObject* self UNUSED, PyObject *args)
{
PyObject *pystr;
PyObject *rval;
Py_ssize_t end;
Py_ssize_t next_end = -1;
char *encoding = NULL;
int strict = 1;
if (!PyArg_ParseTuple(args, "OO&|zi:scanstring", &pystr, _convertPyInt_AsSsize_t, &end, &encoding, &strict)) {
return NULL;
}
if (encoding == NULL) {
encoding = DEFAULT_ENCODING;
}
if (PyUnicode_Check(pystr)) {
if (PyUnicode_READY(pystr))
return NULL;
rval = scanstring_unicode(pystr, end, strict, &next_end);
}
#if PY_MAJOR_VERSION < 3
/* Using a bytes input is unsupported for scanning in Python 3.
It is coerced to str in the decoder before it gets here. */
else if (PyString_Check(pystr)) {
rval = scanstring_str(pystr, end, encoding, strict, &next_end);
}
#endif
else {
PyErr_Format(PyExc_TypeError,
"first argument must be a string, not %.80s",
Py_TYPE(pystr)->tp_name);
return NULL;
}
return _build_rval_index_tuple(rval, next_end);
}
PyDoc_STRVAR(pydoc_encode_basestring_ascii,
"encode_basestring_ascii(basestring) -> str\n"
"\n"
"Return an ASCII-only JSON representation of a Python string"
);
static PyObject *
py_encode_basestring_ascii(PyObject* self UNUSED, PyObject *pystr)
{
/* Return an ASCII-only JSON representation of a Python string */
/* METH_O */
if (PyBytes_Check(pystr)) {
return ascii_escape_str(pystr);
}
else if (PyUnicode_Check(pystr)) {
if (PyUnicode_READY(pystr))
return NULL;
return ascii_escape_unicode(pystr);
}
else {
PyErr_Format(PyExc_TypeError,
"first argument must be a string, not %.80s",
Py_TYPE(pystr)->tp_name);
return NULL;
}
}
static void
scanner_dealloc(PyObject *self)
{
/* bpo-31095: UnTrack is needed before calling any callbacks */
PyObject_GC_UnTrack(self);
scanner_clear(self);
Py_TYPE(self)->tp_free(self);
}
static int
scanner_traverse(PyObject *self, visitproc visit, void *arg)
{
PyScannerObject *s;
assert(PyScanner_Check(self));
s = (PyScannerObject *)self;
Py_VISIT(s->encoding);
Py_VISIT(s->strict_bool);
Py_VISIT(s->object_hook);
Py_VISIT(s->pairs_hook);
Py_VISIT(s->parse_float);
Py_VISIT(s->parse_int);
Py_VISIT(s->parse_constant);
Py_VISIT(s->memo);
return 0;
}
static int
scanner_clear(PyObject *self)
{
PyScannerObject *s;
assert(PyScanner_Check(self));
s = (PyScannerObject *)self;
Py_CLEAR(s->encoding);
Py_CLEAR(s->strict_bool);
Py_CLEAR(s->object_hook);
Py_CLEAR(s->pairs_hook);
Py_CLEAR(s->parse_float);
Py_CLEAR(s->parse_int);
Py_CLEAR(s->parse_constant);
Py_CLEAR(s->memo);
return 0;
}
#if PY_MAJOR_VERSION < 3
static PyObject *
_parse_object_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr)
{
/* Read a JSON object from PyString pystr.
idx is the index of the first character after the opening curly brace.
*next_idx_ptr is a return-by-reference index to the first character after
the closing curly brace.
Returns a new PyObject (usually a dict, but object_hook or
object_pairs_hook can change that)
*/
char *str = PyString_AS_STRING(pystr);
Py_ssize_t end_idx = PyString_GET_SIZE(pystr) - 1;
PyObject *rval = NULL;
PyObject *pairs = NULL;
PyObject *item;
PyObject *key = NULL;
PyObject *val = NULL;
char *encoding = PyString_AS_STRING(s->encoding);
int has_pairs_hook = (s->pairs_hook != Py_None);
int did_parse = 0;
Py_ssize_t next_idx;
if (has_pairs_hook) {
pairs = PyList_New(0);
if (pairs == NULL)
return NULL;
}
else {
rval = PyDict_New();
if (rval == NULL)
return NULL;
}
/* skip whitespace after { */
while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++;
/* only loop if the object is non-empty */
if (idx <= end_idx && str[idx] != '}') {
int trailing_delimiter = 0;
while (idx <= end_idx) {
PyObject *memokey;
trailing_delimiter = 0;
/* read key */
if (str[idx] != '"') {
raise_errmsg(ERR_OBJECT_PROPERTY, pystr, idx);
goto bail;
}
key = scanstring_str(pystr, idx + 1, encoding, s->strict, &next_idx);
if (key == NULL)
goto bail;
memokey = PyDict_GetItem(s->memo, key);
if (memokey != NULL) {
Py_INCREF(memokey);
Py_DECREF(key);
key = memokey;
}
else {
if (PyDict_SetItem(s->memo, key, key) < 0)
goto bail;
}
idx = next_idx;
/* skip whitespace between key and : delimiter, read :, skip whitespace */
while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++;
if (idx > end_idx || str[idx] != ':') {
raise_errmsg(ERR_OBJECT_PROPERTY_DELIMITER, pystr, idx);
goto bail;
}
idx++;
while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++;
/* read any JSON data type */
val = scan_once_str(s, pystr, idx, &next_idx);
if (val == NULL)
goto bail;
if (has_pairs_hook) {
item = PyTuple_Pack(2, key, val);
if (item == NULL)
goto bail;
Py_CLEAR(key);
Py_CLEAR(val);
if (PyList_Append(pairs, item) == -1) {
Py_DECREF(item);
goto bail;
}
Py_DECREF(item);
}
else {
if (PyDict_SetItem(rval, key, val) < 0)
goto bail;
Py_CLEAR(key);
Py_CLEAR(val);
}
idx = next_idx;
/* skip whitespace before } or , */
while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++;
/* bail if the object is closed or we didn't get the , delimiter */
did_parse = 1;
if (idx > end_idx) break;
if (str[idx] == '}') {
break;
}
else if (str[idx] != ',') {
raise_errmsg(ERR_OBJECT_DELIMITER, pystr, idx);
goto bail;
}
idx++;
/* skip whitespace after , delimiter */
while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++;
trailing_delimiter = 1;
}
if (trailing_delimiter) {
raise_errmsg(ERR_OBJECT_PROPERTY, pystr, idx);
goto bail;
}
}
/* verify that idx < end_idx, str[idx] should be '}' */
if (idx > end_idx || str[idx] != '}') {
if (did_parse) {
raise_errmsg(ERR_OBJECT_DELIMITER, pystr, idx);
} else {
raise_errmsg(ERR_OBJECT_PROPERTY_FIRST, pystr, idx);
}
goto bail;
}
/* if pairs_hook is not None: rval = object_pairs_hook(pairs) */
if (s->pairs_hook != Py_None) {
val = PyObject_CallFunctionObjArgs(s->pairs_hook, pairs, NULL);
if (val == NULL)
goto bail;
Py_DECREF(pairs);
*next_idx_ptr = idx + 1;
return val;
}
/* if object_hook is not None: rval = object_hook(rval) */
if (s->object_hook != Py_None) {
val = PyObject_CallFunctionObjArgs(s->object_hook, rval, NULL);
if (val == NULL)
goto bail;
Py_DECREF(rval);
rval = val;
val = NULL;
}
*next_idx_ptr = idx + 1;
return rval;
bail:
Py_XDECREF(rval);
Py_XDECREF(key);
Py_XDECREF(val);
Py_XDECREF(pairs);
return NULL;
}
#endif /* PY_MAJOR_VERSION < 3 */
static PyObject *
_parse_object_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr)
{
/* Read a JSON object from PyUnicode pystr.
idx is the index of the first character after the opening curly brace.
*next_idx_ptr is a return-by-reference index to the first character after
the closing curly brace.
Returns a new PyObject (usually a dict, but object_hook can change that)
*/
void *str = PyUnicode_DATA(pystr);
Py_ssize_t end_idx = PyUnicode_GET_LENGTH(pystr) - 1;
PY2_UNUSED int kind = PyUnicode_KIND(pystr);
PyObject *rval = NULL;
PyObject *pairs = NULL;
PyObject *item;
PyObject *key = NULL;
PyObject *val = NULL;
int has_pairs_hook = (s->pairs_hook != Py_None);
int did_parse = 0;
Py_ssize_t next_idx;
if (has_pairs_hook) {
pairs = PyList_New(0);
if (pairs == NULL)
return NULL;
}
else {
rval = PyDict_New();
if (rval == NULL)
return NULL;
}
/* skip whitespace after { */
while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++;
/* only loop if the object is non-empty */
if (idx <= end_idx && PyUnicode_READ(kind, str, idx) != '}') {
int trailing_delimiter = 0;
while (idx <= end_idx) {
PyObject *memokey;
trailing_delimiter = 0;
/* read key */
if (PyUnicode_READ(kind, str, idx) != '"') {
raise_errmsg(ERR_OBJECT_PROPERTY, pystr, idx);
goto bail;
}
key = scanstring_unicode(pystr, idx + 1, s->strict, &next_idx);
if (key == NULL)
goto bail;
memokey = PyDict_GetItem(s->memo, key);
if (memokey != NULL) {
Py_INCREF(memokey);
Py_DECREF(key);
key = memokey;
}
else {
if (PyDict_SetItem(s->memo, key, key) < 0)
goto bail;
}
idx = next_idx;
/* skip whitespace between key and : delimiter, read :, skip
whitespace */
while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++;
if (idx > end_idx || PyUnicode_READ(kind, str, idx) != ':') {
raise_errmsg(ERR_OBJECT_PROPERTY_DELIMITER, pystr, idx);
goto bail;
}
idx++;
while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++;
/* read any JSON term */
val = scan_once_unicode(s, pystr, idx, &next_idx);
if (val == NULL)
goto bail;
if (has_pairs_hook) {
item = PyTuple_Pack(2, key, val);
if (item == NULL)
goto bail;
Py_CLEAR(key);
Py_CLEAR(val);
if (PyList_Append(pairs, item) == -1) {
Py_DECREF(item);
goto bail;
}
Py_DECREF(item);
}
else {
if (PyDict_SetItem(rval, key, val) < 0)
goto bail;
Py_CLEAR(key);
Py_CLEAR(val);
}
idx = next_idx;
/* skip whitespace before } or , */
while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++;
/* bail if the object is closed or we didn't get the ,
delimiter */
did_parse = 1;
if (idx > end_idx) break;
if (PyUnicode_READ(kind, str, idx) == '}') {
break;
}
else if (PyUnicode_READ(kind, str, idx) != ',') {
raise_errmsg(ERR_OBJECT_DELIMITER, pystr, idx);
goto bail;
}
idx++;
/* skip whitespace after , delimiter */
while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++;
trailing_delimiter = 1;
}
if (trailing_delimiter) {
raise_errmsg(ERR_OBJECT_PROPERTY, pystr, idx);
goto bail;
}
}
/* verify that idx < end_idx, str[idx] should be '}' */
if (idx > end_idx || PyUnicode_READ(kind, str, idx) != '}') {
if (did_parse) {
raise_errmsg(ERR_OBJECT_DELIMITER, pystr, idx);
} else {
raise_errmsg(ERR_OBJECT_PROPERTY_FIRST, pystr, idx);
}
goto bail;
}
/* if pairs_hook is not None: rval = object_pairs_hook(pairs) */
if (s->pairs_hook != Py_None) {
val = PyObject_CallFunctionObjArgs(s->pairs_hook, pairs, NULL);
if (val == NULL)
goto bail;
Py_DECREF(pairs);
*next_idx_ptr = idx + 1;
return val;
}
/* if object_hook is not None: rval = object_hook(rval) */
if (s->object_hook != Py_None) {
val = PyObject_CallFunctionObjArgs(s->object_hook, rval, NULL);
if (val == NULL)
goto bail;
Py_DECREF(rval);
rval = val;
val = NULL;
}
*next_idx_ptr = idx + 1;
return rval;
bail:
Py_XDECREF(rval);
Py_XDECREF(key);
Py_XDECREF(val);
Py_XDECREF(pairs);
return NULL;
}
#if PY_MAJOR_VERSION < 3
static PyObject *
_parse_array_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr)
{
/* Read a JSON array from PyString pystr.
idx is the index of the first character after the opening brace.
*next_idx_ptr is a return-by-reference index to the first character after
the closing brace.
Returns a new PyList
*/
char *str = PyString_AS_STRING(pystr);
Py_ssize_t end_idx = PyString_GET_SIZE(pystr) - 1;
PyObject *val = NULL;
PyObject *rval = PyList_New(0);
Py_ssize_t next_idx;
if (rval == NULL)
return NULL;
/* skip whitespace after [ */
while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++;
/* only loop if the array is non-empty */
if (idx <= end_idx && str[idx] != ']') {
int trailing_delimiter = 0;
while (idx <= end_idx) {
trailing_delimiter = 0;
/* read any JSON term and de-tuplefy the (rval, idx) */
val = scan_once_str(s, pystr, idx, &next_idx);
if (val == NULL) {
goto bail;
}
if (PyList_Append(rval, val) == -1)
goto bail;
Py_CLEAR(val);
idx = next_idx;
/* skip whitespace between term and , */
while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++;
/* bail if the array is closed or we didn't get the , delimiter */
if (idx > end_idx) break;
if (str[idx] == ']') {
break;
}
else if (str[idx] != ',') {
raise_errmsg(ERR_ARRAY_DELIMITER, pystr, idx);
goto bail;
}
idx++;
/* skip whitespace after , */
while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++;
trailing_delimiter = 1;
}
if (trailing_delimiter) {
raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx);
goto bail;
}
}
/* verify that idx < end_idx, str[idx] should be ']' */
if (idx > end_idx || str[idx] != ']') {
if (PyList_GET_SIZE(rval)) {
raise_errmsg(ERR_ARRAY_DELIMITER, pystr, idx);
} else {
raise_errmsg(ERR_ARRAY_VALUE_FIRST, pystr, idx);
}
goto bail;
}
*next_idx_ptr = idx + 1;
return rval;
bail:
Py_XDECREF(val);
Py_DECREF(rval);
return NULL;
}
#endif /* PY_MAJOR_VERSION < 3 */
static PyObject *
_parse_array_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr)
{
/* Read a JSON array from PyString pystr.
idx is the index of the first character after the opening brace.
*next_idx_ptr is a return-by-reference index to the first character after
the closing brace.
Returns a new PyList
*/
PY2_UNUSED int kind = PyUnicode_KIND(pystr);
void *str = PyUnicode_DATA(pystr);
Py_ssize_t end_idx = PyUnicode_GET_LENGTH(pystr) - 1;
PyObject *val = NULL;
PyObject *rval = PyList_New(0);
Py_ssize_t next_idx;
if (rval == NULL)
return NULL;
/* skip whitespace after [ */
while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++;
/* only loop if the array is non-empty */
if (idx <= end_idx && PyUnicode_READ(kind, str, idx) != ']') {
int trailing_delimiter = 0;
while (idx <= end_idx) {
trailing_delimiter = 0;
/* read any JSON term */
val = scan_once_unicode(s, pystr, idx, &next_idx);
if (val == NULL) {
goto bail;
}
if (PyList_Append(rval, val) == -1)
goto bail;
Py_CLEAR(val);
idx = next_idx;
/* skip whitespace between term and , */
while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++;
/* bail if the array is closed or we didn't get the , delimiter */
if (idx > end_idx) break;
if (PyUnicode_READ(kind, str, idx) == ']') {
break;
}
else if (PyUnicode_READ(kind, str, idx) != ',') {
raise_errmsg(ERR_ARRAY_DELIMITER, pystr, idx);
goto bail;
}
idx++;
/* skip whitespace after , */
while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++;
trailing_delimiter = 1;
}
if (trailing_delimiter) {
raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx);
goto bail;
}
}
/* verify that idx < end_idx, str[idx] should be ']' */
if (idx > end_idx || PyUnicode_READ(kind, str, idx) != ']') {
if (PyList_GET_SIZE(rval)) {
raise_errmsg(ERR_ARRAY_DELIMITER, pystr, idx);
} else {
raise_errmsg(ERR_ARRAY_VALUE_FIRST, pystr, idx);
}
goto bail;
}
*next_idx_ptr = idx + 1;
return rval;
bail:
Py_XDECREF(val);
Py_DECREF(rval);
return NULL;
}
static PyObject *
_parse_constant(PyScannerObject *s, PyObject *constant, Py_ssize_t idx, Py_ssize_t *next_idx_ptr)
{
/* Read a JSON constant from PyString pystr.
constant is the Python string that was found
("NaN", "Infinity", "-Infinity").
idx is the index of the first character of the constant
*next_idx_ptr is a return-by-reference index to the first character after
the constant.
Returns the result of parse_constant
*/
PyObject *rval;
/* rval = parse_constant(constant) */
rval = PyObject_CallFunctionObjArgs(s->parse_constant, constant, NULL);
idx += PyString_GET_SIZE(constant);
*next_idx_ptr = idx;
return rval;
}
#if PY_MAJOR_VERSION < 3
static PyObject *
_match_number_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t start, Py_ssize_t *next_idx_ptr)
{
/* Read a JSON number from PyString pystr.
idx is the index of the first character of the number
*next_idx_ptr is a return-by-reference index to the first character after
the number.
Returns a new PyObject representation of that number:
PyInt, PyLong, or PyFloat.
May return other types if parse_int or parse_float are set
*/
char *str = PyString_AS_STRING(pystr);
Py_ssize_t end_idx = PyString_GET_SIZE(pystr) - 1;
Py_ssize_t idx = start;
int is_float = 0;
PyObject *rval;
PyObject *numstr;
/* read a sign if it's there, make sure it's not the end of the string */
if (str[idx] == '-') {
if (idx >= end_idx) {
raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx);
return NULL;
}
idx++;
}
/* read as many integer digits as we find as long as it doesn't start with 0 */
if (str[idx] >= '1' && str[idx] <= '9') {
idx++;
while (idx <= end_idx && str[idx] >= '0' && str[idx] <= '9') idx++;
}
/* if it starts with 0 we only expect one integer digit */
else if (str[idx] == '0') {
idx++;
}
/* no integer digits, error */
else {
raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx);
return NULL;
}
/* if the next char is '.' followed by a digit then read all float digits */
if (idx < end_idx && str[idx] == '.' && str[idx + 1] >= '0' && str[idx + 1] <= '9') {
is_float = 1;
idx += 2;
while (idx <= end_idx && str[idx] >= '0' && str[idx] <= '9') idx++;
}
/* if the next char is 'e' or 'E' then maybe read the exponent (or backtrack) */
if (idx < end_idx && (str[idx] == 'e' || str[idx] == 'E')) {
/* save the index of the 'e' or 'E' just in case we need to backtrack */
Py_ssize_t e_start = idx;
idx++;
/* read an exponent sign if present */
if (idx < end_idx && (str[idx] == '-' || str[idx] == '+')) idx++;
/* read all digits */
while (idx <= end_idx && str[idx] >= '0' && str[idx] <= '9') idx++;
/* if we got a digit, then parse as float. if not, backtrack */
if (str[idx - 1] >= '0' && str[idx - 1] <= '9') {
is_float = 1;
}
else {
idx = e_start;
}
}
/* copy the section we determined to be a number */
numstr = PyString_FromStringAndSize(&str[start], idx - start);
if (numstr == NULL)
return NULL;
if (is_float) {
/* parse as a float using a fast path if available, otherwise call user defined method */
if (s->parse_float != (PyObject *)&PyFloat_Type) {
rval = PyObject_CallFunctionObjArgs(s->parse_float, numstr, NULL);
}
else {
/* rval = PyFloat_FromDouble(PyOS_ascii_atof(PyString_AS_STRING(numstr))); */
double d = PyOS_string_to_double(PyString_AS_STRING(numstr),
NULL, NULL);
if (d == -1.0 && PyErr_Occurred())
return NULL;
rval = PyFloat_FromDouble(d);
}
}
else {
/* parse as an int using a fast path if available, otherwise call user defined method */
if (s->parse_int != (PyObject *)&PyInt_Type) {
rval = PyObject_CallFunctionObjArgs(s->parse_int, numstr, NULL);
}
else {
rval = PyInt_FromString(PyString_AS_STRING(numstr), NULL, 10);
}
}
Py_DECREF(numstr);
*next_idx_ptr = idx;
return rval;
}
#endif /* PY_MAJOR_VERSION < 3 */
static PyObject *
_match_number_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t start, Py_ssize_t *next_idx_ptr)
{
/* Read a JSON number from PyUnicode pystr.
idx is the index of the first character of the number
*next_idx_ptr is a return-by-reference index to the first character after
the number.
Returns a new PyObject representation of that number:
PyInt, PyLong, or PyFloat.
May return other types if parse_int or parse_float are set
*/
PY2_UNUSED int kind = PyUnicode_KIND(pystr);
void *str = PyUnicode_DATA(pystr);
Py_ssize_t end_idx = PyUnicode_GET_LENGTH(pystr) - 1;
Py_ssize_t idx = start;
int is_float = 0;
JSON_UNICHR c;
PyObject *rval;
PyObject *numstr;
/* read a sign if it's there, make sure it's not the end of the string */
if (PyUnicode_READ(kind, str, idx) == '-') {
if (idx >= end_idx) {
raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx);
return NULL;
}
idx++;
}
/* read as many integer digits as we find as long as it doesn't start with 0 */
c = PyUnicode_READ(kind, str, idx);
if (c == '0') {
/* if it starts with 0 we only expect one integer digit */
idx++;
}
else if (IS_DIGIT(c)) {
idx++;
while (idx <= end_idx && IS_DIGIT(PyUnicode_READ(kind, str, idx))) {
idx++;
}
}
else {
/* no integer digits, error */
raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx);
return NULL;
}
/* if the next char is '.' followed by a digit then read all float digits */
if (idx < end_idx &&
PyUnicode_READ(kind, str, idx) == '.' &&
IS_DIGIT(PyUnicode_READ(kind, str, idx + 1))) {
is_float = 1;
idx += 2;
while (idx <= end_idx && IS_DIGIT(PyUnicode_READ(kind, str, idx))) idx++;
}
/* if the next char is 'e' or 'E' then maybe read the exponent (or backtrack) */
if (idx < end_idx &&
(PyUnicode_READ(kind, str, idx) == 'e' ||
PyUnicode_READ(kind, str, idx) == 'E')) {
Py_ssize_t e_start = idx;
idx++;
/* read an exponent sign if present */
if (idx < end_idx &&
(PyUnicode_READ(kind, str, idx) == '-' ||
PyUnicode_READ(kind, str, idx) == '+')) idx++;
/* read all digits */
while (idx <= end_idx && IS_DIGIT(PyUnicode_READ(kind, str, idx))) idx++;
/* if we got a digit, then parse as float. if not, backtrack */
if (IS_DIGIT(PyUnicode_READ(kind, str, idx - 1))) {
is_float = 1;
}
else {
idx = e_start;
}
}
/* copy the section we determined to be a number */
#if PY_MAJOR_VERSION >= 3
numstr = PyUnicode_Substring(pystr, start, idx);
#else
numstr = PyUnicode_FromUnicode(&((Py_UNICODE *)str)[start], idx - start);
#endif
if (numstr == NULL)
return NULL;
if (is_float) {
/* parse as a float using a fast path if available, otherwise call user defined method */
if (s->parse_float != (PyObject *)&PyFloat_Type) {
rval = PyObject_CallFunctionObjArgs(s->parse_float, numstr, NULL);
}
else {
#if PY_MAJOR_VERSION >= 3
rval = PyFloat_FromString(numstr);
#else
rval = PyFloat_FromString(numstr, NULL);
#endif
}
}
else {
/* no fast path for unicode -> int, just call */
rval = PyObject_CallFunctionObjArgs(s->parse_int, numstr, NULL);
}
Py_DECREF(numstr);
*next_idx_ptr = idx;
return rval;
}
#if PY_MAJOR_VERSION < 3
static PyObject *
scan_once_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr)
{
/* Read one JSON term (of any kind) from PyString pystr.
idx is the index of the first character of the term
*next_idx_ptr is a return-by-reference index to the first character after
the number.
Returns a new PyObject representation of the term.
*/
char *str = PyString_AS_STRING(pystr);
Py_ssize_t length = PyString_GET_SIZE(pystr);
PyObject *rval = NULL;
int fallthrough = 0;
if (idx < 0 || idx >= length) {
raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx);
return NULL;
}
switch (str[idx]) {
case '"':
/* string */
rval = scanstring_str(pystr, idx + 1,
PyString_AS_STRING(s->encoding),
s->strict,
next_idx_ptr);
break;
case '{':
/* object */
if (Py_EnterRecursiveCall(" while decoding a JSON object "
"from a string"))
return NULL;
rval = _parse_object_str(s, pystr, idx + 1, next_idx_ptr);
Py_LeaveRecursiveCall();
break;
case '[':
/* array */
if (Py_EnterRecursiveCall(" while decoding a JSON array "
"from a string"))
return NULL;
rval = _parse_array_str(s, pystr, idx + 1, next_idx_ptr);
Py_LeaveRecursiveCall();
break;
case 'n':
/* null */
if ((idx + 3 < length) && str[idx + 1] == 'u' && str[idx + 2] == 'l' && str[idx + 3] == 'l') {
Py_INCREF(Py_None);
*next_idx_ptr = idx + 4;
rval = Py_None;
}
else
fallthrough = 1;
break;
case 't':
/* true */
if ((idx + 3 < length) && str[idx + 1] == 'r' && str[idx + 2] == 'u' && str[idx + 3] == 'e') {
Py_INCREF(Py_True);
*next_idx_ptr = idx + 4;
rval = Py_True;
}
else
fallthrough = 1;
break;
case 'f':
/* false */
if ((idx + 4 < length) && str[idx + 1] == 'a' && str[idx + 2] == 'l' && str[idx + 3] == 's' && str[idx + 4] == 'e') {
Py_INCREF(Py_False);
*next_idx_ptr = idx + 5;
rval = Py_False;
}
else
fallthrough = 1;
break;
case 'N':
/* NaN */
if ((idx + 2 < length) && str[idx + 1] == 'a' && str[idx + 2] == 'N') {
rval = _parse_constant(s, JSON_NaN, idx, next_idx_ptr);
}
else
fallthrough = 1;
break;
case 'I':
/* Infinity */
if ((idx + 7 < length) && str[idx + 1] == 'n' && str[idx + 2] == 'f' && str[idx + 3] == 'i' && str[idx + 4] == 'n' && str[idx + 5] == 'i' && str[idx + 6] == 't' && str[idx + 7] == 'y') {
rval = _parse_constant(s, JSON_Infinity, idx, next_idx_ptr);
}
else
fallthrough = 1;
break;
case '-':
/* -Infinity */
if ((idx + 8 < length) && str[idx + 1] == 'I' && str[idx + 2] == 'n' && str[idx + 3] == 'f' && str[idx + 4] == 'i' && str[idx + 5] == 'n' && str[idx + 6] == 'i' && str[idx + 7] == 't' && str[idx + 8] == 'y') {
rval = _parse_constant(s, JSON_NegInfinity, idx, next_idx_ptr);
}
else
fallthrough = 1;
break;
default:
fallthrough = 1;
}
/* Didn't find a string, object, array, or named constant. Look for a number. */
if (fallthrough)
rval = _match_number_str(s, pystr, idx, next_idx_ptr);
return rval;
}
#endif /* PY_MAJOR_VERSION < 3 */
static PyObject *
scan_once_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr)
{
/* Read one JSON term (of any kind) from PyUnicode pystr.
idx is the index of the first character of the term
*next_idx_ptr is a return-by-reference index to the first character after
the number.
Returns a new PyObject representation of the term.
*/
PY2_UNUSED int kind = PyUnicode_KIND(pystr);
void *str = PyUnicode_DATA(pystr);
Py_ssize_t length = PyUnicode_GET_LENGTH(pystr);
PyObject *rval = NULL;
int fallthrough = 0;
if (idx < 0 || idx >= length) {
raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx);
return NULL;
}
switch (PyUnicode_READ(kind, str, idx)) {
case '"':
/* string */
rval = scanstring_unicode(pystr, idx + 1,
s->strict,
next_idx_ptr);
break;
case '{':
/* object */
if (Py_EnterRecursiveCall(" while decoding a JSON object "
"from a unicode string"))
return NULL;
rval = _parse_object_unicode(s, pystr, idx + 1, next_idx_ptr);
Py_LeaveRecursiveCall();
break;
case '[':
/* array */
if (Py_EnterRecursiveCall(" while decoding a JSON array "
"from a unicode string"))
return NULL;
rval = _parse_array_unicode(s, pystr, idx + 1, next_idx_ptr);
Py_LeaveRecursiveCall();
break;
case 'n':
/* null */
if ((idx + 3 < length) &&
PyUnicode_READ(kind, str, idx + 1) == 'u' &&
PyUnicode_READ(kind, str, idx + 2) == 'l' &&
PyUnicode_READ(kind, str, idx + 3) == 'l') {
Py_INCREF(Py_None);
*next_idx_ptr = idx + 4;
rval = Py_None;
}
else
fallthrough = 1;
break;
case 't':
/* true */
if ((idx + 3 < length) &&
PyUnicode_READ(kind, str, idx + 1) == 'r' &&
PyUnicode_READ(kind, str, idx + 2) == 'u' &&
PyUnicode_READ(kind, str, idx + 3) == 'e') {
Py_INCREF(Py_True);
*next_idx_ptr = idx + 4;
rval = Py_True;
}
else
fallthrough = 1;
break;
case 'f':
/* false */
if ((idx + 4 < length) &&
PyUnicode_READ(kind, str, idx + 1) == 'a' &&
PyUnicode_READ(kind, str, idx + 2) == 'l' &&
PyUnicode_READ(kind, str, idx + 3) == 's' &&
PyUnicode_READ(kind, str, idx + 4) == 'e') {
Py_INCREF(Py_False);
*next_idx_ptr = idx + 5;
rval = Py_False;
}
else
fallthrough = 1;
break;
case 'N':
/* NaN */
if ((idx + 2 < length) &&
PyUnicode_READ(kind, str, idx + 1) == 'a' &&
PyUnicode_READ(kind, str, idx + 2) == 'N') {
rval = _parse_constant(s, JSON_NaN, idx, next_idx_ptr);
}
else
fallthrough = 1;
break;
case 'I':
/* Infinity */
if ((idx + 7 < length) &&
PyUnicode_READ(kind, str, idx + 1) == 'n' &&
PyUnicode_READ(kind, str, idx + 2) == 'f' &&
PyUnicode_READ(kind, str, idx + 3) == 'i' &&
PyUnicode_READ(kind, str, idx + 4) == 'n' &&
PyUnicode_READ(kind, str, idx + 5) == 'i' &&
PyUnicode_READ(kind, str, idx + 6) == 't' &&
PyUnicode_READ(kind, str, idx + 7) == 'y') {
rval = _parse_constant(s, JSON_Infinity, idx, next_idx_ptr);
}
else
fallthrough = 1;
break;
case '-':
/* -Infinity */
if ((idx + 8 < length) &&
PyUnicode_READ(kind, str, idx + 1) == 'I' &&
PyUnicode_READ(kind, str, idx + 2) == 'n' &&
PyUnicode_READ(kind, str, idx + 3) == 'f' &&
PyUnicode_READ(kind, str, idx + 4) == 'i' &&
PyUnicode_READ(kind, str, idx + 5) == 'n' &&
PyUnicode_READ(kind, str, idx + 6) == 'i' &&
PyUnicode_READ(kind, str, idx + 7) == 't' &&
PyUnicode_READ(kind, str, idx + 8) == 'y') {
rval = _parse_constant(s, JSON_NegInfinity, idx, next_idx_ptr);
}
else
fallthrough = 1;
break;
default:
fallthrough = 1;
}
/* Didn't find a string, object, array, or named constant. Look for a number. */
if (fallthrough)
rval = _match_number_unicode(s, pystr, idx, next_idx_ptr);
return rval;
}
static PyObject *
scanner_call(PyObject *self, PyObject *args, PyObject *kwds)
{
/* Python callable interface to scan_once_{str,unicode} */
PyObject *pystr;
PyObject *rval;
Py_ssize_t idx;
Py_ssize_t next_idx = -1;
static char *kwlist[] = {"string", "idx", NULL};
PyScannerObject *s;
assert(PyScanner_Check(self));
s = (PyScannerObject *)self;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO&:scan_once", kwlist, &pystr, _convertPyInt_AsSsize_t, &idx))
return NULL;
if (PyUnicode_Check(pystr)) {
if (PyUnicode_READY(pystr))
return NULL;
rval = scan_once_unicode(s, pystr, idx, &next_idx);
}
#if PY_MAJOR_VERSION < 3
else if (PyString_Check(pystr)) {
rval = scan_once_str(s, pystr, idx, &next_idx);
}
#endif /* PY_MAJOR_VERSION < 3 */
else {
PyErr_Format(PyExc_TypeError,
"first argument must be a string, not %.80s",
Py_TYPE(pystr)->tp_name);
return NULL;
}
PyDict_Clear(s->memo);
return _build_rval_index_tuple(rval, next_idx);
}
static PyObject *
JSON_ParseEncoding(PyObject *encoding)
{
if (encoding == Py_None)
return JSON_InternFromString(DEFAULT_ENCODING);
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(encoding)) {
if (PyUnicode_AsUTF8(encoding) == NULL) {
return NULL;
}
Py_INCREF(encoding);
return encoding;
}
#else /* PY_MAJOR_VERSION >= 3 */
if (PyString_Check(encoding)) {
Py_INCREF(encoding);
return encoding;
}
if (PyUnicode_Check(encoding))
return PyUnicode_AsEncodedString(encoding, NULL, NULL);
#endif /* PY_MAJOR_VERSION >= 3 */
PyErr_SetString(PyExc_TypeError, "encoding must be a string");
return NULL;
}
static PyObject *
scanner_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
/* Initialize Scanner object */
PyObject *ctx;
static char *kwlist[] = {"context", NULL};
PyScannerObject *s;
PyObject *encoding;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:make_scanner", kwlist, &ctx))
return NULL;
s = (PyScannerObject *)type->tp_alloc(type, 0);
if (s == NULL)
return NULL;
if (s->memo == NULL) {
s->memo = PyDict_New();
if (s->memo == NULL)
goto bail;
}
encoding = PyObject_GetAttrString(ctx, "encoding");
if (encoding == NULL)
goto bail;
s->encoding = JSON_ParseEncoding(encoding);
Py_XDECREF(encoding);
if (s->encoding == NULL)
goto bail;
/* All of these will fail "gracefully" so we don't need to verify them */
s->strict_bool = PyObject_GetAttrString(ctx, "strict");
if (s->strict_bool == NULL)
goto bail;
s->strict = PyObject_IsTrue(s->strict_bool);
if (s->strict < 0)
goto bail;
s->object_hook = PyObject_GetAttrString(ctx, "object_hook");
if (s->object_hook == NULL)
goto bail;
s->pairs_hook = PyObject_GetAttrString(ctx, "object_pairs_hook");
if (s->pairs_hook == NULL)
goto bail;
s->parse_float = PyObject_GetAttrString(ctx, "parse_float");
if (s->parse_float == NULL)
goto bail;
s->parse_int = PyObject_GetAttrString(ctx, "parse_int");
if (s->parse_int == NULL)
goto bail;
s->parse_constant = PyObject_GetAttrString(ctx, "parse_constant");
if (s->parse_constant == NULL)
goto bail;
return (PyObject *)s;
bail:
Py_DECREF(s);
return NULL;
}
PyDoc_STRVAR(scanner_doc, "JSON scanner object");
static
PyTypeObject PyScannerType = {
PyVarObject_HEAD_INIT(NULL, 0)
"simplejson._speedups.Scanner", /* tp_name */
sizeof(PyScannerObject), /* tp_basicsize */
0, /* tp_itemsize */
scanner_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
scanner_call, /* tp_call */
0, /* tp_str */
0,/* PyObject_GenericGetAttr, */ /* tp_getattro */
0,/* PyObject_GenericSetAttr, */ /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
scanner_doc, /* tp_doc */
scanner_traverse, /* tp_traverse */
scanner_clear, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
scanner_members, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0,/* PyType_GenericAlloc, */ /* tp_alloc */
scanner_new, /* tp_new */
0,/* PyObject_GC_Del, */ /* tp_free */
};
static PyObject *
encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {
"markers",
"default",
"encoder",
"indent",
"key_separator",
"item_separator",
"sort_keys",
"skipkeys",
"allow_nan",
"key_memo",
"use_decimal",
"namedtuple_as_object",
"tuple_as_array",
"int_as_string_bitcount",
"item_sort_key",
"encoding",
"for_json",
"ignore_nan",
"Decimal",
"iterable_as_array",
NULL};
PyEncoderObject *s;
PyObject *markers, *defaultfn, *encoder, *indent, *key_separator;
PyObject *item_separator, *sort_keys, *skipkeys, *allow_nan, *key_memo;
PyObject *use_decimal, *namedtuple_as_object, *tuple_as_array, *iterable_as_array;
PyObject *int_as_string_bitcount, *item_sort_key, *encoding, *for_json;
PyObject *ignore_nan, *Decimal;
int is_true;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOOOOOOOOOOOOOOOO:make_encoder", kwlist,
&markers, &defaultfn, &encoder, &indent, &key_separator, &item_separator,
&sort_keys, &skipkeys, &allow_nan, &key_memo, &use_decimal,
&namedtuple_as_object, &tuple_as_array,
&int_as_string_bitcount, &item_sort_key, &encoding, &for_json,
&ignore_nan, &Decimal, &iterable_as_array))
return NULL;
s = (PyEncoderObject *)type->tp_alloc(type, 0);
if (s == NULL)
return NULL;
Py_INCREF(markers);
s->markers = markers;
Py_INCREF(defaultfn);
s->defaultfn = defaultfn;
Py_INCREF(encoder);
s->encoder = encoder;
#if PY_MAJOR_VERSION >= 3
if (encoding == Py_None) {
s->encoding = NULL;
}
else
#endif /* PY_MAJOR_VERSION >= 3 */
{
s->encoding = JSON_ParseEncoding(encoding);
if (s->encoding == NULL)
goto bail;
}
Py_INCREF(indent);
s->indent = indent;
Py_INCREF(key_separator);
s->key_separator = key_separator;
Py_INCREF(item_separator);
s->item_separator = item_separator;
Py_INCREF(skipkeys);
s->skipkeys_bool = skipkeys;
s->skipkeys = PyObject_IsTrue(skipkeys);
if (s->skipkeys < 0)
goto bail;
Py_INCREF(key_memo);
s->key_memo = key_memo;
s->fast_encode = (PyCFunction_Check(s->encoder) && PyCFunction_GetFunction(s->encoder) == (PyCFunction)py_encode_basestring_ascii);
is_true = PyObject_IsTrue(ignore_nan);
if (is_true < 0)
goto bail;
s->allow_or_ignore_nan = is_true ? JSON_IGNORE_NAN : 0;
is_true = PyObject_IsTrue(allow_nan);
if (is_true < 0)
goto bail;
s->allow_or_ignore_nan |= is_true ? JSON_ALLOW_NAN : 0;
s->use_decimal = PyObject_IsTrue(use_decimal);
if (s->use_decimal < 0)
goto bail;
s->namedtuple_as_object = PyObject_IsTrue(namedtuple_as_object);
if (s->namedtuple_as_object < 0)
goto bail;
s->tuple_as_array = PyObject_IsTrue(tuple_as_array);
if (s->tuple_as_array < 0)
goto bail;
s->iterable_as_array = PyObject_IsTrue(iterable_as_array);
if (s->iterable_as_array < 0)
goto bail;
if (PyInt_Check(int_as_string_bitcount) || PyLong_Check(int_as_string_bitcount)) {
static const unsigned long long_long_bitsize = SIZEOF_LONG_LONG * 8;
long int_as_string_bitcount_val = PyLong_AsLong(int_as_string_bitcount);
if (int_as_string_bitcount_val > 0 && int_as_string_bitcount_val < (long)long_long_bitsize) {
s->max_long_size = PyLong_FromUnsignedLongLong(1ULL << (int)int_as_string_bitcount_val);
s->min_long_size = PyLong_FromLongLong(-1LL << (int)int_as_string_bitcount_val);
if (s->min_long_size == NULL || s->max_long_size == NULL) {
goto bail;
}
}
else {
PyErr_Format(PyExc_TypeError,
"int_as_string_bitcount (%ld) must be greater than 0 and less than the number of bits of a `long long` type (%lu bits)",
int_as_string_bitcount_val, long_long_bitsize);
goto bail;
}
}
else if (int_as_string_bitcount == Py_None) {
Py_INCREF(Py_None);
s->max_long_size = Py_None;
Py_INCREF(Py_None);
s->min_long_size = Py_None;
}
else {
PyErr_SetString(PyExc_TypeError, "int_as_string_bitcount must be None or an integer");
goto bail;
}
if (item_sort_key != Py_None) {
if (!PyCallable_Check(item_sort_key)) {
PyErr_SetString(PyExc_TypeError, "item_sort_key must be None or callable");
goto bail;
}
}
else {
is_true = PyObject_IsTrue(sort_keys);
if (is_true < 0)
goto bail;
if (is_true) {
static PyObject *itemgetter0 = NULL;
if (!itemgetter0) {
PyObject *operator = PyImport_ImportModule("operator");
if (!operator)
goto bail;
itemgetter0 = PyObject_CallMethod(operator, "itemgetter", "i", 0);
Py_DECREF(operator);
}
item_sort_key = itemgetter0;
if (!item_sort_key)
goto bail;
}
}
if (item_sort_key == Py_None) {
Py_INCREF(Py_None);
s->item_sort_kw = Py_None;
}
else {
s->item_sort_kw = PyDict_New();
if (s->item_sort_kw == NULL)
goto bail;
if (PyDict_SetItemString(s->item_sort_kw, "key", item_sort_key))
goto bail;
}
Py_INCREF(sort_keys);
s->sort_keys = sort_keys;
Py_INCREF(item_sort_key);
s->item_sort_key = item_sort_key;
Py_INCREF(Decimal);
s->Decimal = Decimal;
s->for_json = PyObject_IsTrue(for_json);
if (s->for_json < 0)
goto bail;
return (PyObject *)s;
bail:
Py_DECREF(s);
return NULL;
}
static PyObject *
encoder_call(PyObject *self, PyObject *args, PyObject *kwds)
{
/* Python callable interface to encode_listencode_obj */
static char *kwlist[] = {"obj", "_current_indent_level", NULL};
PyObject *obj;
Py_ssize_t indent_level;
PyEncoderObject *s;
JSON_Accu rval;
assert(PyEncoder_Check(self));
s = (PyEncoderObject *)self;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO&:_iterencode", kwlist,
&obj, _convertPyInt_AsSsize_t, &indent_level))
return NULL;
if (JSON_Accu_Init(&rval))
return NULL;
if (encoder_listencode_obj(s, &rval, obj, indent_level)) {
JSON_Accu_Destroy(&rval);
return NULL;
}
return JSON_Accu_FinishAsList(&rval);
}
static PyObject *
_encoded_const(PyObject *obj)
{
/* Return the JSON string representation of None, True, False */
if (obj == Py_None) {
static PyObject *s_null = NULL;
if (s_null == NULL) {
s_null = JSON_InternFromString("null");
}
Py_INCREF(s_null);
return s_null;
}
else if (obj == Py_True) {
static PyObject *s_true = NULL;
if (s_true == NULL) {
s_true = JSON_InternFromString("true");
}
Py_INCREF(s_true);
return s_true;
}
else if (obj == Py_False) {
static PyObject *s_false = NULL;
if (s_false == NULL) {
s_false = JSON_InternFromString("false");
}
Py_INCREF(s_false);
return s_false;
}
else {
PyErr_SetString(PyExc_ValueError, "not a const");
return NULL;
}
}
static PyObject *
encoder_encode_float(PyEncoderObject *s, PyObject *obj)
{
/* Return the JSON representation of a PyFloat */
double i = PyFloat_AS_DOUBLE(obj);
if (!Py_IS_FINITE(i)) {
if (!s->allow_or_ignore_nan) {
PyErr_SetString(PyExc_ValueError, "Out of range float values are not JSON compliant");
return NULL;
}
if (s->allow_or_ignore_nan & JSON_IGNORE_NAN) {
return _encoded_const(Py_None);
}
/* JSON_ALLOW_NAN is set */
else if (i > 0) {
Py_INCREF(JSON_Infinity);
return JSON_Infinity;
}
else if (i < 0) {
Py_INCREF(JSON_NegInfinity);
return JSON_NegInfinity;
}
else {
Py_INCREF(JSON_NaN);
return JSON_NaN;
}
}
/* Use a better float format here? */
if (PyFloat_CheckExact(obj)) {
return PyObject_Repr(obj);
}
else {
/* See #118, do not trust custom str/repr */
PyObject *res;
PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyFloat_Type, obj, NULL);
if (tmp == NULL) {
return NULL;
}
res = PyObject_Repr(tmp);
Py_DECREF(tmp);
return res;
}
}
static PyObject *
encoder_encode_string(PyEncoderObject *s, PyObject *obj)
{
/* Return the JSON representation of a string */
PyObject *encoded;
if (s->fast_encode) {
return py_encode_basestring_ascii(NULL, obj);
}
encoded = PyObject_CallFunctionObjArgs(s->encoder, obj, NULL);
if (encoded != NULL &&
#if PY_MAJOR_VERSION < 3
!PyString_Check(encoded) &&
#endif /* PY_MAJOR_VERSION < 3 */
!PyUnicode_Check(encoded))
{
PyErr_Format(PyExc_TypeError,
"encoder() must return a string, not %.80s",
Py_TYPE(encoded)->tp_name);
Py_DECREF(encoded);
return NULL;
}
return encoded;
}
static int
_steal_accumulate(JSON_Accu *accu, PyObject *stolen)
{
/* Append stolen and then decrement its reference count */
int rval = JSON_Accu_Accumulate(accu, stolen);
Py_DECREF(stolen);
return rval;
}
static int
encoder_listencode_obj(PyEncoderObject *s, JSON_Accu *rval, PyObject *obj, Py_ssize_t indent_level)
{
/* Encode Python object obj to a JSON term, rval is a PyList */
int rv = -1;
do {
if (obj == Py_None || obj == Py_True || obj == Py_False) {
PyObject *cstr = _encoded_const(obj);
if (cstr != NULL)
rv = _steal_accumulate(rval, cstr);
}
else if ((PyBytes_Check(obj) && s->encoding != NULL) ||
PyUnicode_Check(obj))
{
PyObject *encoded = encoder_encode_string(s, obj);
if (encoded != NULL)
rv = _steal_accumulate(rval, encoded);
}
else if (PyInt_Check(obj) || PyLong_Check(obj)) {
PyObject *encoded;
if (PyInt_CheckExact(obj) || PyLong_CheckExact(obj)) {
encoded = PyObject_Str(obj);
}
else {
/* See #118, do not trust custom str/repr */
PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyLong_Type, obj, NULL);
if (tmp == NULL) {
encoded = NULL;
}
else {
encoded = PyObject_Str(tmp);
Py_DECREF(tmp);
}
}
if (encoded != NULL) {
encoded = maybe_quote_bigint(s, encoded, obj);
if (encoded == NULL)
break;
rv = _steal_accumulate(rval, encoded);
}
}
else if (PyFloat_Check(obj)) {
PyObject *encoded = encoder_encode_float(s, obj);
if (encoded != NULL)
rv = _steal_accumulate(rval, encoded);
}
else if (s->for_json && _has_for_json_hook(obj)) {
PyObject *newobj;
if (Py_EnterRecursiveCall(" while encoding a JSON object"))
return rv;
newobj = PyObject_CallMethod(obj, "for_json", NULL);
if (newobj != NULL) {
rv = encoder_listencode_obj(s, rval, newobj, indent_level);
Py_DECREF(newobj);
}
Py_LeaveRecursiveCall();
}
else if (s->namedtuple_as_object && _is_namedtuple(obj)) {
PyObject *newobj;
if (Py_EnterRecursiveCall(" while encoding a JSON object"))
return rv;
newobj = PyObject_CallMethod(obj, "_asdict", NULL);
if (newobj != NULL) {
rv = encoder_listencode_dict(s, rval, newobj, indent_level);
Py_DECREF(newobj);
}
Py_LeaveRecursiveCall();
}
else if (PyList_Check(obj) || (s->tuple_as_array && PyTuple_Check(obj))) {
if (Py_EnterRecursiveCall(" while encoding a JSON object"))
return rv;
rv = encoder_listencode_list(s, rval, obj, indent_level);
Py_LeaveRecursiveCall();
}
else if (PyDict_Check(obj)) {
if (Py_EnterRecursiveCall(" while encoding a JSON object"))
return rv;
rv = encoder_listencode_dict(s, rval, obj, indent_level);
Py_LeaveRecursiveCall();
}
else if (s->use_decimal && PyObject_TypeCheck(obj, (PyTypeObject *)s->Decimal)) {
PyObject *encoded = PyObject_Str(obj);
if (encoded != NULL)
rv = _steal_accumulate(rval, encoded);
}
else if (is_raw_json(obj))
{
PyObject *encoded = PyObject_GetAttrString(obj, "encoded_json");
if (encoded != NULL)
rv = _steal_accumulate(rval, encoded);
}
else {
PyObject *ident = NULL;
PyObject *newobj;
if (s->iterable_as_array) {
newobj = PyObject_GetIter(obj);
if (newobj == NULL)
PyErr_Clear();
else {
rv = encoder_listencode_list(s, rval, newobj, indent_level);
Py_DECREF(newobj);
break;
}
}
if (s->markers != Py_None) {
int has_key;
ident = PyLong_FromVoidPtr(obj);
if (ident == NULL)
break;
has_key = PyDict_Contains(s->markers, ident);
if (has_key) {
if (has_key != -1)
PyErr_SetString(PyExc_ValueError, "Circular reference detected");
Py_DECREF(ident);
break;
}
if (PyDict_SetItem(s->markers, ident, obj)) {
Py_DECREF(ident);
break;
}
}
if (Py_EnterRecursiveCall(" while encoding a JSON object"))
return rv;
newobj = PyObject_CallFunctionObjArgs(s->defaultfn, obj, NULL);
if (newobj == NULL) {
Py_XDECREF(ident);
Py_LeaveRecursiveCall();
break;
}
rv = encoder_listencode_obj(s, rval, newobj, indent_level);
Py_LeaveRecursiveCall();
Py_DECREF(newobj);
if (rv) {
Py_XDECREF(ident);
rv = -1;
}
else if (ident != NULL) {
if (PyDict_DelItem(s->markers, ident)) {
Py_XDECREF(ident);
rv = -1;
}
Py_XDECREF(ident);
}
}
} while (0);
return rv;
}
static int
encoder_listencode_dict(PyEncoderObject *s, JSON_Accu *rval, PyObject *dct, Py_ssize_t indent_level)
{
/* Encode Python dict dct a JSON term */
static PyObject *open_dict = NULL;
static PyObject *close_dict = NULL;
static PyObject *empty_dict = NULL;
PyObject *kstr = NULL;
PyObject *ident = NULL;
PyObject *iter = NULL;
PyObject *item = NULL;
PyObject *items = NULL;
PyObject *encoded = NULL;
Py_ssize_t idx;
if (open_dict == NULL || close_dict == NULL || empty_dict == NULL) {
open_dict = JSON_InternFromString("{");
close_dict = JSON_InternFromString("}");
empty_dict = JSON_InternFromString("{}");
if (open_dict == NULL || close_dict == NULL || empty_dict == NULL)
return -1;
}
if (PyDict_Size(dct) == 0)
return JSON_Accu_Accumulate(rval, empty_dict);
if (s->markers != Py_None) {
int has_key;
ident = PyLong_FromVoidPtr(dct);
if (ident == NULL)
goto bail;
has_key = PyDict_Contains(s->markers, ident);
if (has_key) {
if (has_key != -1)
PyErr_SetString(PyExc_ValueError, "Circular reference detected");
goto bail;
}
if (PyDict_SetItem(s->markers, ident, dct)) {
goto bail;
}
}
if (JSON_Accu_Accumulate(rval, open_dict))
goto bail;
if (s->indent != Py_None) {
/* TODO: DOES NOT RUN */
indent_level += 1;
/*
newline_indent = '\n' + (_indent * _current_indent_level)
separator = _item_separator + newline_indent
buf += newline_indent
*/
}
iter = encoder_dict_iteritems(s, dct);
if (iter == NULL)
goto bail;
idx = 0;
while ((item = PyIter_Next(iter))) {
PyObject *encoded, *key, *value;
if (!PyTuple_Check(item) || Py_SIZE(item) != 2) {
PyErr_SetString(PyExc_ValueError, "items must return 2-tuples");
goto bail;
}
key = PyTuple_GET_ITEM(item, 0);
if (key == NULL)
goto bail;
value = PyTuple_GET_ITEM(item, 1);
if (value == NULL)
goto bail;
encoded = PyDict_GetItem(s->key_memo, key);
if (encoded != NULL) {
Py_INCREF(encoded);
} else {
kstr = encoder_stringify_key(s, key);
if (kstr == NULL)
goto bail;
else if (kstr == Py_None) {
/* skipkeys */
Py_DECREF(item);
Py_DECREF(kstr);
continue;
}
}
if (idx) {
if (JSON_Accu_Accumulate(rval, s->item_separator))
goto bail;
}
if (encoded == NULL) {
encoded = encoder_encode_string(s, kstr);
Py_CLEAR(kstr);
if (encoded == NULL)
goto bail;
if (PyDict_SetItem(s->key_memo, key, encoded))
goto bail;
}
if (JSON_Accu_Accumulate(rval, encoded)) {
goto bail;
}
Py_CLEAR(encoded);
if (JSON_Accu_Accumulate(rval, s->key_separator))
goto bail;
if (encoder_listencode_obj(s, rval, value, indent_level))
goto bail;
Py_CLEAR(item);
idx += 1;
}
Py_CLEAR(iter);
if (PyErr_Occurred())
goto bail;
if (ident != NULL) {
if (PyDict_DelItem(s->markers, ident))
goto bail;
Py_CLEAR(ident);
}
if (s->indent != Py_None) {
/* TODO: DOES NOT RUN */
indent_level -= 1;
/*
yield '\n' + (_indent * _current_indent_level)
*/
}
if (JSON_Accu_Accumulate(rval, close_dict))
goto bail;
return 0;
bail:
Py_XDECREF(encoded);
Py_XDECREF(items);
Py_XDECREF(item);
Py_XDECREF(iter);
Py_XDECREF(kstr);
Py_XDECREF(ident);
return -1;
}
static int
encoder_listencode_list(PyEncoderObject *s, JSON_Accu *rval, PyObject *seq, Py_ssize_t indent_level)
{
/* Encode Python list seq to a JSON term */
static PyObject *open_array = NULL;
static PyObject *close_array = NULL;
static PyObject *empty_array = NULL;
PyObject *ident = NULL;
PyObject *iter = NULL;
PyObject *obj = NULL;
int is_true;
int i = 0;
if (open_array == NULL || close_array == NULL || empty_array == NULL) {
open_array = JSON_InternFromString("[");
close_array = JSON_InternFromString("]");
empty_array = JSON_InternFromString("[]");
if (open_array == NULL || close_array == NULL || empty_array == NULL)
return -1;
}
ident = NULL;
is_true = PyObject_IsTrue(seq);
if (is_true == -1)
return -1;
else if (is_true == 0)
return JSON_Accu_Accumulate(rval, empty_array);
if (s->markers != Py_None) {
int has_key;
ident = PyLong_FromVoidPtr(seq);
if (ident == NULL)
goto bail;
has_key = PyDict_Contains(s->markers, ident);
if (has_key) {
if (has_key != -1)
PyErr_SetString(PyExc_ValueError, "Circular reference detected");
goto bail;
}
if (PyDict_SetItem(s->markers, ident, seq)) {
goto bail;
}
}
iter = PyObject_GetIter(seq);
if (iter == NULL)
goto bail;
if (JSON_Accu_Accumulate(rval, open_array))
goto bail;
if (s->indent != Py_None) {
/* TODO: DOES NOT RUN */
indent_level += 1;
/*
newline_indent = '\n' + (_indent * _current_indent_level)
separator = _item_separator + newline_indent
buf += newline_indent
*/
}
while ((obj = PyIter_Next(iter))) {
if (i) {
if (JSON_Accu_Accumulate(rval, s->item_separator))
goto bail;
}
if (encoder_listencode_obj(s, rval, obj, indent_level))
goto bail;
i++;
Py_CLEAR(obj);
}
Py_CLEAR(iter);
if (PyErr_Occurred())
goto bail;
if (ident != NULL) {
if (PyDict_DelItem(s->markers, ident))
goto bail;
Py_CLEAR(ident);
}
if (s->indent != Py_None) {
/* TODO: DOES NOT RUN */
indent_level -= 1;
/*
yield '\n' + (_indent * _current_indent_level)
*/
}
if (JSON_Accu_Accumulate(rval, close_array))
goto bail;
return 0;
bail:
Py_XDECREF(obj);
Py_XDECREF(iter);
Py_XDECREF(ident);
return -1;
}
static void
encoder_dealloc(PyObject *self)
{
/* bpo-31095: UnTrack is needed before calling any callbacks */
PyObject_GC_UnTrack(self);
encoder_clear(self);
Py_TYPE(self)->tp_free(self);
}
static int
encoder_traverse(PyObject *self, visitproc visit, void *arg)
{
PyEncoderObject *s;
assert(PyEncoder_Check(self));
s = (PyEncoderObject *)self;
Py_VISIT(s->markers);
Py_VISIT(s->defaultfn);
Py_VISIT(s->encoder);
Py_VISIT(s->encoding);
Py_VISIT(s->indent);
Py_VISIT(s->key_separator);
Py_VISIT(s->item_separator);
Py_VISIT(s->key_memo);
Py_VISIT(s->sort_keys);
Py_VISIT(s->item_sort_kw);
Py_VISIT(s->item_sort_key);
Py_VISIT(s->max_long_size);
Py_VISIT(s->min_long_size);
Py_VISIT(s->Decimal);
return 0;
}
static int
encoder_clear(PyObject *self)
{
/* Deallocate Encoder */
PyEncoderObject *s;
assert(PyEncoder_Check(self));
s = (PyEncoderObject *)self;
Py_CLEAR(s->markers);
Py_CLEAR(s->defaultfn);
Py_CLEAR(s->encoder);
Py_CLEAR(s->encoding);
Py_CLEAR(s->indent);
Py_CLEAR(s->key_separator);
Py_CLEAR(s->item_separator);
Py_CLEAR(s->key_memo);
Py_CLEAR(s->skipkeys_bool);
Py_CLEAR(s->sort_keys);
Py_CLEAR(s->item_sort_kw);
Py_CLEAR(s->item_sort_key);
Py_CLEAR(s->max_long_size);
Py_CLEAR(s->min_long_size);
Py_CLEAR(s->Decimal);
return 0;
}
PyDoc_STRVAR(encoder_doc, "_iterencode(obj, _current_indent_level) -> iterable");
static
PyTypeObject PyEncoderType = {
PyVarObject_HEAD_INIT(NULL, 0)
"simplejson._speedups.Encoder", /* tp_name */
sizeof(PyEncoderObject), /* tp_basicsize */
0, /* tp_itemsize */
encoder_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
encoder_call, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
encoder_doc, /* tp_doc */
encoder_traverse, /* tp_traverse */
encoder_clear, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
encoder_members, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
encoder_new, /* tp_new */
0, /* tp_free */
};
static PyMethodDef speedups_methods[] = {
{"encode_basestring_ascii",
(PyCFunction)py_encode_basestring_ascii,
METH_O,
pydoc_encode_basestring_ascii},
{"scanstring",
(PyCFunction)py_scanstring,
METH_VARARGS,
pydoc_scanstring},
{NULL, NULL, 0, NULL}
};
PyDoc_STRVAR(module_doc,
"simplejson speedups\n");
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"_speedups", /* m_name */
module_doc, /* m_doc */
-1, /* m_size */
speedups_methods, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear*/
NULL, /* m_free */
};
#endif
PyObject *
import_dependency(char *module_name, char *attr_name)
{
PyObject *rval;
PyObject *module = PyImport_ImportModule(module_name);
if (module == NULL)
return NULL;
rval = PyObject_GetAttrString(module, attr_name);
Py_DECREF(module);
return rval;
}
static int
init_constants(void)
{
JSON_NaN = JSON_InternFromString("NaN");
if (JSON_NaN == NULL)
return 0;
JSON_Infinity = JSON_InternFromString("Infinity");
if (JSON_Infinity == NULL)
return 0;
JSON_NegInfinity = JSON_InternFromString("-Infinity");
if (JSON_NegInfinity == NULL)
return 0;
#if PY_MAJOR_VERSION >= 3
JSON_EmptyUnicode = PyUnicode_New(0, 127);
#else /* PY_MAJOR_VERSION >= 3 */
JSON_EmptyStr = PyString_FromString("");
if (JSON_EmptyStr == NULL)
return 0;
JSON_EmptyUnicode = PyUnicode_FromUnicode(NULL, 0);
#endif /* PY_MAJOR_VERSION >= 3 */
if (JSON_EmptyUnicode == NULL)
return 0;
return 1;
}
static PyObject *
moduleinit(void)
{
PyObject *m;
if (PyType_Ready(&PyScannerType) < 0)
return NULL;
if (PyType_Ready(&PyEncoderType) < 0)
return NULL;
if (!init_constants())
return NULL;
#if PY_MAJOR_VERSION >= 3
m = PyModule_Create(&moduledef);
#else
m = Py_InitModule3("_speedups", speedups_methods, module_doc);
#endif
Py_INCREF((PyObject*)&PyScannerType);
PyModule_AddObject(m, "make_scanner", (PyObject*)&PyScannerType);
Py_INCREF((PyObject*)&PyEncoderType);
PyModule_AddObject(m, "make_encoder", (PyObject*)&PyEncoderType);
RawJSONType = import_dependency("simplejson.raw_json", "RawJSON");
if (RawJSONType == NULL)
return NULL;
JSONDecodeError = import_dependency("simplejson.errors", "JSONDecodeError");
if (JSONDecodeError == NULL)
return NULL;
return m;
}
#if PY_MAJOR_VERSION >= 3
PyMODINIT_FUNC
PyInit__speedups(void)
{
return moduleinit();
}
#else
void
init_speedups(void)
{
moduleinit();
}
#endif
================================================
FILE: simplejson/compat.py
================================================
"""Python 3 compatibility shims
"""
import sys
if sys.version_info[0] < 3:
PY3 = False
def b(s):
return s
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
BytesIO = StringIO
text_type = unicode
binary_type = str
string_types = (basestring,)
integer_types = (int, long)
unichr = unichr
reload_module = reload
else:
PY3 = True
if sys.version_info[:2] >= (3, 4):
from importlib import reload as reload_module
else:
from imp import reload as reload_module
def b(s):
return bytes(s, 'latin1')
from io import StringIO, BytesIO
text_type = str
binary_type = bytes
string_types = (str,)
integer_types = (int,)
unichr = chr
long_type = integer_types[-1]
================================================
FILE: simplejson/decoder.py
================================================
"""Implementation of JSONDecoder
"""
from __future__ import absolute_import
import re
import sys
import struct
from .compat import PY3, unichr
from .scanner import make_scanner, JSONDecodeError
def _import_c_scanstring():
try:
from ._speedups import scanstring
return scanstring
except ImportError:
return None
c_scanstring = _import_c_scanstring()
# NOTE (3.1.0): JSONDecodeError may still be imported from this module for
# compatibility, but it was never in the __all__
__all__ = ['JSONDecoder']
FLAGS = re.VERBOSE | re.MULTILINE | re.DOTALL
def _floatconstants():
if sys.version_info < (2, 6):
_BYTES = '7FF80000000000007FF0000000000000'.decode('hex')
nan, inf = struct.unpack('>dd', _BYTES)
else:
nan = float('nan')
inf = float('inf')
return nan, inf, -inf
NaN, PosInf, NegInf = _floatconstants()
_CONSTANTS = {
'-Infinity': NegInf,
'Infinity': PosInf,
'NaN': NaN,
}
STRINGCHUNK = re.compile(r'(.*?)(["\\\x00-\x1f])', FLAGS)
BACKSLASH = {
'"': u'"', '\\': u'\\', '/': u'/',
'b': u'\b', 'f': u'\f', 'n': u'\n', 'r': u'\r', 't': u'\t',
}
DEFAULT_ENCODING = "utf-8"
def py_scanstring(s, end, encoding=None, strict=True,
_b=BACKSLASH, _m=STRINGCHUNK.match, _join=u''.join,
_PY3=PY3, _maxunicode=sys.maxunicode):
"""Scan the string s for a JSON string. End is the index of the
character in s after the quote that started the JSON string.
Unescapes all valid JSON string escape sequences and raises ValueError
on attempt to decode an invalid string. If strict is False then literal
control characters are allowed in the string.
Returns a tuple of the decoded string and the index of the character in s
after the end quote."""
if encoding is None:
encoding = DEFAULT_ENCODING
chunks = []
_append = chunks.append
begin = end - 1
while 1:
chunk = _m(s, end)
if chunk is None:
raise JSONDecodeError(
"Unterminated string starting at", s, begin)
end = chunk.end()
content, terminator = chunk.groups()
# Content is contains zero or more unescaped string characters
if content:
if not _PY3 and not isinstance(content, unicode):
content = unicode(content, encoding)
_append(content)
# Terminator is the end of string, a literal control character,
# or a backslash denoting that an escape sequence follows
if terminator == '"':
break
elif terminator != '\\':
if strict:
msg = "Invalid control character %r at"
raise JSONDecodeError(msg, s, end)
else:
_append(terminator)
continue
try:
esc = s[end]
except IndexError:
raise JSONDecodeError(
"Unterminated string starting at", s, begin)
# If not a unicode escape sequence, must be in the lookup table
if esc != 'u':
try:
char = _b[esc]
except KeyError:
msg = "Invalid \\X escape sequence %r"
raise JSONDecodeError(msg, s, end)
end += 1
else:
# Unicode escape sequence
msg = "Invalid \\uXXXX escape sequence"
esc = s[end + 1:end + 5]
escX = esc[1:2]
if len(esc) != 4 or escX == 'x' or escX == 'X':
raise JSONDecodeError(msg, s, end - 1)
try:
uni = int(esc, 16)
except ValueError:
raise JSONDecodeError(msg, s, end - 1)
end += 5
# Check for surrogate pair on UCS-4 systems
# Note that this will join high/low surrogate pairs
# but will also pass unpaired surrogates through
if (_maxunicode > 65535 and
uni & 0xfc00 == 0xd800 and
s[end:end + 2] == '\\u'):
esc2 = s[end + 2:end + 6]
escX = esc2[1:2]
if len(esc2) == 4 and not (escX == 'x' or escX == 'X'):
try:
uni2 = int(esc2, 16)
except ValueError:
raise JSONDecodeError(msg, s, end)
if uni2 & 0xfc00 == 0xdc00:
uni = 0x10000 + (((uni - 0xd800) << 10) |
(uni2 - 0xdc00))
end += 6
char = unichr(uni)
# Append the unescaped character
_append(char)
return _join(chunks), end
# Use speedup if available
scanstring = c_scanstring or py_scanstring
WHITESPACE = re.compile(r'[ \t\n\r]*', FLAGS)
WHITESPACE_STR = ' \t\n\r'
def JSONObject(state, encoding, strict, scan_once, object_hook,
object_pairs_hook, memo=None,
_w=WHITESPACE.match, _ws=WHITESPACE_STR):
(s, end) = state
# Backwards compatibility
if memo is None:
memo = {}
memo_get = memo.setdefault
pairs = []
# Use a slice to prevent IndexError from being raised, the following
# check will raise a more specific ValueError if the string is empty
nextchar = s[end:end + 1]
# Normally we expect nextchar == '"'
if nextchar != '"':
if nextchar in _ws:
end = _w(s, end).end()
nextchar = s[end:end + 1]
# Trivial empty object
if nextchar == '}':
if object_pairs_hook is not None:
result = object_pairs_hook(pairs)
return result, end + 1
pairs = {}
if object_hook is not None:
pairs = object_hook(pairs)
return pairs, end + 1
elif nextchar != '"':
raise JSONDecodeError(
"Expecting property name enclosed in double quotes",
s, end)
end += 1
while True:
key, end = scanstring(s, end, encoding, strict)
key = memo_get(key, key)
# To skip some function call overhead we optimize the fast paths where
# the JSON key separator is ": " or just ":".
if s[end:end + 1] != ':':
end = _w(s, end).end()
if s[end:end + 1] != ':':
raise JSONDecodeError("Expecting ':' delimiter", s, end)
end += 1
try:
if s[end] in _ws:
end += 1
if s[end] in _ws:
end = _w(s, end + 1).end()
except IndexError:
pass
value, end = scan_once(s, end)
pairs.append((key, value))
try:
nextchar = s[end]
if nextchar in _ws:
end = _w(s, end + 1).end()
nextchar = s[end]
except IndexError:
nextchar = ''
end += 1
if nextchar == '}':
break
elif nextchar != ',':
raise JSONDecodeError("Expecting ',' delimiter or '}'", s, end - 1)
try:
nextchar = s[end]
if nextchar in _ws:
end += 1
nextchar = s[end]
if nextchar in _ws:
end = _w(s, end + 1).end()
nextchar = s[end]
except IndexError:
nextchar = ''
end += 1
if nextchar != '"':
raise JSONDecodeError(
"Expecting property name enclosed in double quotes",
s, end - 1)
if object_pairs_hook is not None:
result = object_pairs_hook(pairs)
return result, end
pairs = dict(pairs)
if object_hook is not None:
pairs = object_hook(pairs)
return pairs, end
def JSONArray(state, scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR):
(s, end) = state
values = []
nextchar = s[end:end + 1]
if nextchar in _ws:
end = _w(s, end + 1).end()
nextchar = s[end:end + 1]
# Look-ahead for trivial empty array
if nextchar == ']':
return values, end + 1
elif nextchar == '':
raise JSONDecodeError("Expecting value or ']'", s, end)
_append = values.append
while True:
value, end = scan_once(s, end)
_append(value)
nextchar = s[end:end + 1]
if nextchar in _ws:
end = _w(s, end + 1).end()
nextchar = s[end:end + 1]
end += 1
if nextchar == ']':
break
elif nextchar != ',':
raise JSONDecodeError("Expecting ',' delimiter or ']'", s, end - 1)
try:
if s[end] in _ws:
end += 1
if s[end] in _ws:
end = _w(s, end + 1).end()
except IndexError:
pass
return values, end
class JSONDecoder(object):
"""Simple JSON decoder
Performs the following translations in decoding by default:
+---------------+-------------------+
| JSON | Python |
+===============+===================+
| object | dict |
+---------------+-------------------+
| array | list |
+---------------+-------------------+
| string | str, unicode |
+---------------+-------------------+
| number (int) | int, long |
+---------------+-------------------+
| number (real) | float |
+---------------+-------------------+
| true | True |
+---------------+-------------------+
| false | False |
+---------------+-------------------+
| null | None |
+---------------+-------------------+
It also understands ``NaN``, ``Infinity``, and ``-Infinity`` as
their corresponding ``float`` values, which is outside the JSON spec.
"""
def __init__(self, encoding=None, object_hook=None, parse_float=None,
parse_int=None, parse_constant=None, strict=True,
object_pairs_hook=None):
"""
*encoding* determines the encoding used to interpret any
:class:`str` objects decoded by this instance (``'utf-8'`` by
default). It has no effect when decoding :class:`unicode` objects.
Note that currently only encodings that are a superset of ASCII work,
strings of other encodings should be passed in as :class:`unicode`.
*object_hook*, if specified, will be called with the result of every
JSON object decoded and its return value will be used in place of the
given :class:`dict`. This can be used to provide custom
deserializations (e.g. to support JSON-RPC class hinting).
*object_pairs_hook* is an optional function that will be called with
the result of any object literal decode with an ordered list of pairs.
The return value of *object_pairs_hook* will be used instead of the
:class:`dict`. This feature can be used to implement custom decoders
that rely on the order that the key and value pairs are decoded (for
example, :func:`collections.OrderedDict` will remember the order of
insertion). If *object_hook* is also defined, the *object_pairs_hook*
takes priority.
*parse_float*, if specified, will be called with the string of every
JSON float to be decoded. By default, this is equivalent to
``float(num_str)``. This can be used to use another datatype or parser
for JSON floats (e.g. :class:`decimal.Decimal`).
*parse_int*, if specified, will be called with the string of every
JSON int to be decoded. By default, this is equivalent to
``int(num_str)``. This can be used to use another datatype or parser
for JSON integers (e.g. :class:`float`).
*parse_constant*, if specified, will be called with one of the
following strings: ``'-Infinity'``, ``'Infinity'``, ``'NaN'``. This
can be used to raise an exception if invalid JSON numbers are
encountered.
*strict* controls the parser's behavior when it encounters an
invalid control character in a string. The default setting of
``True`` means that unescaped control characters are parse errors, if
``False`` then control characters will be allowed in strings.
"""
if encoding is None:
encoding = DEFAULT_ENCODING
self.encoding = encoding
self.object_hook = object_hook
self.object_pairs_hook = object_pairs_hook
self.parse_float = parse_float or float
self.parse_int = parse_int or int
self.parse_constant = parse_constant or _CONSTANTS.__getitem__
self.strict = strict
self.parse_object = JSONObject
self.parse_array = JSONArray
self.parse_string = scanstring
self.memo = {}
self.scan_once = make_scanner(self)
def decode(self, s, _w=WHITESPACE.match, _PY3=PY3):
"""Return the Python representation of ``s`` (a ``str`` or ``unicode``
instance containing a JSON document)
"""
if _PY3 and isinstance(s, bytes):
s = str(s, self.encoding)
obj, end = self.raw_decode(s)
end = _w(s, end).end()
if end != len(s):
raise JSONDecodeError("Extra data", s, end, len(s))
return obj
def raw_decode(self, s, idx=0, _w=WHITESPACE.match, _PY3=PY3):
"""Decode a JSON document from ``s`` (a ``str`` or ``unicode``
beginning with a JSON document) and return a 2-tuple of the Python
representation and the index in ``s`` where the document ended.
Optionally, ``idx`` can be used to specify an offset in ``s`` where
the JSON document begins.
This can be used to decode a JSON document from a string that may
have extraneous data at the end.
"""
if idx < 0:
# Ensure that raw_decode bails on negative indexes, the regex
# would otherwise mask this behavior. #98
raise JSONDecodeError('Expecting value', s, idx)
if _PY3 and not isinstance(s, str):
raise TypeError("Input string must be text, not bytes")
# strip UTF-8 bom
if len(s) > idx:
ord0 = ord(s[idx])
if ord0 == 0xfeff:
idx += 1
elif ord0 == 0xef and s[idx:idx + 3] == '\xef\xbb\xbf':
idx += 3
return self.scan_once(s, idx=_w(s, idx).end())
================================================
FILE: simplejson/encoder.py
================================================
"""Implementation of JSONEncoder
"""
from __future__ import absolute_import
import re
from operator import itemgetter
# Do not import Decimal directly to avoid reload issues
import decimal
from .compat import unichr, binary_type, text_type, string_types, integer_types, PY3
def _import_speedups():
try:
from . import _speedups
return _speedups.encode_basestring_ascii, _speedups.make_encoder
except ImportError:
return None, None
c_encode_basestring_ascii, c_make_encoder = _import_speedups()
from .decoder import PosInf
from .raw_json import RawJSON
ESCAPE = re.compile(r'[\x00-\x1f\\"]')
ESCAPE_ASCII = re.compile(r'([\\"]|[^\ -~])')
HAS_UTF8 = re.compile(r'[\x80-\xff]')
ESCAPE_DCT = {
'\\': '\\\\',
'"': '\\"',
'\b': '\\b',
'\f': '\\f',
'\n': '\\n',
'\r': '\\r',
'\t': '\\t',
}
for i in range(0x20):
#ESCAPE_DCT.setdefault(chr(i), '\\u{0:04x}'.format(i))
ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,))
FLOAT_REPR = repr
def encode_basestring(s, _PY3=PY3, _q=u'"'):
"""Return a JSON representation of a Python string
"""
if _PY3:
if isinstance(s, bytes):
s = str(s, 'utf-8')
elif type(s) is not str:
# convert an str subclass instance to exact str
# raise a TypeError otherwise
s = str.__str__(s)
else:
if isinstance(s, str) and HAS_UTF8.search(s) is not None:
s = unicode(s, 'utf-8')
elif type(s) not in (str, unicode):
# convert an str subclass instance to exact str
# convert a unicode subclass instance to exact unicode
# raise a TypeError otherwise
if isinstance(s, str):
s = str.__str__(s)
else:
s = unicode.__getnewargs__(s)[0]
def replace(match):
return ESCAPE_DCT[match.group(0)]
return _q + ESCAPE.sub(replace, s) + _q
def py_encode_basestring_ascii(s, _PY3=PY3):
"""Return an ASCII-only JSON representation of a Python string
"""
if _PY3:
if isinstance(s, bytes):
s = str(s, 'utf-8')
elif type(s) is not str:
# convert an str subclass instance to exact str
# raise a TypeError otherwise
s = str.__str__(s)
else:
if isinstance(s, str) and HAS_UTF8.search(s) is not None:
s = unicode(s, 'utf-8')
elif type(s) not in (str, unicode):
# convert an str subclass instance to exact str
# convert a unicode subclass instance to exact unicode
# raise a TypeError otherwise
if isinstance(s, str):
s = str.__str__(s)
else:
s = unicode.__getnewargs__(s)[0]
def replace(match):
s = match.group(0)
try:
return ESCAPE_DCT[s]
except KeyError:
n = ord(s)
if n < 0x10000:
#return '\\u{0:04x}'.format(n)
return '\\u%04x' % (n,)
else:
# surrogate pair
n -= 0x10000
s1 = 0xd800 | ((n >> 10) & 0x3ff)
s2 = 0xdc00 | (n & 0x3ff)
#return '\\u{0:04x}\\u{1:04x}'.format(s1, s2)
return '\\u%04x\\u%04x' % (s1, s2)
return '"' + str(ESCAPE_ASCII.sub(replace, s)) + '"'
encode_basestring_ascii = (
c_encode_basestring_ascii or py_encode_basestring_ascii)
class JSONEncoder(object):
"""Extensible JSON encoder for Python data structures.
Supports the following objects and types by default:
+-------------------+---------------+
| Python | JSON |
+===================+===============+
| dict, namedtuple | object |
+-------------------+---------------+
| list, tuple | array |
+-------------------+---------------+
| str, unicode | string |
+-------------------+---------------+
| int, long, float | number |
+-------------------+---------------+
| True | true |
+-------------------+---------------+
| False | false |
+-------------------+---------------+
| None | null |
+-------------------+---------------+
To extend this to recognize other objects, subclass and implement a
``.default()`` method with another method that returns a serializable
object for ``o`` if possible, otherwise it should call the superclass
implementation (to raise ``TypeError``).
"""
item_separator = ', '
key_separator = ': '
def __init__(self, skipkeys=False, ensure_ascii=True,
check_circular=True, allow_nan=True, sort_keys=False,
indent=None, separators=None, encoding='utf-8', default=None,
use_decimal=True, namedtuple_as_object=True,
tuple_as_array=True, bigint_as_string=False,
item_sort_key=None, for_json=False, ignore_nan=False,
int_as_string_bitcount=None, iterable_as_array=False):
"""Constructor for JSONEncoder, with sensible defaults.
If skipkeys is false, then it is a TypeError to attempt
encoding of keys that are not str, int, long, float or None. If
skipkeys is True, such items are simply skipped.
If ensure_ascii is true, the output is guaranteed to be str
objects with all incoming unicode characters escaped. If
ensure_ascii is false, the output will be unicode object.
If check_circular is true, then lists, dicts, and custom encoded
objects will be checked for circular references during encoding to
prevent an infinite recursion (which would cause an OverflowError).
Otherwise, no such check takes place.
If allow_nan is true, then NaN, Infinity, and -Infinity will be
encoded as such. This behavior is not JSON specification compliant,
but is consistent with most JavaScript based encoders and decoders.
Otherwise, it will be a ValueError to encode such floats.
If sort_keys is true, then the output of dictionaries will be
sorted by key; this is useful for regression tests to ensure
that JSON serializations can be compared on a day-to-day basis.
If indent is a string, then JSON array elements and object members
will be pretty-printed with a newline followed by that string repeated
for each level of nesting. ``None`` (the default) selects the most compact
representation without any newlines. For backwards compatibility with
versions of simplejson earlier than 2.1.0, an integer is also accepted
and is converted to a string with that many spaces.
If specified, separators should be an (item_separator, key_separator)
tuple. The default is (', ', ': ') if *indent* is ``None`` and
(',', ': ') otherwise. To get the most compact JSON representation,
you should specify (',', ':') to eliminate whitespace.
If specified, default is a function that gets called for objects
that can't otherwise be serialized. It should return a JSON encodable
version of the object or raise a ``TypeError``.
If encoding is not None, then all input strings will be
transformed into unicode using that encoding prior to JSON-encoding.
The default is UTF-8.
If use_decimal is true (default: ``True``), ``decimal.Decimal`` will
be supported directly by the encoder. For the inverse, decode JSON
with ``parse_float=decimal.Decimal``.
If namedtuple_as_object is true (the default), objects with
``_asdict()`` methods will be encoded as JSON objects.
If tuple_as_array is true (the default), tuple (and subclasses) will
be encoded as JSON arrays.
If *iterable_as_array* is true (default: ``False``),
any object not in the above table that implements ``__iter__()``
will be encoded as a JSON array.
If bigint_as_string is true (not the default), ints 2**53 and higher
or lower than -2**53 will be encoded as strings. This is to avoid the
rounding that happens in Javascript otherwise.
If int_as_string_bitcount is a positive number (n), then int of size
greater than or equal to 2**n or lower than or equal to -2**n will be
encoded as strings.
If specified, item_sort_key is a callable used to sort the items in
each dictionary. This is useful if you want to sort items other than
in alphabetical order by key.
If for_json is true (not the default), objects with a ``for_json()``
method will use the return value of that method for encoding as JSON
instead of the object.
If *ignore_nan* is true (default: ``False``), then out of range
:class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized
as ``null`` in compliance with the ECMA-262 specification. If true,
this will override *allow_nan*.
"""
self.skipkeys = skipkeys
self.ensure_ascii = ensure_ascii
self.check_circular = check_circular
self.allow_nan = allow_nan
self.sort_keys = sort_keys
self.use_decimal = use_decimal
self.namedtuple_as_object = namedtuple_as_object
self.tuple_as_array = tuple_as_array
self.iterable_as_array = iterable_as_array
self.bigint_as_string = bigint_as_string
self.item_sort_key = item_sort_key
self.for_json = for_json
self.ignore_nan = ignore_nan
self.int_as_string_bitcount = int_as_string_bitcount
if indent is not None and not isinstance(indent, string_types):
indent = indent * ' '
self.indent = indent
if separators is not None:
self.item_separator, self.key_separator = separators
elif indent is not None:
self.item_separator = ','
if default is not None:
self.default = default
self.encoding = encoding
def default(self, o):
"""Implement this method in a subclass such that it returns
a serializable object for ``o``, or calls the base implementation
(to raise a ``TypeError``).
For example, to support arbitrary iterators, you could
implement default like this::
def default(self, o):
try:
iterable = iter(o)
except TypeError:
pass
else:
return list(iterable)
return JSONEncoder.default(self, o)
"""
raise TypeError('Object of type %s is not JSON serializable' %
o.__class__.__name__)
def encode(self, o):
"""Return a JSON string representation of a Python data structure.
>>> from simplejson import JSONEncoder
>>> JSONEncoder().encode({"foo": ["bar", "baz"]})
'{"foo": ["bar", "baz"]}'
"""
# This is for extremely simple cases and benchmarks.
if isinstance(o, binary_type):
_encoding = self.encoding
if (_encoding is not None and not (_encoding == 'utf-8')):
o = text_type(o, _encoding)
if isinstance(o, string_types):
if self.ensure_ascii:
return encode_basestring_ascii(o)
else:
return encode_basestring(o)
# This doesn't pass the iterator directly to ''.join() because the
# exceptions aren't as detailed. The list call should be roughly
# equivalent to the PySequence_Fast that ''.join() would do.
chunks = self.iterencode(o, _one_shot=True)
if not isinstance(chunks, (list, tuple)):
chunks = list(chunks)
if self.ensure_ascii:
return ''.join(chunks)
else:
return u''.join(chunks)
def iterencode(self, o, _one_shot=False):
"""Encode the given object and yield each string
representation as available.
For example::
for chunk in JSONEncoder().iterencode(bigobject):
mysocket.write(chunk)
"""
if self.check_circular:
markers = {}
else:
markers = None
if self.ensure_ascii:
_encoder = encode_basestring_ascii
else:
_encoder = encode_basestring
if self.encoding != 'utf-8' and self.encoding is not None:
def _encoder(o, _orig_encoder=_encoder, _encoding=self.encoding):
if isinstance(o, binary_type):
o = text_type(o, _encoding)
return _orig_encoder(o)
def floatstr(o, allow_nan=self.allow_nan, ignore_nan=self.ignore_nan,
_repr=FLOAT_REPR, _inf=PosInf, _neginf=-PosInf):
# Check for specials. Note that this type of test is processor
# and/or platform-specific, so do tests which don't depend on
# the internals.
if o != o:
text = 'NaN'
elif o == _inf:
text = 'Infinity'
elif o == _neginf:
text = '-Infinity'
else:
if type(o) != float:
# See #118, do not trust custom str/repr
o = float(o)
return _repr(o)
if ignore_nan:
text = 'null'
elif not allow_nan:
raise ValueError(
"Out of range float values are not JSON compliant: " +
repr(o))
return text
key_memo = {}
int_as_string_bitcount = (
53 if self.bigint_as_string else self.int_as_string_bitcount)
if (_one_shot and c_make_encoder is not None
and self.indent is None):
_iterencode = c_make_encoder(
markers, self.default, _encoder, self.indent,
self.key_separator, self.item_separator, self.sort_keys,
self.skipkeys, self.allow_nan, key_memo, self.use_decimal,
self.namedtuple_as_object, self.tuple_as_array,
int_as_string_bitcount,
self.item_sort_key, self.encoding, self.for_json,
self.ignore_nan, decimal.Decimal, self.iterable_as_array)
else:
_iterencode = _make_iterencode(
markers, self.default, _encoder, self.indent, floatstr,
self.key_separator, self.item_separator, self.sort_keys,
self.skipkeys, _one_shot, self.use_decimal,
self.namedtuple_as_object, self.tuple_as_array,
int_as_string_bitcount,
self.item_sort_key, self.encoding, self.for_json,
self.iterable_as_array, Decimal=decimal.Decimal)
try:
return _iterencode(o, 0)
finally:
key_memo.clear()
class JSONEncoderForHTML(JSONEncoder):
"""An encoder that produces JSON safe to embed in HTML.
To embed JSON content in, say, a script tag on a web page, the
characters &, < and > should be escaped. They cannot be escaped
with the usual entities (e.g. &) because they are not expanded
within '
self.assertEqual(
r'"\u003c/script\u003e\u003cscript\u003e'
r'alert(\"gotcha\")\u003c/script\u003e"',
self.encoder.encode(bad_string))
self.assertEqual(
bad_string, self.decoder.decode(
self.encoder.encode(bad_string)))
================================================
FILE: simplejson/tests/test_errors.py
================================================
import sys, pickle
from unittest import TestCase
import simplejson as json
from simplejson.compat import text_type, b
class TestErrors(TestCase):
def test_string_keys_error(self):
data = [{'a': 'A', 'b': (2, 4), 'c': 3.0, ('d',): 'D tuple'}]
try:
json.dumps(data)
except TypeError:
err = sys.exc_info()[1]
else:
self.fail('Expected TypeError')
self.assertEqual(str(err),
'keys must be str, int, float, bool or None, not tuple')
def test_not_serializable(self):
try:
json.dumps(json)
except TypeError:
err = sys.exc_info()[1]
else:
self.fail('Expected TypeError')
self.assertEqual(str(err),
'Object of type module is not JSON serializable')
def test_decode_error(self):
err = None
try:
json.loads('{}\na\nb')
except json.JSONDecodeError:
err = sys.exc_info()[1]
else:
self.fail('Expected JSONDecodeError')
self.assertEqual(err.lineno, 2)
self.assertEqual(err.colno, 1)
self.assertEqual(err.endlineno, 3)
self.assertEqual(err.endcolno, 2)
def test_scan_error(self):
err = None
for t in (text_type, b):
try:
json.loads(t('{"asdf": "'))
except json.JSONDecodeError:
err = sys.exc_info()[1]
else:
self.fail('Expected JSONDecodeError')
self.assertEqual(err.lineno, 1)
self.assertEqual(err.colno, 10)
def test_error_is_pickable(self):
err = None
try:
json.loads('{}\na\nb')
except json.JSONDecodeError:
err = sys.exc_info()[1]
else:
self.fail('Expected JSONDecodeError')
s = pickle.dumps(err)
e = pickle.loads(s)
self.assertEqual(err.msg, e.msg)
self.assertEqual(err.doc, e.doc)
self.assertEqual(err.pos, e.pos)
self.assertEqual(err.end, e.end)
================================================
FILE: simplejson/tests/test_fail.py
================================================
import sys
from unittest import TestCase
import simplejson as json
# 2007-10-05
JSONDOCS = [
# http://json.org/JSON_checker/test/fail1.json
'"A JSON payload should be an object or array, not a string."',
# http://json.org/JSON_checker/test/fail2.json
'["Unclosed array"',
# http://json.org/JSON_checker/test/fail3.json
'{unquoted_key: "keys must be quoted"}',
# http://json.org/JSON_checker/test/fail4.json
'["extra comma",]',
# http://json.org/JSON_checker/test/fail5.json
'["double extra comma",,]',
# http://json.org/JSON_checker/test/fail6.json
'[ , "<-- missing value"]',
# http://json.org/JSON_checker/test/fail7.json
'["Comma after the close"],',
# http://json.org/JSON_checker/test/fail8.json
'["Extra close"]]',
# http://json.org/JSON_checker/test/fail9.json
'{"Extra comma": true,}',
# http://json.org/JSON_checker/test/fail10.json
'{"Extra value after close": true} "misplaced quoted value"',
# http://json.org/JSON_checker/test/fail11.json
'{"Illegal expression": 1 + 2}',
# http://json.org/JSON_checker/test/fail12.json
'{"Illegal invocation": alert()}',
# http://json.org/JSON_checker/test/fail13.json
'{"Numbers cannot have leading zeroes": 013}',
# http://json.org/JSON_checker/test/fail14.json
'{"Numbers cannot be hex": 0x14}',
# http://json.org/JSON_checker/test/fail15.json
'["Illegal backslash escape: \\x15"]',
# http://json.org/JSON_checker/test/fail16.json
'[\\naked]',
# http://json.org/JSON_checker/test/fail17.json
'["Illegal backslash escape: \\017"]',
# http://json.org/JSON_checker/test/fail18.json
'[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]',
# http://json.org/JSON_checker/test/fail19.json
'{"Missing colon" null}',
# http://json.org/JSON_checker/test/fail20.json
'{"Double colon":: null}',
# http://json.org/JSON_checker/test/fail21.json
'{"Comma instead of colon", null}',
# http://json.org/JSON_checker/test/fail22.json
'["Colon instead of comma": false]',
# http://json.org/JSON_checker/test/fail23.json
'["Bad value", truth]',
# http://json.org/JSON_checker/test/fail24.json
"['single quote']",
# http://json.org/JSON_checker/test/fail25.json
'["\ttab\tcharacter\tin\tstring\t"]',
# http://json.org/JSON_checker/test/fail26.json
'["tab\\ character\\ in\\ string\\ "]',
# http://json.org/JSON_checker/test/fail27.json
'["line\nbreak"]',
# http://json.org/JSON_checker/test/fail28.json
'["line\\\nbreak"]',
# http://json.org/JSON_checker/test/fail29.json
'[0e]',
# http://json.org/JSON_checker/test/fail30.json
'[0e+]',
# http://json.org/JSON_checker/test/fail31.json
'[0e+-1]',
# http://json.org/JSON_checker/test/fail32.json
'{"Comma instead if closing brace": true,',
# http://json.org/JSON_checker/test/fail33.json
'["mismatch"}',
# http://code.google.com/p/simplejson/issues/detail?id=3
u'["A\u001FZ control characters in string"]',
# misc based on coverage
'{',
'{]',
'{"foo": "bar"]',
'{"foo": "bar"',
'nul',
'nulx',
'-',
'-x',
'-e',
'-e0',
'-Infinite',
'-Inf',
'Infinit',
'Infinite',
'NaM',
'NuN',
'falsy',
'fal',
'trug',
'tru',
'1e',
'1ex',
'1e-',
'1e-x',
]
SKIPS = {
1: "why not have a string payload?",
18: "spec doesn't specify any nesting limitations",
}
class TestFail(TestCase):
def test_failures(self):
for idx, doc in enumerate(JSONDOCS):
idx = idx + 1
if idx in SKIPS:
json.loads(doc)
continue
try:
json.loads(doc)
except json.JSONDecodeError:
pass
else:
self.fail("Expected failure for fail%d.json: %r" % (idx, doc))
def test_array_decoder_issue46(self):
# http://code.google.com/p/simplejson/issues/detail?id=46
for doc in [u'[,]', '[,]']:
try:
json.loads(doc)
except json.JSONDecodeError:
e = sys.exc_info()[1]
self.assertEqual(e.pos, 1)
self.assertEqual(e.lineno, 1)
self.assertEqual(e.colno, 2)
except Exception:
e = sys.exc_info()[1]
self.fail("Unexpected exception raised %r %s" % (e, e))
else:
self.fail("Unexpected success parsing '[,]'")
def test_truncated_input(self):
test_cases = [
('', 'Expecting value', 0),
('[', "Expecting value or ']'", 1),
('[42', "Expecting ',' delimiter", 3),
('[42,', 'Expecting value', 4),
('["', 'Unterminated string starting at', 1),
('["spam', 'Unterminated string starting at', 1),
('["spam"', "Expecting ',' delimiter", 7),
('["spam",', 'Expecting value', 8),
('{', 'Expecting property name enclosed in double quotes', 1),
('{"', 'Unterminated string starting at', 1),
('{"spam', 'Unterminated string starting at', 1),
('{"spam"', "Expecting ':' delimiter", 7),
('{"spam":', 'Expecting value', 8),
('{"spam":42', "Expecting ',' delimiter", 10),
('{"spam":42,', 'Expecting property name enclosed in double quotes',
11),
('"', 'Unterminated string starting at', 0),
('"spam', 'Unterminated string starting at', 0),
('[,', "Expecting value", 1),
]
for data, msg, idx in test_cases:
try:
json.loads(data)
except json.JSONDecodeError:
e = sys.exc_info()[1]
self.assertEqual(
e.msg[:len(msg)],
msg,
"%r doesn't start with %r for %r" % (e.msg, msg, data))
self.assertEqual(
e.pos, idx,
"pos %r != %r for %r" % (e.pos, idx, data))
except Exception:
e = sys.exc_info()[1]
self.fail("Unexpected exception raised %r %s" % (e, e))
else:
self.fail("Unexpected success parsing '%r'" % (data,))
================================================
FILE: simplejson/tests/test_float.py
================================================
import math
from unittest import TestCase
from simplejson.compat import long_type, text_type
import simplejson as json
from simplejson.decoder import NaN, PosInf, NegInf
class TestFloat(TestCase):
def test_degenerates_allow(self):
for inf in (PosInf, NegInf):
self.assertEqual(json.loads(json.dumps(inf)), inf)
# Python 2.5 doesn't have math.isnan
nan = json.loads(json.dumps(NaN))
self.assertTrue((0 + nan) != nan)
def test_degenerates_ignore(self):
for f in (PosInf, NegInf, NaN):
self.assertEqual(json.loads(json.dumps(f, ignore_nan=True)), None)
def test_degenerates_deny(self):
for f in (PosInf, NegInf, NaN):
self.assertRaises(ValueError, json.dumps, f, allow_nan=False)
def test_floats(self):
for num in [1617161771.7650001, math.pi, math.pi**100,
math.pi**-100, 3.1]:
self.assertEqual(float(json.dumps(num)), num)
self.assertEqual(json.loads(json.dumps(num)), num)
self.assertEqual(json.loads(text_type(json.dumps(num))), num)
def test_ints(self):
for num in [1, long_type(1), 1<<32, 1<<64]:
self.assertEqual(json.dumps(num), str(num))
self.assertEqual(int(json.dumps(num)), num)
self.assertEqual(json.loads(json.dumps(num)), num)
self.assertEqual(json.loads(text_type(json.dumps(num))), num)
================================================
FILE: simplejson/tests/test_for_json.py
================================================
import unittest
import simplejson as json
class ForJson(object):
def for_json(self):
return {'for_json': 1}
class NestedForJson(object):
def for_json(self):
return {'nested': ForJson()}
class ForJsonList(object):
def for_json(self):
return ['list']
class DictForJson(dict):
def for_json(self):
return {'alpha': 1}
class ListForJson(list):
def for_json(self):
return ['list']
class TestForJson(unittest.TestCase):
def assertRoundTrip(self, obj, other, for_json=True):
if for_json is None:
# None will use the default
s = json.dumps(obj)
else:
s = json.dumps(obj, for_json=for_json)
self.assertEqual(
json.loads(s),
other)
def test_for_json_encodes_stand_alone_object(self):
self.assertRoundTrip(
ForJson(),
ForJson().for_json())
def test_for_json_encodes_object_nested_in_dict(self):
self.assertRoundTrip(
{'hooray': ForJson()},
{'hooray': ForJson().for_json()})
def test_for_json_encodes_object_nested_in_list_within_dict(self):
self.assertRoundTrip(
{'list': [0, ForJson(), 2, 3]},
{'list': [0, ForJson().for_json(), 2, 3]})
def test_for_json_encodes_object_nested_within_object(self):
self.assertRoundTrip(
NestedForJson(),
{'nested': {'for_json': 1}})
def test_for_json_encodes_list(self):
self.assertRoundTrip(
ForJsonList(),
ForJsonList().for_json())
def test_for_json_encodes_list_within_object(self):
self.assertRoundTrip(
{'nested': ForJsonList()},
{'nested': ForJsonList().for_json()})
def test_for_json_encodes_dict_subclass(self):
self.assertRoundTrip(
DictForJson(a=1),
DictForJson(a=1).for_json())
def test_for_json_encodes_list_subclass(self):
self.assertRoundTrip(
ListForJson(['l']),
ListForJson(['l']).for_json())
def test_for_json_ignored_if_not_true_with_dict_subclass(self):
for for_json in (None, False):
self.assertRoundTrip(
DictForJson(a=1),
{'a': 1},
for_json=for_json)
def test_for_json_ignored_if_not_true_with_list_subclass(self):
for for_json in (None, False):
self.assertRoundTrip(
ListForJson(['l']),
['l'],
for_json=for_json)
def test_raises_typeerror_if_for_json_not_true_with_object(self):
self.assertRaises(TypeError, json.dumps, ForJson())
self.assertRaises(TypeError, json.dumps, ForJson(), for_json=False)
================================================
FILE: simplejson/tests/test_indent.py
================================================
from unittest import TestCase
import textwrap
import simplejson as json
from simplejson.compat import StringIO
class TestIndent(TestCase):
def test_indent(self):
h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh',
'i-vhbjkhnth',
{'nifty': 87}, {'field': 'yes', 'morefield': False} ]
expect = textwrap.dedent("""\
[
\t[
\t\t"blorpie"
\t],
\t[
\t\t"whoops"
\t],
\t[],
\t"d-shtaeou",
\t"d-nthiouh",
\t"i-vhbjkhnth",
\t{
\t\t"nifty": 87
\t},
\t{
\t\t"field": "yes",
\t\t"morefield": false
\t}
]""")
d1 = json.dumps(h)
d2 = json.dumps(h, indent='\t', sort_keys=True, separators=(',', ': '))
d3 = json.dumps(h, indent=' ', sort_keys=True, separators=(',', ': '))
d4 = json.dumps(h, indent=2, sort_keys=True, separators=(',', ': '))
h1 = json.loads(d1)
h2 = json.loads(d2)
h3 = json.loads(d3)
h4 = json.loads(d4)
self.assertEqual(h1, h)
self.assertEqual(h2, h)
self.assertEqual(h3, h)
self.assertEqual(h4, h)
self.assertEqual(d3, expect.replace('\t', ' '))
self.assertEqual(d4, expect.replace('\t', ' '))
# NOTE: Python 2.4 textwrap.dedent converts tabs to spaces,
# so the following is expected to fail. Python 2.4 is not a
# supported platform in simplejson 2.1.0+.
self.assertEqual(d2, expect)
def test_indent0(self):
h = {3: 1}
def check(indent, expected):
d1 = json.dumps(h, indent=indent)
self.assertEqual(d1, expected)
sio = StringIO()
json.dump(h, sio, indent=indent)
self.assertEqual(sio.getvalue(), expected)
# indent=0 should emit newlines
check(0, '{\n"3": 1\n}')
# indent=None is more compact
check(None, '{"3": 1}')
def test_separators(self):
lst = [1,2,3,4]
expect = '[\n1,\n2,\n3,\n4\n]'
expect_spaces = '[\n1, \n2, \n3, \n4\n]'
# Ensure that separators still works
self.assertEqual(
expect_spaces,
json.dumps(lst, indent=0, separators=(', ', ': ')))
# Force the new defaults
self.assertEqual(
expect,
json.dumps(lst, indent=0, separators=(',', ': ')))
# Added in 2.1.4
self.assertEqual(
expect,
json.dumps(lst, indent=0))
================================================
FILE: simplejson/tests/test_item_sort_key.py
================================================
from unittest import TestCase
import simplejson as json
from operator import itemgetter
class TestItemSortKey(TestCase):
def test_simple_first(self):
a = {'a': 1, 'c': 5, 'jack': 'jill', 'pick': 'axe', 'array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'}
self.assertEqual(
'{"a": 1, "c": 5, "crate": "dog", "jack": "jill", "pick": "axe", "zeak": "oh", "array": [1, 5, 6, 9], "tuple": [83, 12, 3]}',
json.dumps(a, item_sort_key=json.simple_first))
def test_case(self):
a = {'a': 1, 'c': 5, 'Jack': 'jill', 'pick': 'axe', 'Array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'}
self.assertEqual(
'{"Array": [1, 5, 6, 9], "Jack": "jill", "a": 1, "c": 5, "crate": "dog", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}',
json.dumps(a, item_sort_key=itemgetter(0)))
self.assertEqual(
'{"a": 1, "Array": [1, 5, 6, 9], "c": 5, "crate": "dog", "Jack": "jill", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}',
json.dumps(a, item_sort_key=lambda kv: kv[0].lower()))
def test_item_sort_key_value(self):
# https://github.com/simplejson/simplejson/issues/173
a = {'a': 1, 'b': 0}
self.assertEqual(
'{"b": 0, "a": 1}',
json.dumps(a, item_sort_key=lambda kv: kv[1]))
================================================
FILE: simplejson/tests/test_iterable.py
================================================
import unittest
from simplejson.compat import StringIO
import simplejson as json
def iter_dumps(obj, **kw):
return ''.join(json.JSONEncoder(**kw).iterencode(obj))
def sio_dump(obj, **kw):
sio = StringIO()
json.dumps(obj, **kw)
return sio.getvalue()
class TestIterable(unittest.TestCase):
def test_iterable(self):
for l in ([], [1], [1, 2], [1, 2, 3]):
for opts in [{}, {'indent': 2}]:
for dumps in (json.dumps, iter_dumps, sio_dump):
expect = dumps(l, **opts)
default_expect = dumps(sum(l), **opts)
# Default is False
self.assertRaises(TypeError, dumps, iter(l), **opts)
self.assertRaises(TypeError, dumps, iter(l), iterable_as_array=False, **opts)
self.assertEqual(expect, dumps(iter(l), iterable_as_array=True, **opts))
# Ensure that the "default" gets called
self.assertEqual(default_expect, dumps(iter(l), default=sum, **opts))
self.assertEqual(default_expect, dumps(iter(l), iterable_as_array=False, default=sum, **opts))
# Ensure that the "default" does not get called
self.assertEqual(
expect,
dumps(iter(l), iterable_as_array=True, default=sum, **opts))
================================================
FILE: simplejson/tests/test_namedtuple.py
================================================
from __future__ import absolute_import
import unittest
import simplejson as json
from simplejson.compat import StringIO
try:
from collections import namedtuple
except ImportError:
class Value(tuple):
def __new__(cls, *args):
return tuple.__new__(cls, args)
def _asdict(self):
return {'value': self[0]}
class Point(tuple):
def __new__(cls, *args):
return tuple.__new__(cls, args)
def _asdict(self):
return {'x': self[0], 'y': self[1]}
else:
Value = namedtuple('Value', ['value'])
Point = namedtuple('Point', ['x', 'y'])
class DuckValue(object):
def __init__(self, *args):
self.value = Value(*args)
def _asdict(self):
return self.value._asdict()
class DuckPoint(object):
def __init__(self, *args):
self.point = Point(*args)
def _asdict(self):
return self.point._asdict()
class DeadDuck(object):
_asdict = None
class DeadDict(dict):
_asdict = None
CONSTRUCTORS = [
lambda v: v,
lambda v: [v],
lambda v: [{'key': v}],
]
class TestNamedTuple(unittest.TestCase):
def test_namedtuple_dumps(self):
for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]:
d = v._asdict()
self.assertEqual(d, json.loads(json.dumps(v)))
self.assertEqual(
d,
json.loads(json.dumps(v, namedtuple_as_object=True)))
self.assertEqual(d, json.loads(json.dumps(v, tuple_as_array=False)))
self.assertEqual(
d,
json.loads(json.dumps(v, namedtuple_as_object=True,
tuple_as_array=False)))
def test_namedtuple_dumps_false(self):
for v in [Value(1), Point(1, 2)]:
l = list(v)
self.assertEqual(
l,
json.loads(json.dumps(v, namedtuple_as_object=False)))
self.assertRaises(TypeError, json.dumps, v,
tuple_as_array=False, namedtuple_as_object=False)
def test_namedtuple_dump(self):
for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]:
d = v._asdict()
sio = StringIO()
json.dump(v, sio)
self.assertEqual(d, json.loads(sio.getvalue()))
sio = StringIO()
json.dump(v, sio, namedtuple_as_object=True)
self.assertEqual(
d,
json.loads(sio.getvalue()))
sio = StringIO()
json.dump(v, sio, tuple_as_array=False)
self.assertEqual(d, json.loads(sio.getvalue()))
sio = StringIO()
json.dump(v, sio, namedtuple_as_object=True,
tuple_as_array=False)
self.assertEqual(
d,
json.loads(sio.getvalue()))
def test_namedtuple_dump_false(self):
for v in [Value(1), Point(1, 2)]:
l = list(v)
sio = StringIO()
json.dump(v, sio, namedtuple_as_object=False)
self.assertEqual(
l,
json.loads(sio.getvalue()))
self.assertRaises(TypeError, json.dump, v, StringIO(),
tuple_as_array=False, namedtuple_as_object=False)
def test_asdict_not_callable_dump(self):
for f in CONSTRUCTORS:
self.assertRaises(TypeError,
json.dump, f(DeadDuck()), StringIO(), namedtuple_as_object=True)
sio = StringIO()
json.dump(f(DeadDict()), sio, namedtuple_as_object=True)
self.assertEqual(
json.dumps(f({})),
sio.getvalue())
def test_asdict_not_callable_dumps(self):
for f in CONSTRUCTORS:
self.assertRaises(TypeError,
json.dumps, f(DeadDuck()), namedtuple_as_object=True)
self.assertEqual(
json.dumps(f({})),
json.dumps(f(DeadDict()), namedtuple_as_object=True))
================================================
FILE: simplejson/tests/test_pass1.py
================================================
from unittest import TestCase
import simplejson as json
# from http://json.org/JSON_checker/test/pass1.json
JSON = r'''
[
"JSON Test Pattern pass1",
{"object with 1 member":["array with 1 element"]},
{},
[],
-42,
true,
false,
null,
{
"integer": 1234567890,
"real": -9876.543210,
"e": 0.123456789e-12,
"E": 1.234567890E+34,
"": 23456789012E66,
"zero": 0,
"one": 1,
"space": " ",
"quote": "\"",
"backslash": "\\",
"controls": "\b\f\n\r\t",
"slash": "/ & \/",
"alpha": "abcdefghijklmnopqrstuvwyz",
"ALPHA": "ABCDEFGHIJKLMNOPQRSTUVWYZ",
"digit": "0123456789",
"special": "`1~!@#$%^&*()_+-={':[,]}|;.>?",
"hex": "\u0123\u4567\u89AB\uCDEF\uabcd\uef4A",
"true": true,
"false": false,
"null": null,
"array":[ ],
"object":{ },
"address": "50 St. James Street",
"url": "http://www.JSON.org/",
"comment": "// /* */": " ",
" s p a c e d " :[1,2 , 3
,
4 , 5 , 6 ,7 ],"compact": [1,2,3,4,5,6,7],
"jsontext": "{\"object with 1 member\":[\"array with 1 element\"]}",
"quotes": "" \u0022 %22 0x22 034 "",
"\/\\\"\uCAFE\uBABE\uAB98\uFCDE\ubcda\uef4A\b\f\n\r\t`1~!@#$%^&*()_+-=[]{}|;:',./<>?"
: "A key can be any string"
},
0.5 ,98.6
,
99.44
,
1066,
1e1,
0.1e1,
1e-1,
1e00,2e+00,2e-00
,"rosebud"]
'''
class TestPass1(TestCase):
def test_parse(self):
# test in/out equivalence and parsing
res = json.loads(JSON)
out = json.dumps(res)
self.assertEqual(res, json.loads(out))
================================================
FILE: simplejson/tests/test_pass2.py
================================================
from unittest import TestCase
import simplejson as json
# from http://json.org/JSON_checker/test/pass2.json
JSON = r'''
[[[[[[[[[[[[[[[[[[["Not too deep"]]]]]]]]]]]]]]]]]]]
'''
class TestPass2(TestCase):
def test_parse(self):
# test in/out equivalence and parsing
res = json.loads(JSON)
out = json.dumps(res)
self.assertEqual(res, json.loads(out))
================================================
FILE: simplejson/tests/test_pass3.py
================================================
from unittest import TestCase
import simplejson as json
# from http://json.org/JSON_checker/test/pass3.json
JSON = r'''
{
"JSON Test Pattern pass3": {
"The outermost value": "must be an object or array.",
"In this test": "It is an object."
}
}
'''
class TestPass3(TestCase):
def test_parse(self):
# test in/out equivalence and parsing
res = json.loads(JSON)
out = json.dumps(res)
self.assertEqual(res, json.loads(out))
================================================
FILE: simplejson/tests/test_raw_json.py
================================================
import unittest
import simplejson as json
dct1 = {
'key1': 'value1'
}
dct2 = {
'key2': 'value2',
'd1': dct1
}
dct3 = {
'key2': 'value2',
'd1': json.dumps(dct1)
}
dct4 = {
'key2': 'value2',
'd1': json.RawJSON(json.dumps(dct1))
}
class TestRawJson(unittest.TestCase):
def test_normal_str(self):
self.assertNotEqual(json.dumps(dct2), json.dumps(dct3))
def test_raw_json_str(self):
self.assertEqual(json.dumps(dct2), json.dumps(dct4))
self.assertEqual(dct2, json.loads(json.dumps(dct4)))
def test_list(self):
self.assertEqual(
json.dumps([dct2]),
json.dumps([json.RawJSON(json.dumps(dct2))]))
self.assertEqual(
[dct2],
json.loads(json.dumps([json.RawJSON(json.dumps(dct2))])))
def test_direct(self):
self.assertEqual(
json.dumps(dct2),
json.dumps(json.RawJSON(json.dumps(dct2))))
self.assertEqual(
dct2,
json.loads(json.dumps(json.RawJSON(json.dumps(dct2)))))
================================================
FILE: simplejson/tests/test_recursion.py
================================================
from unittest import TestCase
import simplejson as json
class JSONTestObject:
pass
class RecursiveJSONEncoder(json.JSONEncoder):
recurse = False
def default(self, o):
if o is JSONTestObject:
if self.recurse:
return [JSONTestObject]
else:
return 'JSONTestObject'
return json.JSONEncoder.default(o)
class TestRecursion(TestCase):
def test_listrecursion(self):
x = []
x.append(x)
try:
json.dumps(x)
except ValueError:
pass
else:
self.fail("didn't raise ValueError on list recursion")
x = []
y = [x]
x.append(y)
try:
json.dumps(x)
except ValueError:
pass
else:
self.fail("didn't raise ValueError on alternating list recursion")
y = []
x = [y, y]
# ensure that the marker is cleared
json.dumps(x)
def test_dictrecursion(self):
x = {}
x["test"] = x
try:
json.dumps(x)
except ValueError:
pass
else:
self.fail("didn't raise ValueError on dict recursion")
x = {}
y = {"a": x, "b": x}
# ensure that the marker is cleared
json.dumps(y)
def test_defaultrecursion(self):
enc = RecursiveJSONEncoder()
self.assertEqual(enc.encode(JSONTestObject), '"JSONTestObject"')
enc.recurse = True
try:
enc.encode(JSONTestObject)
except ValueError:
pass
else:
self.fail("didn't raise ValueError on default recursion")
================================================
FILE: simplejson/tests/test_scanstring.py
================================================
import sys
from unittest import TestCase
import simplejson as json
import simplejson.decoder
from simplejson.compat import b, PY3
class TestScanString(TestCase):
# The bytes type is intentionally not used in most of these tests
# under Python 3 because the decoder immediately coerces to str before
# calling scanstring. In Python 2 we are testing the code paths
# for both unicode and str.
#
# The reason this is done is because Python 3 would require
# entirely different code paths for parsing bytes and str.
#
def test_py_scanstring(self):
self._test_scanstring(simplejson.decoder.py_scanstring)
def test_c_scanstring(self):
if not simplejson.decoder.c_scanstring:
return
self._test_scanstring(simplejson.decoder.c_scanstring)
self.assertTrue(isinstance(simplejson.decoder.c_scanstring('""', 0)[0], str))
def _test_scanstring(self, scanstring):
if sys.maxunicode == 65535:
self.assertEqual(
scanstring(u'"z\U0001d120x"', 1, None, True),
(u'z\U0001d120x', 6))
else:
self.assertEqual(
scanstring(u'"z\U0001d120x"', 1, None, True),
(u'z\U0001d120x', 5))
self.assertEqual(
scanstring('"\\u007b"', 1, None, True),
(u'{', 8))
self.assertEqual(
scanstring('"A JSON payload should be an object or array, not a string."', 1, None, True),
(u'A JSON payload should be an object or array, not a string.', 60))
self.assertEqual(
scanstring('["Unclosed array"', 2, None, True),
(u'Unclosed array', 17))
self.assertEqual(
scanstring('["extra comma",]', 2, None, True),
(u'extra comma', 14))
self.assertEqual(
scanstring('["double extra comma",,]', 2, None, True),
(u'double extra comma', 21))
self.assertEqual(
scanstring('["Comma after the close"],', 2, None, True),
(u'Comma after the close', 24))
self.assertEqual(
scanstring('["Extra close"]]', 2, None, True),
(u'Extra close', 14))
self.assertEqual(
scanstring('{"Extra comma": true,}', 2, None, True),
(u'Extra comma', 14))
self.assertEqual(
scanstring('{"Extra value after close": true} "misplaced quoted value"', 2, None, True),
(u'Extra value after close', 26))
self.assertEqual(
scanstring('{"Illegal expression": 1 + 2}', 2, None, True),
(u'Illegal expression', 21))
self.assertEqual(
scanstring('{"Illegal invocation": alert()}', 2, None, True),
(u'Illegal invocation', 21))
self.assertEqual(
scanstring('{"Numbers cannot have leading zeroes": 013}', 2, None, True),
(u'Numbers cannot have leading zeroes', 37))
self.assertEqual(
scanstring('{"Numbers cannot be hex": 0x14}', 2, None, True),
(u'Numbers cannot be hex', 24))
self.assertEqual(
scanstring('[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', 21, None, True),
(u'Too deep', 30))
self.assertEqual(
scanstring('{"Missing colon" null}', 2, None, True),
(u'Missing colon', 16))
self.assertEqual(
scanstring('{"Double colon":: null}', 2, None, True),
(u'Double colon', 15))
self.assertEqual(
scanstring('{"Comma instead of colon", null}', 2, None, True),
(u'Comma instead of colon', 25))
self.assertEqual(
scanstring('["Colon instead of comma": false]', 2, None, True),
(u'Colon instead of comma', 25))
self.assertEqual(
scanstring('["Bad value", truth]', 2, None, True),
(u'Bad value', 12))
for c in map(chr, range(0x00, 0x1f)):
self.assertEqual(
scanstring(c + '"', 0, None, False),
(c, 2))
self.assertRaises(
ValueError,
scanstring, c + '"', 0, None, True)
self.assertRaises(ValueError, scanstring, '', 0, None, True)
self.assertRaises(ValueError, scanstring, 'a', 0, None, True)
self.assertRaises(ValueError, scanstring, '\\', 0, None, True)
self.assertRaises(ValueError, scanstring, '\\u', 0, None, True)
self.assertRaises(ValueError, scanstring, '\\u0', 0, None, True)
self.assertRaises(ValueError, scanstring, '\\u01', 0, None, True)
self.assertRaises(ValueError, scanstring, '\\u012', 0, None, True)
self.assertRaises(ValueError, scanstring, '\\u0123', 0, None, True)
if sys.maxunicode > 65535:
self.assertRaises(ValueError,
scanstring, '\\ud834\\u"', 0, None, True)
self.assertRaises(ValueError,
scanstring, '\\ud834\\x0123"', 0, None, True)
def test_issue3623(self):
self.assertRaises(ValueError, json.decoder.scanstring, "xxx", 1,
"xxx")
self.assertRaises(UnicodeDecodeError,
json.encoder.encode_basestring_ascii, b("xx\xff"))
def test_overflow(self):
# Python 2.5 does not have maxsize, Python 3 does not have maxint
maxsize = getattr(sys, 'maxsize', getattr(sys, 'maxint', None))
assert maxsize is not None
self.assertRaises(OverflowError, json.decoder.scanstring, "xxx",
maxsize + 1)
def test_surrogates(self):
scanstring = json.decoder.scanstring
def assertScan(given, expect, test_utf8=True):
givens = [given]
if not PY3 and test_utf8:
givens.append(given.encode('utf8'))
for given in givens:
(res, count) = scanstring(given, 1, None, True)
self.assertEqual(len(given), count)
self.assertEqual(res, expect)
assertScan(
u'"z\\ud834\\u0079x"',
u'z\ud834yx')
assertScan(
u'"z\\ud834\\udd20x"',
u'z\U0001d120x')
assertScan(
u'"z\\ud834\\ud834\\udd20x"',
u'z\ud834\U0001d120x')
assertScan(
u'"z\\ud834x"',
u'z\ud834x')
assertScan(
u'"z\\udd20x"',
u'z\udd20x')
assertScan(
u'"z\ud834x"',
u'z\ud834x')
# It may look strange to join strings together, but Python is drunk.
# https://gist.github.com/etrepum/5538443
assertScan(
u'"z\\ud834\udd20x12345"',
u''.join([u'z\ud834', u'\udd20x12345']))
assertScan(
u'"z\ud834\\udd20x"',
u''.join([u'z\ud834', u'\udd20x']))
# these have different behavior given UTF8 input, because the surrogate
# pair may be joined (in maxunicode > 65535 builds)
assertScan(
u''.join([u'"z\ud834', u'\udd20x"']),
u''.join([u'z\ud834', u'\udd20x']),
test_utf8=False)
self.assertRaises(ValueError,
scanstring, u'"z\\ud83x"', 1, None, True)
self.assertRaises(ValueError,
scanstring, u'"z\\ud834\\udd2x"', 1, None, True)
================================================
FILE: simplejson/tests/test_separators.py
================================================
import textwrap
from unittest import TestCase
import simplejson as json
class TestSeparators(TestCase):
def test_separators(self):
h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', 'i-vhbjkhnth',
{'nifty': 87}, {'field': 'yes', 'morefield': False} ]
expect = textwrap.dedent("""\
[
[
"blorpie"
] ,
[
"whoops"
] ,
[] ,
"d-shtaeou" ,
"d-nthiouh" ,
"i-vhbjkhnth" ,
{
"nifty" : 87
} ,
{
"field" : "yes" ,
"morefield" : false
}
]""")
d1 = json.dumps(h)
d2 = json.dumps(h, indent=' ', sort_keys=True, separators=(' ,', ' : '))
h1 = json.loads(d1)
h2 = json.loads(d2)
self.assertEqual(h1, h)
self.assertEqual(h2, h)
self.assertEqual(d2, expect)
================================================
FILE: simplejson/tests/test_speedups.py
================================================
from __future__ import with_statement
import sys
import unittest
from unittest import TestCase
import simplejson
from simplejson import encoder, decoder, scanner
from simplejson.compat import PY3, long_type, b
def has_speedups():
return encoder.c_make_encoder is not None
def skip_if_speedups_missing(func):
def wrapper(*args, **kwargs):
if not has_speedups():
if hasattr(unittest, 'SkipTest'):
raise unittest.SkipTest("C Extension not available")
else:
sys.stdout.write("C Extension not available")
return
return func(*args, **kwargs)
return wrapper
class BadBool:
def __bool__(self):
1/0
__nonzero__ = __bool__
class TestDecode(TestCase):
@skip_if_speedups_missing
def test_make_scanner(self):
self.assertRaises(AttributeError, scanner.c_make_scanner, 1)
@skip_if_speedups_missing
def test_bad_bool_args(self):
def test(value):
decoder.JSONDecoder(strict=BadBool()).decode(value)
self.assertRaises(ZeroDivisionError, test, '""')
self.assertRaises(ZeroDivisionError, test, '{}')
if not PY3:
self.assertRaises(ZeroDivisionError, test, u'""')
self.assertRaises(ZeroDivisionError, test, u'{}')
class TestEncode(TestCase):
@skip_if_speedups_missing
def test_make_encoder(self):
self.assertRaises(
TypeError,
encoder.c_make_encoder,
None,
("\xCD\x7D\x3D\x4E\x12\x4C\xF9\x79\xD7"
"\x52\xBA\x82\xF2\x27\x4A\x7D\xA0\xCA\x75"),
None
)
@skip_if_speedups_missing
def test_bad_str_encoder(self):
# Issue #31505: There shouldn't be an assertion failure in case
# c_make_encoder() receives a bad encoder() argument.
import decimal
def bad_encoder1(*args):
return None
enc = encoder.c_make_encoder(
None, lambda obj: str(obj),
bad_encoder1, None, ': ', ', ',
False, False, False, {}, False, False, False,
None, None, 'utf-8', False, False, decimal.Decimal, False)
self.assertRaises(TypeError, enc, 'spam', 4)
self.assertRaises(TypeError, enc, {'spam': 42}, 4)
def bad_encoder2(*args):
1/0
enc = encoder.c_make_encoder(
None, lambda obj: str(obj),
bad_encoder2, None, ': ', ', ',
False, False, False, {}, False, False, False,
None, None, 'utf-8', False, False, decimal.Decimal, False)
self.assertRaises(ZeroDivisionError, enc, 'spam', 4)
@skip_if_speedups_missing
def test_bad_bool_args(self):
def test(name):
encoder.JSONEncoder(**{name: BadBool()}).encode({})
self.assertRaises(ZeroDivisionError, test, 'skipkeys')
self.assertRaises(ZeroDivisionError, test, 'ensure_ascii')
self.assertRaises(ZeroDivisionError, test, 'check_circular')
self.assertRaises(ZeroDivisionError, test, 'allow_nan')
self.assertRaises(ZeroDivisionError, test, 'sort_keys')
self.assertRaises(ZeroDivisionError, test, 'use_decimal')
self.assertRaises(ZeroDivisionError, test, 'namedtuple_as_object')
self.assertRaises(ZeroDivisionError, test, 'tuple_as_array')
self.assertRaises(ZeroDivisionError, test, 'bigint_as_string')
self.assertRaises(ZeroDivisionError, test, 'for_json')
self.assertRaises(ZeroDivisionError, test, 'ignore_nan')
self.assertRaises(ZeroDivisionError, test, 'iterable_as_array')
@skip_if_speedups_missing
def test_int_as_string_bitcount_overflow(self):
long_count = long_type(2)**32+31
def test():
encoder.JSONEncoder(int_as_string_bitcount=long_count).encode(0)
self.assertRaises((TypeError, OverflowError), test)
if PY3:
@skip_if_speedups_missing
def test_bad_encoding(self):
with self.assertRaises(UnicodeEncodeError):
encoder.JSONEncoder(encoding='\udcff').encode({b('key'): 123})
================================================
FILE: simplejson/tests/test_str_subclass.py
================================================
from unittest import TestCase
import simplejson
from simplejson.compat import text_type
# Tests for issue demonstrated in https://github.com/simplejson/simplejson/issues/144
class WonkyTextSubclass(text_type):
def __getslice__(self, start, end):
return self.__class__('not what you wanted!')
class TestStrSubclass(TestCase):
def test_dump_load(self):
for s in ['', '"hello"', 'text', u'\u005c']:
self.assertEqual(
s,
simplejson.loads(simplejson.dumps(WonkyTextSubclass(s))))
self.assertEqual(
s,
simplejson.loads(simplejson.dumps(WonkyTextSubclass(s),
ensure_ascii=False)))
================================================
FILE: simplejson/tests/test_subclass.py
================================================
from unittest import TestCase
import simplejson as json
from decimal import Decimal
class AlternateInt(int):
def __repr__(self):
return 'invalid json'
__str__ = __repr__
class AlternateFloat(float):
def __repr__(self):
return 'invalid json'
__str__ = __repr__
# class AlternateDecimal(Decimal):
# def __repr__(self):
# return 'invalid json'
class TestSubclass(TestCase):
def test_int(self):
self.assertEqual(json.dumps(AlternateInt(1)), '1')
self.assertEqual(json.dumps(AlternateInt(-1)), '-1')
self.assertEqual(json.loads(json.dumps({AlternateInt(1): 1})), {'1': 1})
def test_float(self):
self.assertEqual(json.dumps(AlternateFloat(1.0)), '1.0')
self.assertEqual(json.dumps(AlternateFloat(-1.0)), '-1.0')
self.assertEqual(json.loads(json.dumps({AlternateFloat(1.0): 1})), {'1.0': 1})
# NOTE: Decimal subclasses are not supported as-is
# def test_decimal(self):
# self.assertEqual(json.dumps(AlternateDecimal('1.0')), '1.0')
# self.assertEqual(json.dumps(AlternateDecimal('-1.0')), '-1.0')
================================================
FILE: simplejson/tests/test_tool.py
================================================
from __future__ import with_statement
import os
import sys
import textwrap
import unittest
import subprocess
import tempfile
try:
# Python 3.x
from test.support import strip_python_stderr
except ImportError:
# Python 2.6+
try:
from test.test_support import strip_python_stderr
except ImportError:
# Python 2.5
import re
def strip_python_stderr(stderr):
return re.sub(
r"\[\d+ refs\]\r?\n?$".encode(),
"".encode(),
stderr).strip()
def open_temp_file():
if sys.version_info >= (2, 6):
file = tempfile.NamedTemporaryFile(delete=False)
filename = file.name
else:
fd, filename = tempfile.mkstemp()
file = os.fdopen(fd, 'w+b')
return file, filename
class TestTool(unittest.TestCase):
data = """
[["blorpie"],[ "whoops" ] , [
],\t"d-shtaeou",\r"d-nthiouh",
"i-vhbjkhnth", {"nifty":87}, {"morefield" :\tfalse,"field"
:"yes"} ]
"""
expect = textwrap.dedent("""\
[
[
"blorpie"
],
[
"whoops"
],
[],
"d-shtaeou",
"d-nthiouh",
"i-vhbjkhnth",
{
"nifty": 87
},
{
"field": "yes",
"morefield": false
}
]
""")
def runTool(self, args=None, data=None):
argv = [sys.executable, '-m', 'simplejson.tool']
if args:
argv.extend(args)
proc = subprocess.Popen(argv,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE)
out, err = proc.communicate(data)
self.assertEqual(strip_python_stderr(err), ''.encode())
self.assertEqual(proc.returncode, 0)
return out.decode('utf8').splitlines()
def test_stdin_stdout(self):
self.assertEqual(
self.runTool(data=self.data.encode()),
self.expect.splitlines())
def test_infile_stdout(self):
infile, infile_name = open_temp_file()
try:
infile.write(self.data.encode())
infile.close()
self.assertEqual(
self.runTool(args=[infile_name]),
self.expect.splitlines())
finally:
os.unlink(infile_name)
def test_infile_outfile(self):
infile, infile_name = open_temp_file()
try:
infile.write(self.data.encode())
infile.close()
# outfile will get overwritten by tool, so the delete
# may not work on some platforms. Do it manually.
outfile, outfile_name = open_temp_file()
try:
outfile.close()
self.assertEqual(
self.runTool(args=[infile_name, outfile_name]),
[])
with open(outfile_name, 'rb') as f:
self.assertEqual(
f.read().decode('utf8').splitlines(),
self.expect.splitlines()
)
finally:
os.unlink(outfile_name)
finally:
os.unlink(infile_name)
================================================
FILE: simplejson/tests/test_tuple.py
================================================
import unittest
from simplejson.compat import StringIO
import simplejson as json
class TestTuples(unittest.TestCase):
def test_tuple_array_dumps(self):
t = (1, 2, 3)
expect = json.dumps(list(t))
# Default is True
self.assertEqual(expect, json.dumps(t))
self.assertEqual(expect, json.dumps(t, tuple_as_array=True))
self.assertRaises(TypeError, json.dumps, t, tuple_as_array=False)
# Ensure that the "default" does not get called
self.assertEqual(expect, json.dumps(t, default=repr))
self.assertEqual(expect, json.dumps(t, tuple_as_array=True,
default=repr))
# Ensure that the "default" gets called
self.assertEqual(
json.dumps(repr(t)),
json.dumps(t, tuple_as_array=False, default=repr))
def test_tuple_array_dump(self):
t = (1, 2, 3)
expect = json.dumps(list(t))
# Default is True
sio = StringIO()
json.dump(t, sio)
self.assertEqual(expect, sio.getvalue())
sio = StringIO()
json.dump(t, sio, tuple_as_array=True)
self.assertEqual(expect, sio.getvalue())
self.assertRaises(TypeError, json.dump, t, StringIO(),
tuple_as_array=False)
# Ensure that the "default" does not get called
sio = StringIO()
json.dump(t, sio, default=repr)
self.assertEqual(expect, sio.getvalue())
sio = StringIO()
json.dump(t, sio, tuple_as_array=True, default=repr)
self.assertEqual(expect, sio.getvalue())
# Ensure that the "default" gets called
sio = StringIO()
json.dump(t, sio, tuple_as_array=False, default=repr)
self.assertEqual(
json.dumps(repr(t)),
sio.getvalue())
================================================
FILE: simplejson/tests/test_unicode.py
================================================
import sys
import codecs
from unittest import TestCase
import simplejson as json
from simplejson.compat import unichr, text_type, b, BytesIO
class TestUnicode(TestCase):
def test_encoding1(self):
encoder = json.JSONEncoder(encoding='utf-8')
u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
s = u.encode('utf-8')
ju = encoder.encode(u)
js = encoder.encode(s)
self.assertEqual(ju, js)
def test_encoding2(self):
u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
s = u.encode('utf-8')
ju = json.dumps(u, encoding='utf-8')
js = json.dumps(s, encoding='utf-8')
self.assertEqual(ju, js)
def test_encoding3(self):
u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
j = json.dumps(u)
self.assertEqual(j, '"\\u03b1\\u03a9"')
def test_encoding4(self):
u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
j = json.dumps([u])
self.assertEqual(j, '["\\u03b1\\u03a9"]')
def test_encoding5(self):
u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
j = json.dumps(u, ensure_ascii=False)
self.assertEqual(j, u'"' + u + u'"')
def test_encoding6(self):
u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
j = json.dumps([u], ensure_ascii=False)
self.assertEqual(j, u'["' + u + u'"]')
def test_big_unicode_encode(self):
u = u'\U0001d120'
self.assertEqual(json.dumps(u), '"\\ud834\\udd20"')
self.assertEqual(json.dumps(u, ensure_ascii=False), u'"\U0001d120"')
def test_big_unicode_decode(self):
u = u'z\U0001d120x'
self.assertEqual(json.loads('"' + u + '"'), u)
self.assertEqual(json.loads('"z\\ud834\\udd20x"'), u)
def test_unicode_decode(self):
for i in range(0, 0xd7ff):
u = unichr(i)
#s = '"\\u{0:04x}"'.format(i)
s = '"\\u%04x"' % (i,)
self.assertEqual(json.loads(s), u)
def test_object_pairs_hook_with_unicode(self):
s = u'{"xkd":1, "kcw":2, "art":3, "hxm":4, "qrt":5, "pad":6, "hoy":7}'
p = [(u"xkd", 1), (u"kcw", 2), (u"art", 3), (u"hxm", 4),
(u"qrt", 5), (u"pad", 6), (u"hoy", 7)]
self.assertEqual(json.loads(s), eval(s))
self.assertEqual(json.loads(s, object_pairs_hook=lambda x: x), p)
od = json.loads(s, object_pairs_hook=json.OrderedDict)
self.assertEqual(od, json.OrderedDict(p))
self.assertEqual(type(od), json.OrderedDict)
# the object_pairs_hook takes priority over the object_hook
self.assertEqual(json.loads(s,
object_pairs_hook=json.OrderedDict,
object_hook=lambda x: None),
json.OrderedDict(p))
def test_default_encoding(self):
self.assertEqual(json.loads(u'{"a": "\xe9"}'.encode('utf-8')),
{'a': u'\xe9'})
def test_unicode_preservation(self):
self.assertEqual(type(json.loads(u'""')), text_type)
self.assertEqual(type(json.loads(u'"a"')), text_type)
self.assertEqual(type(json.loads(u'["a"]')[0]), text_type)
def test_ensure_ascii_false_returns_unicode(self):
# http://code.google.com/p/simplejson/issues/detail?id=48
self.assertEqual(type(json.dumps([], ensure_ascii=False)), text_type)
self.assertEqual(type(json.dumps(0, ensure_ascii=False)), text_type)
self.assertEqual(type(json.dumps({}, ensure_ascii=False)), text_type)
self.assertEqual(type(json.dumps("", ensure_ascii=False)), text_type)
def test_ensure_ascii_false_bytestring_encoding(self):
# http://code.google.com/p/simplejson/issues/detail?id=48
doc1 = {u'quux': b('Arr\xc3\xaat sur images')}
doc2 = {u'quux': u'Arr\xeat sur images'}
doc_ascii = '{"quux": "Arr\\u00eat sur images"}'
doc_unicode = u'{"quux": "Arr\xeat sur images"}'
self.assertEqual(json.dumps(doc1), doc_ascii)
self.assertEqual(json.dumps(doc2), doc_ascii)
self.assertEqual(json.dumps(doc1, ensure_ascii=False), doc_unicode)
self.assertEqual(json.dumps(doc2, ensure_ascii=False), doc_unicode)
def test_ensure_ascii_linebreak_encoding(self):
# http://timelessrepo.com/json-isnt-a-javascript-subset
s1 = u'\u2029\u2028'
s2 = s1.encode('utf8')
expect = '"\\u2029\\u2028"'
expect_non_ascii = u'"\u2029\u2028"'
self.assertEqual(json.dumps(s1), expect)
self.assertEqual(json.dumps(s2), expect)
self.assertEqual(json.dumps(s1, ensure_ascii=False), expect_non_ascii)
self.assertEqual(json.dumps(s2, ensure_ascii=False), expect_non_ascii)
def test_invalid_escape_sequences(self):
# incomplete escape sequence
self.assertRaises(json.JSONDecodeError, json.loads, '"\\u')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1234')
# invalid escape sequence
self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123x"')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12x4"')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1x34"')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\ux234"')
if sys.maxunicode > 65535:
# invalid escape sequence for low surrogate
self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u"')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0"')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00"')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000"')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000x"')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00x0"')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0x00"')
self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\ux000"')
def test_ensure_ascii_still_works(self):
# in the ascii range, ensure that everything is the same
for c in map(unichr, range(0, 127)):
self.assertEqual(
json.dumps(c, ensure_ascii=False),
json.dumps(c))
snowman = u'\N{SNOWMAN}'
self.assertEqual(
json.dumps(c, ensure_ascii=False),
'"' + c + '"')
def test_strip_bom(self):
content = u"\u3053\u3093\u306b\u3061\u308f"
json_doc = codecs.BOM_UTF8 + b(json.dumps(content))
self.assertEqual(json.load(BytesIO(json_doc)), content)
for doc in json_doc, json_doc.decode('utf8'):
self.assertEqual(json.loads(doc), content)
================================================
FILE: simplejson/tool.py
================================================
r"""Command-line tool to validate and pretty-print JSON
Usage::
$ echo '{"json":"obj"}' | python -m simplejson.tool
{
"json": "obj"
}
$ echo '{ 1.2:3.4}' | python -m simplejson.tool
Expecting property name: line 1 column 2 (char 2)
"""
from __future__ import with_statement
import sys
import simplejson as json
def main():
if len(sys.argv) == 1:
infile = sys.stdin
outfile = sys.stdout
elif len(sys.argv) == 2:
infile = open(sys.argv[1], 'r')
outfile = sys.stdout
elif len(sys.argv) == 3:
infile = open(sys.argv[1], 'r')
outfile = open(sys.argv[2], 'w')
else:
raise SystemExit(sys.argv[0] + " [infile [outfile]]")
with infile:
try:
obj = json.load(infile,
object_pairs_hook=json.OrderedDict,
use_decimal=True)
except ValueError:
raise SystemExit(sys.exc_info()[1])
with outfile:
json.dump(obj, outfile, sort_keys=True, indent=' ', use_decimal=True)
outfile.write('\n')
if __name__ == '__main__':
main()
================================================
FILE: train.py
================================================
import time
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
from utils import *
from network import *
from configs import *
import math
import argparse
import models.resnet as resnet
import models.densenet as densenet
from models import create_model
parser = argparse.ArgumentParser(description='Training code for GFNet')
parser.add_argument('--data_url', default='./data', type=str,
help='path to the dataset (ImageNet)')
parser.add_argument('--work_dirs', default='./output', type=str,
help='path to save log and checkpoints')
parser.add_argument('--train_stage', default=-1, type=int,
help='select training stage, see our paper for details \
stage-1 : warm-up \
stage-2 : learn to select patches with RL \
stage-3 : finetune CNNs')
parser.add_argument('--model_arch', default='', type=str,
help='architecture of the model to be trained \
resnet50 / resnet101 / \
densenet121 / densenet169 / densenet201 / \
regnety_600m / regnety_800m / regnety_1.6g / \
mobilenetv3_large_100 / mobilenetv3_large_125 / \
efficientnet_b2 / efficientnet_b3')
parser.add_argument('--patch_size', default=96, type=int,
help='size of local patches (we recommend 96 / 128 / 144)')
parser.add_argument('--T', default=4, type=int,
help='maximum length of the sequence of Glance + Focus')
parser.add_argument('--print_freq', default=100, type=int,
help='the frequency of printing log')
parser.add_argument('--model_prime_path', default='', type=str,
help='path to the pre-trained model of Global Encoder (for training stage-1)')
parser.add_argument('--model_path', default='', type=str,
help='path to the pre-trained model of Local Encoder (for training stage-1)')
parser.add_argument('--checkpoint_path', default='', type=str,
help='path to the stage-2/3 checkpoint (for training stage-2/3)')
parser.add_argument('--resume', default='', type=str,
help='path to the checkpoint for resuming')
args = parser.parse_args()
def main():
if not os.path.isdir(args.work_dirs):
mkdir_p(args.work_dirs)
record_path = args.work_dirs + '/GF-' + str(args.model_arch) \
+ '_patch-size-' + str(args.patch_size) \
+ '_T' + str(args.T) \
+ '_train-stage' + str(args.train_stage)
if not os.path.isdir(record_path):
mkdir_p(record_path)
record_file = record_path + '/record.txt'
# *create model* #
model_configuration = model_configurations[args.model_arch]
if 'resnet' in args.model_arch:
model_arch = 'resnet'
model = resnet.resnet50(pretrained=False)
model_prime = resnet.resnet50(pretrained=False)
elif 'densenet' in args.model_arch:
model_arch = 'densenet'
model = eval('densenet.' + args.model_arch)(pretrained=False)
model_prime = eval('densenet.' + args.model_arch)(pretrained=False)
elif 'efficientnet' in args.model_arch:
model_arch = 'efficientnet'
model = create_model(args.model_arch, pretrained=False, num_classes=1000,
drop_rate=0.3, drop_connect_rate=0.2)
model_prime = create_model(args.model_arch, pretrained=False, num_classes=1000,
drop_rate=0.3, drop_connect_rate=0.2)
elif 'mobilenetv3' in args.model_arch:
model_arch = 'mobilenetv3'
model = create_model(args.model_arch, pretrained=False, num_classes=1000,
drop_rate=0.2, drop_connect_rate=0.2)
model_prime = create_model(args.model_arch, pretrained=False, num_classes=1000,
drop_rate=0.2, drop_connect_rate=0.2)
elif 'regnet' in args.model_arch:
model_arch = 'regnet'
import pycls.core.model_builder as model_builder
from pycls.core.config import cfg
cfg.merge_from_file(model_configuration['cfg_file'])
cfg.freeze()
model = model_builder.build_model()
model_prime = model_builder.build_model()
fc = Full_layer(model_configuration['feature_num'],
model_configuration['fc_hidden_dim'],
model_configuration['fc_rnn'])
if args.train_stage == 1:
model.load_state_dict(torch.load(args.model_path))
model_prime.load_state_dict(torch.load(args.model_prime_path))
else:
checkpoint = torch.load(args.checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
model_prime.load_state_dict(checkpoint['model_prime_state_dict'])
fc.load_state_dict(checkpoint['fc'])
train_configuration = train_configurations[model_arch]
if args.train_stage != 2:
if train_configuration['train_model_prime']:
optimizer = torch.optim.SGD([{'params': model.parameters()},
{'params': model_prime.parameters()},
{'params': fc.parameters()}],
lr=0, # specify in adjust_learning_rate()
momentum=train_configuration['momentum'],
nesterov=train_configuration['Nesterov'],
weight_decay=train_configuration['weight_decay'])
else:
optimizer = torch.optim.SGD([{'params': model.parameters()},
{'params': fc.parameters()}],
lr=0, # specify in adjust_learning_rate()
momentum=train_configuration['momentum'],
nesterov=train_configuration['Nesterov'],
weight_decay=train_configuration['weight_decay'])
training_epoch_num = train_configuration['epoch_num']
else:
optimizer = None
training_epoch_num = 15
criterion = nn.CrossEntropyLoss().cuda()
model = nn.DataParallel(model.cuda())
model_prime = nn.DataParallel(model_prime.cuda())
fc = fc.cuda()
traindir = args.data_url + 'train/'
valdir = args.data_url + 'val/'
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_set = datasets.ImageFolder(traindir, transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
train_set_index = torch.randperm(len(train_set))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, num_workers=32, pin_memory=False,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
train_set_index[:]))
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize, ])),
batch_size=train_configuration['batch_size'], shuffle=False, num_workers=32, pin_memory=False)
if args.train_stage != 1:
state_dim = model_configuration['feature_map_channels'] * math.ceil(args.patch_size / 32) * math.ceil(args.patch_size / 32)
ppo = PPO(model_configuration['feature_map_channels'], state_dim,
model_configuration['policy_hidden_dim'], model_configuration['policy_conv'])
if args.train_stage == 3:
ppo.policy.load_state_dict(checkpoint['policy'])
ppo.policy_old.load_state_dict(checkpoint['policy'])
else:
ppo = None
memory = Memory()
if args.resume:
resume_ckp = torch.load(args.resume)
start_epoch = resume_ckp['epoch']
print('resume from epoch: {}'.format(start_epoch))
model.module.load_state_dict(resume_ckp['model_state_dict'])
model_prime.module.load_state_dict(resume_ckp['model_prime_state_dict'])
fc.load_state_dict(resume_ckp['fc'])
if optimizer:
optimizer.load_state_dict(resume_ckp['optimizer'])
if ppo:
ppo.policy.load_state_dict(resume_ckp['policy'])
ppo.policy_old.load_state_dict(resume_ckp['policy'])
ppo.optimizer.load_state_dict(resume_ckp['ppo_optimizer'])
best_acc = resume_ckp['best_acc']
else:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, training_epoch_num):
if args.train_stage != 2:
print('Training Stage: {}, lr:'.format(args.train_stage))
adjust_learning_rate(optimizer, train_configuration,
epoch, training_epoch_num, args)
else:
print('Training Stage: {}, train ppo only'.format(args.train_stage))
train(model_prime, model, fc, memory, ppo, optimizer, train_loader, criterion,
args.print_freq, epoch, train_configuration['batch_size'], record_file, train_configuration, args)
acc = validate(model_prime, model, fc, memory, ppo, optimizer, val_loader, criterion,
args.print_freq, epoch, train_configuration['batch_size'], record_file, train_configuration, args)
if acc > best_acc:
best_acc = acc
is_best = True
else:
is_best = False
save_checkpoint({
'epoch': epoch + 1,
'model_state_dict': model.module.state_dict(),
'model_prime_state_dict': model_prime.module.state_dict(),
'fc': fc.state_dict(),
'acc': acc,
'best_acc': best_acc,
'optimizer': optimizer.state_dict() if optimizer else None,
'ppo_optimizer': ppo.optimizer.state_dict() if ppo else None,
'policy': ppo.policy.state_dict() if ppo else None,
}, is_best, checkpoint=record_path)
def train(model_prime, model, fc, memory, ppo, optimizer, train_loader, criterion,
print_freq, epoch, batch_size, record_file, train_configuration, args):
batch_time = AverageMeter()
losses = [AverageMeter() for _ in range(args.T)]
top1 = [AverageMeter() for _ in range(args.T)]
reward_list = [AverageMeter() for _ in range(args.T - 1)]
train_batches_num = len(train_loader)
if args.train_stage == 2:
model_prime.eval()
model.eval()
fc.eval()
else:
if train_configuration['train_model_prime']:
model_prime.train()
else:
model_prime.eval()
model.train()
fc.train()
if 'resnet' in args.model_arch or 'densenet' in args.model_arch or 'regnet' in args.model_arch:
dsn_fc_prime = model_prime.module.fc
dsn_fc = model.module.fc
else:
dsn_fc_prime = model_prime.module.classifier
dsn_fc = model.module.classifier
fd = open(record_file, 'a+')
end = time.time()
for i, (x, target) in enumerate(train_loader):
loss_cla = []
loss_list_dsn = []
target_var = target.cuda()
input_var = x.cuda()
input_prime = get_prime(input_var, args.patch_size)
if train_configuration['train_model_prime'] and args.train_stage != 2:
output, state = model_prime(input_prime)
assert 'resnet' in args.model_arch or 'densenet' in args.model_arch or 'regnet' in args.model_arch
output_dsn = dsn_fc_prime(output)
output = fc(output, restart=True)
else:
with torch.no_grad():
output, state = model_prime(input_prime)
if 'resnet' in args.model_arch or 'densenet' in args.model_arch or 'regnet' in args.model_arch:
output_dsn = dsn_fc_prime(output)
output = fc(output, restart=True)
else:
_ = fc(output, restart=True)
output = model_prime.module.classifier(output)
output_dsn = output
loss_prime = criterion(output, target_var)
loss_cla.append(loss_prime)
loss_dsn = criterion(output_dsn, target_var)
loss_list_dsn.append(loss_dsn)
losses[0].update(loss_prime.data.item(), x.size(0))
acc = accuracy(output, target_var, topk=(1,))
top1[0].update(acc.sum(0).mul_(100.0 / batch_size).data.item(), x.size(0))
confidence_last = torch.gather(F.softmax(output.detach(), 1), dim=1, index=target_var.view(-1, 1)).view(1, -1)
for patch_step in range(1, args.T):
if args.train_stage == 1:
action = torch.rand(x.size(0), 2).cuda()
else:
if patch_step == 1:
action = ppo.select_action(state.to(0), memory, restart_batch=True)
else:
action = ppo.select_action(state.to(0), memory)
patches = get_patch(input_var, action, args.patch_size)
if args.train_stage != 2:
output, state = model(patches)
output_dsn = dsn_fc(output)
output = fc(output, restart=False)
else:
with torch.no_grad():
output, state = model(patches)
output_dsn = dsn_fc(output)
output = fc(output, restart=False)
loss = criterion(output, target_var)
loss_cla.append(loss)
losses[patch_step].update(loss.data.item(), x.size(0))
loss_dsn = criterion(output_dsn, target_var)
loss_list_dsn.append(loss_dsn)
acc = accuracy(output, target_var, topk=(1,))
top1[patch_step].update(acc.sum(0).mul_(100.0 / batch_size).data.item(), x.size(0))
confidence = torch.gather(F.softmax(output.detach(), 1), dim=1, index=target_var.view(-1, 1)).view(1, -1)
reward = confidence - confidence_last
confidence_last = confidence
reward_list[patch_step - 1].update(reward.data.mean(), x.size(0))
memory.rewards.append(reward)
loss = (sum(loss_cla) + train_configuration['dsn_ratio'] * sum(loss_list_dsn)) / args.T
if args.train_stage != 2:
optimizer.zero_grad()
loss.backward()
optimizer.step()
else:
ppo.update(memory)
memory.clear_memory()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0 or i == train_batches_num - 1:
string = ('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.value:.3f} ({batch_time.ave:.3f})\t'
'Loss {loss.value:.4f} ({loss.ave:.4f})\t'.format(
epoch, i + 1, train_batches_num, batch_time=batch_time, loss=losses[-1]))
print(string)
fd.write(string + '\n')
_acc = [acc.ave for acc in top1]
print('accuracy of each step:')
print(_acc)
fd.write('accuracy of each step:\n')
fd.write(str(_acc) + '\n')
_reward = [reward.ave for reward in reward_list]
print('reward of each step:')
print(_reward)
fd.write('reward of each step:\n')
fd.write(str(_reward) + '\n')
fd.close()
def validate(model_prime, model, fc, memory, ppo, _, val_loader, criterion,
print_freq, epoch, batch_size, record_file, __, args):
batch_time = AverageMeter()
losses = [AverageMeter() for _ in range(args.T)]
top1 = [AverageMeter() for _ in range(args.T)]
reward_list = [AverageMeter() for _ in range(args.T - 1)]
train_batches_num = len(val_loader)
model_prime.eval()
model.eval()
fc.eval()
if 'resnet' in args.model_arch or 'densenet' in args.model_arch or 'regnet' in args.model_arch:
dsn_fc_prime = model_prime.module.fc
dsn_fc = model.module.fc
else:
dsn_fc_prime = model_prime.module.classifier
dsn_fc = model.module.classifier
fd = open(record_file, 'a+')
end = time.time()
with torch.no_grad():
for i, (x, target) in enumerate(val_loader):
loss_cla = []
loss_list_dsn = []
target_var = target.cuda()
input_var = x.cuda()
input_prime = get_prime(input_var, args.patch_size)
output, state = model_prime(input_prime)
if 'resnet' in args.model_arch or 'densenet' in args.model_arch or 'regnet' in args.model_arch:
output_dsn = dsn_fc_prime(output)
output = fc(output, restart=True)
else:
_ = fc(output, restart=True)
output = model_prime.module.classifier(output)
output_dsn = output
loss_prime = criterion(output, target_var)
loss_cla.append(loss_prime)
loss_dsn = criterion(output_dsn, target_var)
loss_list_dsn.append(loss_dsn)
losses[0].update(loss_prime.data.item(), x.size(0))
acc = accuracy(output, target_var, topk=(1,))
top1[0].update(acc.sum(0).mul_(100.0 / batch_size).data.item(), x.size(0))
confidence_last = torch.gather(F.softmax(output.detach(), 1), dim=1, index=target_var.view(-1, 1)).view(1, -1)
for patch_step in range(1, args.T):
if args.train_stage == 1:
action = torch.rand(x.size(0), 2).cuda()
else:
if patch_step == 1:
action = ppo.select_action(state.to(0), memory, restart_batch=True, training=False)
else:
action = ppo.select_action(state.to(0), memory, training=False)
patches = get_patch(input_var, action, args.patch_size)
output, state = model(patches)
output_dsn = dsn_fc(output)
output = fc(output, restart=False)
loss = criterion(output, target_var)
loss_cla.append(loss)
losses[patch_step].update(loss.data.item(), x.size(0))
loss_dsn = criterion(output_dsn, target_var)
loss_list_dsn.append(loss_dsn)
acc = accuracy(output, target_var, topk=(1,))
top1[patch_step].update(acc.sum(0).mul_(100.0 / batch_size).data.item(), x.size(0))
confidence = torch.gather(F.softmax(output.detach(), 1), dim=1, index=target_var.view(-1, 1)).view(1, -1)
reward = confidence - confidence_last
confidence_last = confidence
reward_list[patch_step - 1].update(reward.data.mean(), x.size(0))
memory.rewards.append(reward)
memory.clear_memory()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0 or i == train_batches_num - 1:
string = ('Val: [{0}][{1}/{2}]\t'
'Time {batch_time.value:.3f} ({batch_time.ave:.3f})\t'
'Loss {loss.value:.4f} ({loss.ave:.4f})\t'.format(
epoch, i + 1, train_batches_num, batch_time=batch_time, loss=losses[-1]))
print(string)
fd.write(string + '\n')
_acc = [acc.ave for acc in top1]
print('accuracy of each step:')
print(_acc)
fd.write('accuracy of each step:\n')
fd.write(str(_acc) + '\n')
_reward = [reward.ave for reward in reward_list]
print('reward of each step:')
print(_reward)
fd.write('reward of each step:\n')
fd.write(str(_reward) + '\n')
fd.close()
return top1[args.T - 1].ave
if __name__ == '__main__':
main()
================================================
FILE: utils.py
================================================
import os
import errno
import math
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
def mkdir_p(path):
'''make dir if not exist'''
try:
os.mkdir(path)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.value = 0
self.ave = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.value = val
self.sum += val * n
self.count += n
self.ave = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
correct_k = correct[:1].view(-1).float()
return correct_k
def get_prime(images, patch_size, interpolation='bicubic'):
"""Get down-sampled original image"""
prime = F.interpolate(images, size=[patch_size, patch_size], mode=interpolation, align_corners=True)
return prime
def get_patch(images, action_sequence, patch_size):
"""Get small patch of the original image"""
batch_size = images.size(0)
image_size = images.size(2)
patch_coordinate = torch.floor(action_sequence * (image_size - patch_size)).int()
patches = []
for i in range(batch_size):
per_patch = images[i, :,
(patch_coordinate[i, 0].item()): ((patch_coordinate[i, 0] + patch_size).item()),
(patch_coordinate[i, 1].item()): ((patch_coordinate[i, 1] + patch_size).item())]
patches.append(per_patch.view(1, per_patch.size(0), per_patch.size(1), per_patch.size(2)))
return torch.cat(patches, 0)
def adjust_learning_rate(optimizer, train_configuration, epoch, training_epoch_num, args):
"""Sets the learning rate"""
backbone_lr = 0.5 * train_configuration['backbone_lr'] * \
(1 + math.cos(math.pi * epoch / training_epoch_num))
if args.train_stage == 1:
fc_lr = 0.5 * train_configuration['fc_stage_1_lr'] * \
(1 + math.cos(math.pi * epoch / training_epoch_num))
elif args.train_stage == 3:
fc_lr = 0.5 * train_configuration['fc_stage_3_lr'] * \
(1 + math.cos(math.pi * epoch / training_epoch_num))
if train_configuration['train_model_prime']:
optimizer.param_groups[0]['lr'] = backbone_lr
optimizer.param_groups[1]['lr'] = backbone_lr
optimizer.param_groups[2]['lr'] = fc_lr
else:
optimizer.param_groups[0]['lr'] = backbone_lr
optimizer.param_groups[1]['lr'] = fc_lr
for param_group in optimizer.param_groups:
print(param_group['lr'])
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
filepath = checkpoint + '/' + filename
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, checkpoint + '/model_best.pth.tar')
================================================
FILE: yacs/__init__.py
================================================
================================================
FILE: yacs/config.py
================================================
# Copyright (c) 2018-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""YACS -- Yet Another Configuration System is designed to be a simple
configuration management system for academic and industrial research
projects.
See README.md for usage and examples.
"""
import copy
import io
import logging
import os
import sys
from ast import literal_eval
import yaml
# Flag for py2 and py3 compatibility to use when separate code paths are necessary
# When _PY2 is False, we assume Python 3 is in use
_PY2 = sys.version_info.major == 2
# Filename extensions for loading configs from files
_YAML_EXTS = {"", ".yaml", ".yml"}
_PY_EXTS = {".py"}
# py2 and py3 compatibility for checking file object type
# We simply use this to infer py2 vs py3
if _PY2:
_FILE_TYPES = (file, io.IOBase)
else:
_FILE_TYPES = (io.IOBase,)
# CfgNodes can only contain a limited set of valid types
_VALID_TYPES = {tuple, list, str, int, float, bool}
# py2 allow for str and unicode
if _PY2:
_VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
# Utilities for importing modules from file paths
if _PY2:
# imp is available in both py2 and py3 for now, but is deprecated in py3
import imp
else:
import importlib.util
logger = logging.getLogger(__name__)
class CfgNode(dict):
"""
CfgNode represents an internal node in the configuration tree. It's a simple
dict-like container that allows for attribute-based access to keys.
"""
IMMUTABLE = "__immutable__"
DEPRECATED_KEYS = "__deprecated_keys__"
RENAMED_KEYS = "__renamed_keys__"
NEW_ALLOWED = "__new_allowed__"
def __init__(self, init_dict=None, key_list=None, new_allowed=False):
"""
Args:
init_dict (dict): the possibly-nested dictionary to initailize the CfgNode.
key_list (list[str]): a list of names which index this CfgNode from the root.
Currently only used for logging purposes.
new_allowed (bool): whether adding new key is allowed when merging with
other configs.
"""
# Recursively convert nested dictionaries in init_dict into CfgNodes
init_dict = {} if init_dict is None else init_dict
key_list = [] if key_list is None else key_list
init_dict = self._create_config_tree_from_dict(init_dict, key_list)
super(CfgNode, self).__init__(init_dict)
# Manage if the CfgNode is frozen or not
self.__dict__[CfgNode.IMMUTABLE] = False
# Deprecated options
# If an option is removed from the code and you don't want to break existing
# yaml configs, you can add the full config key as a string to the set below.
self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
# Renamed options
# If you rename a config option, record the mapping from the old name to the new
# name in the dictionary below. Optionally, if the type also changed, you can
# make the value a tuple that specifies first the renamed key and then
# instructions for how to edit the config file.
self.__dict__[CfgNode.RENAMED_KEYS] = {
# 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow
# 'EXAMPLE.OLD.KEY': ( # A more complex example to follow
# 'EXAMPLE.NEW.KEY',
# "Also convert to a tuple, e.g., 'foo' -> ('foo',) or "
# + "'foo:bar' -> ('foo', 'bar')"
# ),
}
# Allow new attributes after initialisation
self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed
@classmethod
def _create_config_tree_from_dict(cls, dic, key_list):
"""
Create a configuration tree using the given dict.
Any dict-like objects inside dict will be treated as a new CfgNode.
Args:
dic (dict):
key_list (list[str]): a list of names which index this CfgNode from the root.
Currently only used for logging purposes.
"""
dic = copy.deepcopy(dic)
for k, v in dic.items():
if isinstance(v, dict):
# Convert dict to CfgNode
dic[k] = cls(v, key_list=key_list + [k])
else:
# Check for valid leaf type or nested CfgNode
_assert_with_logging(
_valid_type(v, allow_cfg_node=False),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list + [k]), type(v), _VALID_TYPES
),
)
return dic
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
if self.is_frozen():
raise AttributeError(
"Attempted to set {} to {}, but CfgNode is immutable".format(
name, value
)
)
_assert_with_logging(
name not in self.__dict__,
"Invalid attempt to modify internal CfgNode state: {}".format(name),
)
_assert_with_logging(
_valid_type(value, allow_cfg_node=True),
"Invalid type {} for key {}; valid types = {}".format(
type(value), name, _VALID_TYPES
),
)
self[name] = value
def __str__(self):
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
r = ""
s = []
for k, v in sorted(self.items()):
seperator = "\n" if isinstance(v, CfgNode) else " "
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += "\n".join(s)
return r
def __repr__(self):
return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
def dump(self, **kwargs):
"""Dump to a string."""
def convert_to_dict(cfg_node, key_list):
if not isinstance(cfg_node, CfgNode):
_assert_with_logging(
_valid_type(cfg_node),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list), type(cfg_node), _VALID_TYPES
),
)
return cfg_node
else:
cfg_dict = dict(cfg_node)
for k, v in cfg_dict.items():
cfg_dict[k] = convert_to_dict(v, key_list + [k])
return cfg_dict
self_as_dict = convert_to_dict(self, [])
return yaml.safe_dump(self_as_dict, **kwargs)
def merge_from_file(self, cfg_filename):
"""Load a yaml config file and merge it this CfgNode."""
with open(cfg_filename, "r") as f:
cfg = self.load_cfg(f)
self.merge_from_other_cfg(cfg)
def merge_from_other_cfg(self, cfg_other):
"""Merge `cfg_other` into this CfgNode."""
_merge_a_into_b(cfg_other, self, self, [])
def merge_from_list(self, cfg_list):
"""Merge config (keys, values) in a list (e.g., from command line) into
this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`.
"""
_assert_with_logging(
len(cfg_list) % 2 == 0,
"Override list has odd length: {}; it must be a list of pairs".format(
cfg_list
),
)
root = self
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
if root.key_is_deprecated(full_key):
continue
if root.key_is_renamed(full_key):
root.raise_key_rename_error(full_key)
key_list = full_key.split(".")
d = self
for subkey in key_list[:-1]:
_assert_with_logging(
subkey in d, "Non-existent key: {}".format(full_key)
)
d = d[subkey]
subkey = key_list[-1]
_assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
value = self._decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
d[subkey] = value
def freeze(self):
"""Make this CfgNode and all of its children immutable."""
self._immutable(True)
def defrost(self):
"""Make this CfgNode and all of its children mutable."""
self._immutable(False)
def is_frozen(self):
"""Return mutability."""
return self.__dict__[CfgNode.IMMUTABLE]
def _immutable(self, is_immutable):
"""Set immutability to is_immutable and recursively apply the setting
to all nested CfgNodes.
"""
self.__dict__[CfgNode.IMMUTABLE] = is_immutable
# Recursively set immutable state
for v in self.__dict__.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
for v in self.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
def clone(self):
"""Recursively copy this CfgNode."""
return copy.deepcopy(self)
def register_deprecated_key(self, key):
"""Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated
keys a warning is generated and the key is ignored.
"""
_assert_with_logging(
key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
"key {} is already registered as a deprecated key".format(key),
)
self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)
def register_renamed_key(self, old_name, new_name, message=None):
"""Register a key as having been renamed from `old_name` to `new_name`.
When merging a renamed key, an exception is thrown alerting to user to
the fact that the key has been renamed.
"""
_assert_with_logging(
old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
"key {} is already registered as a renamed cfg key".format(old_name),
)
value = new_name
if message:
value = (new_name, message)
self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value
def key_is_deprecated(self, full_key):
"""Test if a key is deprecated."""
if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
logger.warning("Deprecated config key (ignoring): {}".format(full_key))
return True
return False
def key_is_renamed(self, full_key):
"""Test if a key is renamed."""
return full_key in self.__dict__[CfgNode.RENAMED_KEYS]
def raise_key_rename_error(self, full_key):
new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
if isinstance(new_key, tuple):
msg = " Note: " + new_key[1]
new_key = new_key[0]
else:
msg = ""
raise KeyError(
"Key {} was renamed to {}; please update your config.{}".format(
full_key, new_key, msg
)
)
def is_new_allowed(self):
return self.__dict__[CfgNode.NEW_ALLOWED]
@classmethod
def load_cfg(cls, cfg_file_obj_or_str):
"""
Load a cfg.
Args:
cfg_file_obj_or_str (str or file):
Supports loading from:
- A file object backed by a YAML file
- A file object backed by a Python source file that exports an attribute
"cfg" that is either a dict or a CfgNode
- A string that can be parsed as valid YAML
"""
_assert_with_logging(
isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
"Expected first argument to be of type {} or {}, but it was {}".format(
_FILE_TYPES, str, type(cfg_file_obj_or_str)
),
)
if isinstance(cfg_file_obj_or_str, str):
return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str)
elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
return cls._load_cfg_from_file(cfg_file_obj_or_str)
else:
raise NotImplementedError("Impossible to reach here (unless there's a bug)")
@classmethod
def _load_cfg_from_file(cls, file_obj):
"""Load a config from a YAML file or a Python source file."""
_, file_extension = os.path.splitext(file_obj.name)
if file_extension in _YAML_EXTS:
return cls._load_cfg_from_yaml_str(file_obj.read())
elif file_extension in _PY_EXTS:
return cls._load_cfg_py_source(file_obj.name)
else:
raise Exception(
"Attempt to load from an unsupported file type {}; "
"only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS))
)
@classmethod
def _load_cfg_from_yaml_str(cls, str_obj):
"""Load a config from a YAML string encoding."""
cfg_as_dict = yaml.safe_load(str_obj)
return cls(cfg_as_dict)
@classmethod
def _load_cfg_py_source(cls, filename):
"""Load a config from a Python source file."""
module = _load_module_from_file("yacs.config.override", filename)
_assert_with_logging(
hasattr(module, "cfg"),
"Python module from file {} must have 'cfg' attr".format(filename),
)
VALID_ATTR_TYPES = {dict, CfgNode}
_assert_with_logging(
type(module.cfg) in VALID_ATTR_TYPES,
"Imported module 'cfg' attr must be in {} but is {} instead".format(
VALID_ATTR_TYPES, type(module.cfg)
),
)
return cls(module.cfg)
@classmethod
def _decode_cfg_value(cls, value):
"""
Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object.
If the value is a dict, it will be interpreted as a new CfgNode.
If the value is a str, it will be evaluated as literals.
Otherwise it is returned as-is.
"""
# Configs parsed from raw yaml will contain dictionary keys that need to be
# converted to CfgNode objects
if isinstance(value, dict):
return cls(value)
# All remaining processing is only applied to strings
if not isinstance(value, str):
return value
# Try to interpret `value` as a:
# string, number, tuple, list, dict, boolean, or None
try:
value = literal_eval(value)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return value
load_cfg = (
CfgNode.load_cfg
) # keep this function in global scope for backward compatibility
def _valid_type(value, allow_cfg_node=False):
return (type(value) in _VALID_TYPES) or (
allow_cfg_node and isinstance(value, CfgNode)
)
def _merge_a_into_b(a, b, root, key_list):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
_assert_with_logging(
isinstance(a, CfgNode),
"`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
)
_assert_with_logging(
isinstance(b, CfgNode),
"`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
)
for k, v_ in a.items():
full_key = ".".join(key_list + [k])
v = copy.deepcopy(v_)
v = b._decode_cfg_value(v)
if k in b:
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
# Recursively merge dicts
if isinstance(v, CfgNode):
try:
_merge_a_into_b(v, b[k], root, key_list + [k])
except BaseException:
raise
else:
b[k] = v
elif b.is_new_allowed():
b[k] = v
else:
if root.key_is_deprecated(full_key):
continue
elif root.key_is_renamed(full_key):
root.raise_key_rename_error(full_key)
else:
raise KeyError("Non-existent config key: {}".format(full_key))
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
"""Checks that `replacement`, which is intended to replace `original` is of
the right type. The type is correct if it matches exactly or is one of a few
cases in which the type can be easily coerced.
"""
original_type = type(original)
replacement_type = type(replacement)
# The types must match (with some exceptions)
if replacement_type == original_type:
return replacement
# Cast replacement from from_type to to_type if the replacement and original
# types match from_type and to_type
def conditional_cast(from_type, to_type):
if replacement_type == from_type and original_type == to_type:
return True, to_type(replacement)
else:
return False, None
# Conditionally casts
# list <-> tuple
casts = [(tuple, list), (list, tuple)]
# For py2: allow converting from str (bytes) to a unicode string
try:
casts.append((str, unicode)) # noqa: F821
except Exception:
pass
for (from_type, to_type) in casts:
converted, converted_value = conditional_cast(from_type, to_type)
if converted:
return converted_value
raise ValueError(
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
"key: {}".format(
original_type, replacement_type, original, replacement, full_key
)
)
def _assert_with_logging(cond, msg):
if not cond:
logger.debug(msg)
assert cond, msg
def _load_module_from_file(name, filename):
if _PY2:
module = imp.load_source(name, filename)
else:
spec = importlib.util.spec_from_file_location(name, filename)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
================================================
FILE: yacs/tests.py
================================================
import logging
import tempfile
import unittest
import yacs.config
from yacs.config import CfgNode as CN
try:
_ignore = unicode # noqa: F821
PY2 = True
except Exception as _ignore:
PY2 = False
class SubCN(CN):
pass
def get_cfg(cls=CN):
cfg = cls()
cfg.NUM_GPUS = 8
cfg.TRAIN = cls()
cfg.TRAIN.HYPERPARAMETER_1 = 0.1
cfg.TRAIN.SCALES = (2, 4, 8, 16)
cfg.MODEL = cls()
cfg.MODEL.TYPE = "a_foo_model"
# Some extra stuff to test CfgNode.__str__
cfg.STR = cls()
cfg.STR.KEY1 = 1
cfg.STR.KEY2 = 2
cfg.STR.FOO = cls()
cfg.STR.FOO.KEY1 = 1
cfg.STR.FOO.KEY2 = 2
cfg.STR.FOO.BAR = cls()
cfg.STR.FOO.BAR.KEY1 = 1
cfg.STR.FOO.BAR.KEY2 = 2
cfg.register_deprecated_key("FINAL_MSG")
cfg.register_deprecated_key("MODEL.DILATION")
cfg.register_renamed_key(
"EXAMPLE.OLD.KEY",
"EXAMPLE.NEW.KEY",
message="Please update your config fil config file.",
)
cfg.KWARGS = cls(new_allowed=True)
cfg.KWARGS.z = 0
cfg.KWARGS.Y = cls()
cfg.KWARGS.Y.X = 1
return cfg
class TestCfgNode(unittest.TestCase):
def test_immutability(self):
# Top level immutable
a = CN()
a.foo = 0
a.freeze()
with self.assertRaises(AttributeError):
a.foo = 1
a.bar = 1
assert a.is_frozen()
assert a.foo == 0
a.defrost()
assert not a.is_frozen()
a.foo = 1
assert a.foo == 1
# Recursively immutable
a.level1 = CN()
a.level1.foo = 0
a.level1.level2 = CN()
a.level1.level2.foo = 0
a.freeze()
assert a.is_frozen()
with self.assertRaises(AttributeError):
a.level1.level2.foo = 1
a.level1.bar = 1
assert a.level1.level2.foo == 0
class TestCfg(unittest.TestCase):
def test_copy_cfg(self):
cfg = get_cfg()
cfg2 = cfg.clone()
s = cfg.MODEL.TYPE
cfg2.MODEL.TYPE = "dummy"
assert cfg.MODEL.TYPE == s
def test_merge_cfg_from_cfg(self):
# Test: merge from clone
cfg = get_cfg()
s = "dummy0"
cfg2 = cfg.clone()
cfg2.MODEL.TYPE = s
cfg.merge_from_other_cfg(cfg2)
assert cfg.MODEL.TYPE == s
# Test: merge from yaml
s = "dummy1"
cfg2 = CN.load_cfg(cfg.dump())
cfg2.MODEL.TYPE = s
cfg.merge_from_other_cfg(cfg2)
assert cfg.MODEL.TYPE == s
# Test: merge with a valid key
s = "dummy2"
cfg2 = CN()
cfg2.MODEL = CN()
cfg2.MODEL.TYPE = s
cfg.merge_from_other_cfg(cfg2)
assert cfg.MODEL.TYPE == s
# Test: merge with an invalid key
s = "dummy3"
cfg2 = CN()
cfg2.FOO = CN()
cfg2.FOO.BAR = s
with self.assertRaises(KeyError):
cfg.merge_from_other_cfg(cfg2)
# Test: merge with converted type
cfg2 = CN()
cfg2.TRAIN = CN()
cfg2.TRAIN.SCALES = [1]
cfg.merge_from_other_cfg(cfg2)
assert type(cfg.TRAIN.SCALES) is tuple
assert cfg.TRAIN.SCALES[0] == 1
# Test str (bytes) <-> unicode conversion for py2
if PY2:
cfg.A_UNICODE_KEY = u"foo"
cfg2 = CN()
cfg2.A_UNICODE_KEY = b"bar"
cfg.merge_from_other_cfg(cfg2)
assert type(cfg.A_UNICODE_KEY) == unicode # noqa: F821
assert cfg.A_UNICODE_KEY == u"bar"
# Test: merge with invalid type
cfg2 = CN()
cfg2.TRAIN = CN()
cfg2.TRAIN.SCALES = 1
with self.assertRaises(ValueError):
cfg.merge_from_other_cfg(cfg2)
def test_merge_cfg_from_file(self):
with tempfile.NamedTemporaryFile(mode="wt") as f:
cfg = get_cfg()
f.write(cfg.dump())
f.flush()
s = cfg.MODEL.TYPE
cfg.MODEL.TYPE = "dummy"
assert cfg.MODEL.TYPE != s
cfg.merge_from_file(f.name)
assert cfg.MODEL.TYPE == s
def test_merge_cfg_from_list(self):
cfg = get_cfg()
opts = ["TRAIN.SCALES", "(100, )", "MODEL.TYPE", "foobar", "NUM_GPUS", 2]
assert len(cfg.TRAIN.SCALES) > 0
assert cfg.TRAIN.SCALES[0] != 100
assert cfg.MODEL.TYPE != "foobar"
assert cfg.NUM_GPUS != 2
cfg.merge_from_list(opts)
assert type(cfg.TRAIN.SCALES) is tuple
assert len(cfg.TRAIN.SCALES) == 1
assert cfg.TRAIN.SCALES[0] == 100
assert cfg.MODEL.TYPE == "foobar"
assert cfg.NUM_GPUS == 2
def test_deprecated_key_from_list(self):
# You should see logger messages like:
# "Deprecated config key (ignoring): MODEL.DILATION"
cfg = get_cfg()
opts = ["FINAL_MSG", "foobar", "MODEL.DILATION", 2]
with self.assertRaises(AttributeError):
_ = cfg.FINAL_MSG # noqa
with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa
cfg.merge_from_list(opts)
with self.assertRaises(AttributeError):
_ = cfg.FINAL_MSG # noqa
with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa
def test_nonexistant_key_from_list(self):
cfg = get_cfg()
opts = ["MODEL.DOES_NOT_EXIST", "IGNORE"]
with self.assertRaises(AssertionError):
cfg.merge_from_list(opts)
def test_load_cfg_invalid_type(self):
# FOO.BAR.QUUX will have type None, which is not allowed
cfg_string = "FOO:\n BAR:\n QUUX:"
with self.assertRaises(AssertionError):
yacs.config.load_cfg(cfg_string)
def test_deprecated_key_from_file(self):
# You should see logger messages like:
# "Deprecated config key (ignoring): MODEL.DILATION"
cfg = get_cfg()
with tempfile.NamedTemporaryFile("wt") as f:
cfg2 = cfg.clone()
cfg2.MODEL.DILATION = 2
f.write(cfg2.dump())
f.flush()
with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa
cfg.merge_from_file(f.name)
with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa
def test_renamed_key_from_list(self):
cfg = get_cfg()
opts = ["EXAMPLE.OLD.KEY", "foobar"]
with self.assertRaises(AttributeError):
_ = cfg.EXAMPLE.OLD.KEY # noqa
with self.assertRaises(KeyError):
cfg.merge_from_list(opts)
def test_renamed_key_from_file(self):
cfg = get_cfg()
with tempfile.NamedTemporaryFile("wt") as f:
cfg2 = cfg.clone()
cfg2.EXAMPLE = CN()
cfg2.EXAMPLE.RENAMED = CN()
cfg2.EXAMPLE.RENAMED.KEY = "foobar"
f.write(cfg2.dump())
f.flush()
with self.assertRaises(AttributeError):
_ = cfg.EXAMPLE.RENAMED.KEY # noqa
with self.assertRaises(KeyError):
cfg.merge_from_file(f.name)
def test_load_cfg_from_file(self):
cfg = get_cfg()
with tempfile.NamedTemporaryFile("wt") as f:
f.write(cfg.dump())
f.flush()
with open(f.name, "rt") as f_read:
yacs.config.load_cfg(f_read)
def test_load_from_python_file(self):
# Case 1: exports CfgNode
cfg = get_cfg()
cfg.merge_from_file("example/config_override.py")
assert cfg.TRAIN.HYPERPARAMETER_1 == 0.9
# Case 2: exports dict
cfg = get_cfg()
cfg.merge_from_file("example/config_override_from_dict.py")
assert cfg.TRAIN.HYPERPARAMETER_1 == 0.9
def test_invalid_type(self):
cfg = get_cfg()
with self.assertRaises(AssertionError):
cfg.INVALID_KEY_TYPE = object()
def test__str__(self):
expected_str = """
KWARGS:
Y:
X: 1
z: 0
MODEL:
TYPE: a_foo_model
NUM_GPUS: 8
STR:
FOO:
BAR:
KEY1: 1
KEY2: 2
KEY1: 1
KEY2: 2
KEY1: 1
KEY2: 2
TRAIN:
HYPERPARAMETER_1: 0.1
SCALES: (2, 4, 8, 16)
""".strip()
cfg = get_cfg()
assert str(cfg) == expected_str
def test_new_allowed(self):
cfg = get_cfg()
cfg.merge_from_file("example/config_new_allowed.yaml")
assert cfg.KWARGS.a == 1
assert cfg.KWARGS.B.c == 2
assert cfg.KWARGS.B.D.e == "3"
def test_new_allowed_bad(self):
cfg = get_cfg()
with self.assertRaises(KeyError):
cfg.merge_from_file("example/config_new_allowed_bad.yaml")
class TestCfgNodeSubclass(unittest.TestCase):
def test_merge_cfg_from_file(self):
with tempfile.NamedTemporaryFile(mode="wt") as f:
cfg = get_cfg(SubCN)
f.write(cfg.dump())
f.flush()
s = cfg.MODEL.TYPE
cfg.MODEL.TYPE = "dummy"
assert cfg.MODEL.TYPE != s
cfg.merge_from_file(f.name)
assert cfg.MODEL.TYPE == s
def test_merge_cfg_from_list(self):
cfg = get_cfg(SubCN)
opts = ["TRAIN.SCALES", "(100, )", "MODEL.TYPE", "foobar", "NUM_GPUS", 2]
assert len(cfg.TRAIN.SCALES) > 0
assert cfg.TRAIN.SCALES[0] != 100
assert cfg.MODEL.TYPE != "foobar"
assert cfg.NUM_GPUS != 2
cfg.merge_from_list(opts)
assert type(cfg.TRAIN.SCALES) is tuple
assert len(cfg.TRAIN.SCALES) == 1
assert cfg.TRAIN.SCALES[0] == 100
assert cfg.MODEL.TYPE == "foobar"
assert cfg.NUM_GPUS == 2
def test_merge_cfg_from_cfg(self):
cfg = get_cfg(SubCN)
cfg2 = get_cfg(SubCN)
s = "dummy0"
cfg2.MODEL.TYPE = s
cfg.merge_from_other_cfg(cfg2)
assert cfg.MODEL.TYPE == s
# Test: merge from yaml
s = "dummy1"
cfg2 = SubCN.load_cfg(cfg.dump())
cfg2.MODEL.TYPE = s
cfg.merge_from_other_cfg(cfg2)
assert cfg.MODEL.TYPE == s
if __name__ == "__main__":
logging.basicConfig()
yacs_logger = logging.getLogger("yacs.config")
yacs_logger.setLevel(logging.DEBUG)
unittest.main()