Repository: blackfeather-wang/GFNet-Pytorch Branch: master Commit: 8a6775c8267c Files: 103 Total size: 610.4 KB Directory structure: gitextract_6lcfcbge/ ├── .gitignore ├── LICENSE ├── README.md ├── configs.py ├── inference.py ├── models/ │ ├── __init__.py │ ├── activations/ │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── activations_autofn.py │ │ ├── activations_jit.py │ │ └── config.py │ ├── config.py │ ├── conv2d_layers.py │ ├── densenet.py │ ├── efficientnet_builder.py │ ├── gen_efficientnet.py │ ├── helpers.py │ ├── mobilenetv3.py │ ├── model_factory.py │ ├── resnet.py │ └── version.py ├── network.py ├── pycls/ │ ├── __init__.py │ ├── cfgs/ │ │ ├── RegNetY-1.6GF_dds_8gpu.yaml │ │ ├── RegNetY-600MF_dds_8gpu.yaml │ │ └── RegNetY-800MF_dds_8gpu.yaml │ ├── core/ │ │ ├── __init__.py │ │ ├── config.py │ │ ├── losses.py │ │ ├── model_builder.py │ │ ├── old_config.py │ │ └── optimizer.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── cifar10.py │ │ ├── imagenet.py │ │ ├── loader.py │ │ ├── paths.py │ │ └── transforms.py │ ├── models/ │ │ ├── __init__.py │ │ ├── anynet.py │ │ ├── effnet.py │ │ ├── regnet.py │ │ └── resnet.py │ └── utils/ │ ├── __init__.py │ ├── benchmark.py │ ├── checkpoint.py │ ├── distributed.py │ ├── error_handler.py │ ├── io.py │ ├── logging.py │ ├── lr_policy.py │ ├── meters.py │ ├── metrics.py │ ├── multiprocessing.py │ ├── net.py │ ├── plotting.py │ └── timer.py ├── simplejson/ │ ├── __init__.py │ ├── _speedups.c │ ├── compat.py │ ├── decoder.py │ ├── encoder.py │ ├── errors.py │ ├── ordered_dict.py │ ├── raw_json.py │ ├── scanner.py │ ├── tests/ │ │ ├── __init__.py │ │ ├── test_bigint_as_string.py │ │ ├── test_bitsize_int_as_string.py │ │ ├── test_check_circular.py │ │ ├── test_decimal.py │ │ ├── test_decode.py │ │ ├── test_default.py │ │ ├── test_dump.py │ │ ├── test_encode_basestring_ascii.py │ │ ├── test_encode_for_html.py │ │ ├── test_errors.py │ │ ├── test_fail.py │ │ ├── test_float.py │ │ ├── test_for_json.py │ │ ├── test_indent.py │ │ ├── test_item_sort_key.py │ │ ├── test_iterable.py │ │ ├── test_namedtuple.py │ │ ├── test_pass1.py │ │ ├── test_pass2.py │ │ ├── test_pass3.py │ │ ├── test_raw_json.py │ │ ├── test_recursion.py │ │ ├── test_scanstring.py │ │ ├── test_separators.py │ │ ├── test_speedups.py │ │ ├── test_str_subclass.py │ │ ├── test_subclass.py │ │ ├── test_tool.py │ │ ├── test_tuple.py │ │ └── test_unicode.py │ └── tool.py ├── train.py ├── utils.py └── yacs/ ├── __init__.py ├── config.py └── tests.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ models/.DS_Store figures/.DS_Store ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Glance-and-Focus Networks (PyTorch) This repo contains the official code and pre-trained models for the glance and focus networks (GFNet). - (NeurIPS 2020) [Glance and Focus: a Dynamic Approach to Reducing Spatial Redundancy in Image Classification](https://arxiv.org/abs/2010.05300) - (T-PAMI) [Glance and Focus Networks for Dynamic Visual Recognition](https://arxiv.org/abs/2201.03014) **Update on 2020/12/28: Release Training Code.** **Update on 2020/10/08: Release Pre-trained Models and the Inference Code on ImageNet.** ## Introduction

Inspired by the fact that not all regions in an image are task-relevant, we propose a novel framework that performs efficient image classification by processing a sequence of relatively small inputs, which are strategically cropped from the original image. Experiments on ImageNet show that our method consistently improves the computational efficiency of a wide variety of deep models. For example, it further reduces the average latency of the highly efficient MobileNet-V3 on an iPhone XS Max by 20% without sacrificing accuracy.

## Citation ``` @inproceedings{NeurIPS2020_7866, title={Glance and Focus: a Dynamic Approach to Reducing Spatial Redundancy in Image Classification}, author={Wang, Yulin and Lv, Kangchen and Huang, Rui and Song, Shiji and Yang, Le and Huang, Gao}, booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, year={2020}, } @article{huang2023glance, title={Glance and Focus Networks for Dynamic Visual Recognition}, author={Huang, Gao and Wang, Yulin and Lv, Kangchen and Jiang, Haojun and Huang, Wenhui and Qi, Pengfei and Song, Shiji}, journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, year={2023}, volume={45}, number={4}, pages={4605-4621}, doi={10.1109/TPAMI.2022.3196959} } ``` ## Results - Top-1 accuracy on ImageNet v.s. Multiply-Adds

- Top-1 accuracy on ImageNet v.s. Inference Latency (ms) on an iPhone XS Max

- Visualization

## Pre-trained Models |Backbone CNNs|Patch Size|T|Links| |-----|------|-----|-----| |ResNet-50| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/4c55dd9472b4416cbdc9/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Iun8o4o7cQL-7vSwKyNfefOgwb9-o9kD/view?usp=sharing)| |ResNet-50| 128x128| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/1cbed71346e54a129771/?dl=1) / [Google Drive](https://drive.google.com/file/d/1cEj0dXO7BfzQNd5fcYZOQekoAe3_DPia/view?usp=sharing)| |DenseNet-121| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c75c77d2f2054872ac20/?dl=1) / [Google Drive](https://drive.google.com/file/d/1UflIM29Npas0rTQSxPqwAT6zHbFkQq6R/view?usp=sharing)| |DenseNet-169| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/83fef24e667b4dccace1/?dl=1) / [Google Drive](https://drive.google.com/file/d/1pBo22i6VsWJWtw2xJw1_bTSMO3HgJNDL/view?usp=sharing)| |DenseNet-201| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/85a57b82e592470892e0/?dl=1) / [Google Drive](https://drive.google.com/file/d/1sETDr7dP5Q525fRMTIl2jlt9Eg2qh2dx/view?usp=sharing)| |RegNet-Y-600MF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/2638f038d3b1465da59e/?dl=1) / [Google Drive](https://drive.google.com/file/d/1FCR14wUiNrIXb81cU1pDPcy4bPSRXrqe/view?usp=sharing)| |RegNet-Y-800MF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/686e411e72894b789dde/?dl=1) / [Google Drive](https://drive.google.com/file/d/1On39MwbJY5Zagz7gNtKBFwMWhhHskfZq/view?usp=sharing)| |RegNet-Y-1.6GF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/90116ad21ee74843b0ef/?dl=1) / [Google Drive](https://drive.google.com/file/d/1rMe0LU8m4BF3udII71JT0VLPBE2G-eCJ/view?usp=sharing)| |MobileNet-V3-Large (1.00)| 96x96| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/4a4e8486b83b4dbeb06c/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Dw16jPlw2hR8EaWbd_1Ujd6Df9n1gHgj/view?usp=sharing)| |MobileNet-V3-Large (1.00)| 128x128| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/ab0f6fc3997d4771a4c9/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Ud_olyer-YgAb667YUKs38C2O1Yb6Nom/view?usp=sharing)| |MobileNet-V3-Large (1.25)| 128x128| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b2052c3af7734f688bc7/?dl=1) / [Google Drive](https://drive.google.com/file/d/14zj1Ci0i4nYceu-f2ZckFMjRmYDGtJpl/view?usp=sharing)| |EfficientNet-B2| 128x128| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/1a490deecd34470580da/?dl=1) / [Google Drive](https://drive.google.com/file/d/1LBBPrrYZzKKqCnoZH1kPfQQZ5ixkEmjz/view?usp=sharing)| |EfficientNet-B3| 128x128| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/d5182a2257bb481ea622/?dl=1) / [Google Drive](https://drive.google.com/file/d/1fdxwimcuQAXBOsbOdGw8Ee43PgeHHZTA/view?usp=sharing)| |EfficientNet-B3| 144x144| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/f96abfb6de13430aa663/?dl=1) / [Google Drive](https://drive.google.com/file/d/1OVTGI6d2nsN5Hz5T_qLnYeUIBL5oMeeU/view?usp=sharing)| - What are contained in the checkpoints: ``` **.pth.tar ├── model_name: name of the backbone CNNs (e.g., resnet50, densenet121) ├── patch_size: size of image patches (i.e., H' or W' in the paper) ├── model_prime_state_dict, model_state_dict, fc, policy: state dictionaries of the four components of GFNets ├── model_flops, policy_flops, fc_flops: Multiply-Adds of inferring the encoder, patch proposal network and classifier for once ├── flops: a list containing the Multiply-Adds corresponding to each length of the input sequence during inference ├── anytime_classification: results of anytime prediction (in Top-1 accuracy) ├── dynamic_threshold: the confidence thresholds used in budgeted batch classification ├── budgeted_batch_classification: results of budgeted batch classification (a two-item list, [0] and [1] correspond to the two coordinates of a curve) ``` ## Requirements - python 3.7.7 - pytorch 1.3.1 - torchvision 0.4.2 - pyyaml 5.3.1 (for RegNets) ## Evaluate Pre-trained Models Read the evaluation results saved in pre-trained models ``` CUDA_VISIBLE_DEVICES=0 python inference.py --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 0 ``` Read the confidence thresholds saved in pre-trained models and infer the model on the validation set ``` CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 1 ``` Determine confidence thresholds on the training set and infer the model on the validation set ``` CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 2 ``` The dataset is expected to be prepared as follows: ``` ImageNet ├── train │ ├── folder 1 (class 1) │ ├── folder 2 (class 1) │ ├── ... ├── val │ ├── folder 1 (class 1) │ ├── folder 2 (class 1) │ ├── ... ``` ## Training - Here we take training ResNet-50 (96x96, T=5) for example. All the used initialization models and stage-1/2 checkpoints can be found in [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/ac7c47b3f9b04e098862/) / [Google Drive](https://drive.google.com/drive/folders/1yO2GviOnukSUgcTkptNLBBttSJQZk9yn?usp=sharing). Currently, this link includes ResNet and MobileNet-V3. We will update it as soon as possible. If you need other helps, feel free to contact us. - The Results in the paper is based on 2 Tesla V100 GPUs. For most of experiments, up to 4 Titan Xp GPUs may be enough. Training stage 1, the initializations of global encoder (model_prime) and local encoder (model) are required: ``` CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_url PATH_TO_DATASET --train_stage 1 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --model_prime_path PATH_TO_CHECKPOINTS --model_path PATH_TO_CHECKPOINTS ``` Training stage 2, a stage-1 checkpoint is required: ``` CUDA_VISIBLE_DEVICES=0 python train.py --data_url PATH_TO_DATASET --train_stage 2 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --checkpoint_path PATH_TO_CHECKPOINTS ``` Training stage 3, a stage-2 checkpoint is required: ``` CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_url PATH_TO_DATASET --train_stage 3 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --checkpoint_path PATH_TO_CHECKPOINTS ``` ## Contact If you have any question, please feel free to contact the authors. Yulin Wang: wang-yl19@mails.tsinghua.edu.cn. ## Acknowledgment Our code of MobileNet-V3 and EfficientNet is from [here](https://github.com/rwightman/pytorch-image-models). Our code of RegNet is from [here](https://github.com/facebookresearch/pycls). ## To Do - Update the code for visualizing. - Update the code for MIXED PRECISION TRAINING。 ================================================ FILE: configs.py ================================================ from PIL import Image model_configurations = { 'resnet50': { 'feature_num': 2048, 'feature_map_channels': 2048, 'policy_conv': False, 'policy_hidden_dim': 1024, 'fc_rnn': True, 'fc_hidden_dim': 1024, 'image_size': 224, 'crop_pct': 0.875, 'dataset_interpolation': Image.BILINEAR, 'prime_interpolation': 'bicubic' }, 'densenet121': { 'feature_num': 1024, 'feature_map_channels': 1024, 'policy_conv': False, 'policy_hidden_dim': 1024, 'fc_rnn': True, 'fc_hidden_dim': 1024, 'image_size': 224, 'crop_pct': 0.875, 'dataset_interpolation': Image.BILINEAR, 'prime_interpolation': 'bilinear' }, 'densenet169': { 'feature_num': 1664, 'feature_map_channels': 1664, 'policy_conv': False, 'policy_hidden_dim': 1024, 'fc_rnn': True, 'fc_hidden_dim': 1024, 'image_size': 224, 'crop_pct': 0.875, 'dataset_interpolation': Image.BILINEAR, 'prime_interpolation': 'bilinear' }, 'densenet201': { 'feature_num': 1920, 'feature_map_channels': 1920, 'policy_conv': False, 'policy_hidden_dim': 1024, 'fc_rnn': True, 'fc_hidden_dim': 1024, 'image_size': 224, 'crop_pct': 0.875, 'dataset_interpolation': Image.BILINEAR, 'prime_interpolation': 'bilinear' }, 'mobilenetv3_large_100': { 'feature_num': 1280, 'feature_map_channels': 960, 'policy_conv': True, 'policy_hidden_dim': 256, 'fc_rnn': False, 'fc_hidden_dim': None, 'image_size': 224, 'crop_pct': 0.875, 'dataset_interpolation': Image.BILINEAR, 'prime_interpolation': 'bicubic' }, 'mobilenetv3_large_125': { 'feature_num': 1280, 'feature_map_channels': 1200, 'policy_conv': True, 'policy_hidden_dim': 256, 'fc_rnn': False, 'fc_hidden_dim': None, 'image_size': 224, 'crop_pct': 0.875, 'dataset_interpolation': Image.BILINEAR, 'prime_interpolation': 'bicubic' }, 'efficientnet_b2': { 'feature_num': 1408, 'feature_map_channels': 1408, 'policy_conv': True, 'policy_hidden_dim': 256, 'fc_rnn': False, 'fc_hidden_dim': None, 'image_size': 260, 'crop_pct': 0.875, 'dataset_interpolation': Image.BICUBIC, 'prime_interpolation': 'bicubic' }, 'efficientnet_b3': { 'feature_num': 1536, 'feature_map_channels': 1536, 'policy_conv': True, 'policy_hidden_dim': 256, 'fc_rnn': False, 'fc_hidden_dim': None, 'image_size': 300, 'crop_pct': 0.904, 'dataset_interpolation': Image.BICUBIC, 'prime_interpolation': 'bicubic' }, 'regnety_600m': { 'feature_num': 608, 'feature_map_channels': 608, 'policy_conv': True, 'policy_hidden_dim': 256, 'fc_rnn': True, 'fc_hidden_dim': 1024, 'image_size': 224, 'crop_pct': 0.875, 'dataset_interpolation': Image.BILINEAR, 'prime_interpolation': 'bilinear', 'cfg_file': 'pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml' }, 'regnety_800m': { 'feature_num': 768, 'feature_map_channels': 768, 'policy_conv': True, 'policy_hidden_dim': 256, 'fc_rnn': True, 'fc_hidden_dim': 1024, 'image_size': 224, 'crop_pct': 0.875, 'dataset_interpolation': Image.BILINEAR, 'prime_interpolation': 'bilinear', 'cfg_file': 'pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml' }, 'regnety_1.6g': { 'feature_num': 888, 'feature_map_channels': 888, 'policy_conv': True, 'policy_hidden_dim': 256, 'fc_rnn': True, 'fc_hidden_dim': 1024, 'image_size': 224, 'crop_pct': 0.875, 'dataset_interpolation': Image.BILINEAR, 'prime_interpolation': 'bilinear', 'cfg_file': 'pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml' } } train_configurations = { 'resnet': { 'backbone_lr': 0.01, 'fc_stage_1_lr': 0.1, 'fc_stage_3_lr': 0.01, 'weight_decay': 1e-4, 'momentum': 0.9, 'Nesterov': True, 'batch_size': 256, 'dsn_ratio': 1, 'epoch_num': 60, 'train_model_prime': True }, 'densenet': { 'backbone_lr': 0.01, 'fc_stage_1_lr': 0.1, 'fc_stage_3_lr': 0.01, 'weight_decay': 1e-4, 'momentum': 0.9, 'Nesterov': True, 'batch_size': 256, 'dsn_ratio': 1, 'epoch_num': 60, 'train_model_prime': True }, 'efficientnet': { 'backbone_lr': 0.005, 'fc_stage_1_lr': 0.1, 'fc_stage_3_lr': 0.01, 'weight_decay': 1e-4, 'momentum': 0.9, 'Nesterov': True, 'batch_size': 256, 'dsn_ratio': 5, 'epoch_num': 30, 'train_model_prime': False }, 'mobilenetv3': { 'backbone_lr': 0.005, 'fc_stage_1_lr': 0.1, 'fc_stage_3_lr': 0.01, 'weight_decay': 1e-4, 'momentum': 0.9, 'Nesterov': True, 'batch_size': 256, 'dsn_ratio': 5, 'epoch_num': 90, 'train_model_prime': False }, 'regnet': { 'backbone_lr': 0.02, 'fc_stage_1_lr': 0.1, 'fc_stage_3_lr': 0.01, 'weight_decay': 5e-5, 'momentum': 0.9, 'Nesterov': True, 'batch_size': 256, 'dsn_ratio': 1, 'epoch_num': 60, 'train_model_prime': True } } ================================================ FILE: inference.py ================================================ import torchvision.transforms as transforms import torchvision.datasets as datasets import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') from utils import * from network import * from configs import * import math import argparse import models.resnet as resnet import models.densenet as densenet from models import create_model device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") parser = argparse.ArgumentParser(description='Inference code for GFNet') parser.add_argument('--data_url', default='./data', type=str, help='path to the dataset (ImageNet)') parser.add_argument('--checkpoint_path', default='', type=str, help='path to the pre-train model (default: none)') parser.add_argument('--eval_mode', default=2, type=int, help='mode 0 : read the evaluation results saved in pre-trained models\ mode 1 : read the confidence thresholds saved in pre-trained models and infer the model on the validation set\ mode 2 : determine confidence thresholds on the training set and infer the model on the validation set') args = parser.parse_args() def main(): # load pretrained model checkpoint = torch.load(args.checkpoint_path) try: model_arch = checkpoint['model_name'] patch_size = checkpoint['patch_size'] prime_size = checkpoint['patch_size'] flops = checkpoint['flops'] model_flops = checkpoint['model_flops'] policy_flops = checkpoint['policy_flops'] fc_flops = checkpoint['fc_flops'] anytime_classification = checkpoint['anytime_classification'] budgeted_batch_classification = checkpoint['budgeted_batch_classification'] dynamic_threshold = checkpoint['dynamic_threshold'] maximum_length = len(checkpoint['flops']) except: print('Error: \n' 'Please provide essential information' 'for customized models (as we have done ' 'in pre-trained models)!\n' 'At least the following information should be Given: \n' '--model_name: name of the backbone CNNs (e.g., resnet50, densenet121)\n' '--patch_size: size of image patches (i.e., H\' or W\' in the paper)\n' '--flops: a list containing the Multiply-Adds corresponding to each ' 'length of the input sequence during inference') model_configuration = model_configurations[model_arch] if args.eval_mode > 0: # create model if 'resnet' in model_arch: model = resnet.resnet50(pretrained=False) model_prime = resnet.resnet50(pretrained=False) elif 'densenet' in model_arch: model = eval('densenet.' + model_arch)(pretrained=False) model_prime = eval('densenet.' + model_arch)(pretrained=False) elif 'efficientnet' in model_arch: model = create_model(model_arch, pretrained=False, num_classes=1000, drop_rate=0.3, drop_connect_rate=0.2) model_prime = create_model(model_arch, pretrained=False, num_classes=1000, drop_rate=0.3, drop_connect_rate=0.2) elif 'mobilenetv3' in model_arch: model = create_model(model_arch, pretrained=False, num_classes=1000, drop_rate=0.2, drop_connect_rate=0.2) model_prime = create_model(model_arch, pretrained=False, num_classes=1000, drop_rate=0.2, drop_connect_rate=0.2) elif 'regnet' in model_arch: import pycls.core.model_builder as model_builder from pycls.core.config import cfg cfg.merge_from_file(model_configuration['cfg_file']) cfg.freeze() model = model_builder.build_model() model_prime = model_builder.build_model() traindir = args.data_url + 'train/' valdir = args.data_url + 'val/' normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_set = datasets.ImageFolder(traindir, transforms.Compose([ transforms.RandomResizedCrop(model_configuration['image_size'], interpolation=model_configuration['dataset_interpolation']), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ])) train_set_index = torch.randperm(len(train_set)) train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, num_workers=32, pin_memory=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_set_index[-200000:])) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(valdir, transforms.Compose([ transforms.Resize(int(model_configuration['image_size']/model_configuration['crop_pct']), interpolation = model_configuration['dataset_interpolation']), transforms.CenterCrop(model_configuration['image_size']), transforms.ToTensor(), normalize])), batch_size=256, shuffle=False, num_workers=16, pin_memory=False) state_dim = model_configuration['feature_map_channels'] * math.ceil(patch_size/32) * math.ceil(patch_size/32) memory = Memory() policy = ActorCritic(model_configuration['feature_map_channels'], state_dim, model_configuration['policy_hidden_dim'], model_configuration['policy_conv']) fc = Full_layer(model_configuration['feature_num'], model_configuration['fc_hidden_dim'], model_configuration['fc_rnn']) model = nn.DataParallel(model.cuda()) model_prime = nn.DataParallel(model_prime.cuda()) policy = policy.cuda() fc = fc.cuda() model.load_state_dict(checkpoint['model_state_dict']) model_prime.load_state_dict(checkpoint['model_prime_state_dict']) fc.load_state_dict(checkpoint['fc']) policy.load_state_dict(checkpoint['policy']) budgeted_batch_flops_list = [] budgeted_batch_acc_list = [] print('generate logits on test samples...') test_logits, test_targets, anytime_classification = generate_logits(model_prime, model, fc, memory, policy, val_loader, maximum_length, prime_size, patch_size, model_arch) if args.eval_mode == 2: print('generate logits on training samples...') dynamic_threshold = torch.zeros([39, maximum_length]) train_logits, train_targets, _ = generate_logits(model_prime, model, fc, memory, policy, train_loader, maximum_length, prime_size, patch_size, model_arch) for p in range(1, 40): print('inference: {}/40'.format(p)) _p = torch.FloatTensor(1).fill_(p * 1.0 / 20) probs = torch.exp(torch.log(_p) * torch.range(1, maximum_length)) probs /= probs.sum() if args.eval_mode == 2: dynamic_threshold[p-1] = dynamic_find_threshold(train_logits, train_targets, probs) acc_step, flops_step = dynamic_evaluate(test_logits, test_targets, flops, dynamic_threshold[p-1]) budgeted_batch_acc_list.append(acc_step) budgeted_batch_flops_list.append(flops_step) budgeted_batch_classification = [budgeted_batch_flops_list, budgeted_batch_acc_list] print('model_arch :', model_arch) print('patch_size :', patch_size) print('flops :', flops) print('model_flops :', model_flops) print('policy_flops :', policy_flops) print('fc_flops :', fc_flops) print('anytime_classification :', anytime_classification) print('budgeted_batch_classification :', budgeted_batch_classification) def generate_logits(model_prime, model, fc, memory, policy, dataloader, maximum_length, prime_size, patch_size, model_arch): logits_list = [] targets_list = [] top1 = [AverageMeter() for _ in range(maximum_length)] model.eval() model_prime.eval() fc.eval() for i, (x, target) in enumerate(dataloader): logits_temp = torch.zeros(maximum_length, x.size(0), 1000) target_var = target.cuda() input_var = x.cuda() input_prime = get_prime(input_var, prime_size, model_configurations[model_arch]['prime_interpolation']) with torch.no_grad(): output, state = model_prime(input_prime) if 'resnet' in model_arch or 'densenet' in model_arch: output = fc(output, restart=True) elif 'regnet' in model_arch: _ = fc(output, restart=True) output = model_prime.module.fc(output) else: _ = fc(output, restart=True) output = model_prime.module.classifier(output) logits_temp[0] = F.softmax(output, 1) acc = accuracy(output, target_var, topk=(1,)) top1[0].update(acc.sum(0).mul_(100.0 / x.size(0)).data.item(), x.size(0)) for patch_step in range(1, maximum_length): with torch.no_grad(): if patch_step == 1: action = policy.act(state, memory, restart_batch=True) else: action = policy.act(state, memory) patches = get_patch(input_var, action, patch_size) output, state = model(patches) output = fc(output, restart=False) logits_temp[patch_step] = F.softmax(output, 1) acc = accuracy(output, target_var, topk=(1,)) top1[patch_step].update(acc.sum(0).mul_(100.0 / x.size(0)).data.item(), x.size(0)) logits_list.append(logits_temp) targets_list.append(target_var) memory.clear_memory() anytime_classification = [] for index in range(maximum_length): anytime_classification.append(top1[index].ave) return torch.cat(logits_list, 1), torch.cat(targets_list, 0), anytime_classification def dynamic_find_threshold(logits, targets, p): n_stage, n_sample, c = logits.size() max_preds, argmax_preds = logits.max(dim=2, keepdim=False) _, sorted_idx = max_preds.sort(dim=1, descending=True) filtered = torch.zeros(n_sample) T = torch.Tensor(n_stage).fill_(1e8) for k in range(n_stage - 1): acc, count = 0.0, 0 out_n = math.floor(n_sample * p[k]) for i in range(n_sample): ori_idx = sorted_idx[k][i] if filtered[ori_idx] == 0: count += 1 if count == out_n: T[k] = max_preds[k][ori_idx] break filtered.add_(max_preds[k].ge(T[k]).type_as(filtered)) T[n_stage - 1] = -1e8 return T def dynamic_evaluate(logits, targets, flops, T): n_stage, n_sample, c = logits.size() max_preds, argmax_preds = logits.max(dim=2, keepdim=False) _, sorted_idx = max_preds.sort(dim=1, descending=True) acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage) acc, expected_flops = 0, 0 for i in range(n_sample): gold_label = targets[i] for k in range(n_stage): if max_preds[k][i].item() >= T[k]: # force the sample to exit at k if int(gold_label.item()) == int(argmax_preds[k][i].item()): acc += 1 acc_rec[k] += 1 exp[k] += 1 break acc_all = 0 for k in range(n_stage): _t = 1.0 * exp[k] / n_sample expected_flops += _t * flops[k] acc_all += acc_rec[k] return acc * 100.0 / n_sample, expected_flops.item() if __name__ == '__main__': main() ================================================ FILE: models/__init__.py ================================================ from .gen_efficientnet import * from .mobilenetv3 import * from .model_factory import create_model from .config import is_exportable, is_scriptable, set_exportable, set_scriptable from .activations import * from .resnet import * from .densenet import * ================================================ FILE: models/activations/__init__.py ================================================ from .config import * from .activations_autofn import * from .activations_jit import * from .activations import * _ACT_FN_DEFAULT = dict( swish=swish, mish=mish, relu=F.relu, relu6=F.relu6, sigmoid=sigmoid, tanh=tanh, hard_sigmoid=hard_sigmoid, hard_swish=hard_swish, ) _ACT_FN_AUTO = dict( swish=swish_auto, mish=mish_auto, ) _ACT_FN_JIT = dict( swish=swish_jit, mish=mish_jit, #hard_swish=hard_swish_jit, #hard_sigmoid_jit=hard_sigmoid_jit, ) _ACT_LAYER_DEFAULT = dict( swish=Swish, mish=Mish, relu=nn.ReLU, relu6=nn.ReLU6, sigmoid=Sigmoid, tanh=Tanh, hard_sigmoid=HardSigmoid, hard_swish=HardSwish, ) _ACT_LAYER_AUTO = dict( swish=SwishAuto, mish=MishAuto, ) _ACT_LAYER_JIT = dict( swish=SwishJit, mish=MishJit, #hard_swish=HardSwishJit, #hard_sigmoid=HardSigmoidJit ) _OVERRIDE_FN = dict() _OVERRIDE_LAYER = dict() def add_override_act_fn(name, fn): global _OVERRIDE_FN _OVERRIDE_FN[name] = fn def update_override_act_fn(overrides): assert isinstance(overrides, dict) global _OVERRIDE_FN _OVERRIDE_FN.update(overrides) def clear_override_act_fn(): global _OVERRIDE_FN _OVERRIDE_FN = dict() def add_override_act_layer(name, fn): _OVERRIDE_LAYER[name] = fn def update_override_act_layer(overrides): assert isinstance(overrides, dict) global _OVERRIDE_LAYER _OVERRIDE_LAYER.update(overrides) def clear_override_act_layer(): global _OVERRIDE_LAYER _OVERRIDE_LAYER = dict() def get_act_fn(name='relu'): """ Activation Function Factory Fetching activation fns by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if name in _OVERRIDE_FN: return _OVERRIDE_FN[name] if not config.is_exportable() and not config.is_scriptable(): # If not exporting or scripting the model, first look for a JIT optimized version # of our activation, then a custom autograd.Function variant before defaulting to # a Python or Torch builtin impl if name in _ACT_FN_JIT: return _ACT_FN_JIT[name] if name in _ACT_FN_AUTO: return _ACT_FN_AUTO[name] return _ACT_FN_DEFAULT[name] def get_act_layer(name='relu'): """ Activation Layer Factory Fetching activation layers by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if name in _OVERRIDE_LAYER: return _OVERRIDE_LAYER[name] if not config.is_exportable() and not config.is_scriptable(): if name in _ACT_LAYER_JIT: return _ACT_LAYER_JIT[name] if name in _ACT_LAYER_AUTO: return _ACT_LAYER_AUTO[name] return _ACT_LAYER_DEFAULT[name] ================================================ FILE: models/activations/activations.py ================================================ from torch import nn as nn from torch.nn import functional as F def swish(x, inplace: bool = False): """Swish - Described in: https://arxiv.org/abs/1710.05941 """ return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) class Swish(nn.Module): def __init__(self, inplace: bool = False): super(Swish, self).__init__() self.inplace = inplace def forward(self, x): return swish(x, self.inplace) def mish(x, inplace: bool = False): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 """ return x.mul(F.softplus(x).tanh()) class Mish(nn.Module): def __init__(self, inplace: bool = False): super(Mish, self).__init__() self.inplace = inplace def forward(self, x): return mish(x, self.inplace) def sigmoid(x, inplace: bool = False): return x.sigmoid_() if inplace else x.sigmoid() # PyTorch has this, but not with a consistent inplace argmument interface class Sigmoid(nn.Module): def __init__(self, inplace: bool = False): super(Sigmoid, self).__init__() self.inplace = inplace def forward(self, x): return x.sigmoid_() if self.inplace else x.sigmoid() def tanh(x, inplace: bool = False): return x.tanh_() if inplace else x.tanh() # PyTorch has this, but not with a consistent inplace argmument interface class Tanh(nn.Module): def __init__(self, inplace: bool = False): super(Tanh, self).__init__() self.inplace = inplace def forward(self, x): return x.tanh_() if self.inplace else x.tanh() def hard_swish(x, inplace: bool = False): inner = F.relu6(x + 3.).div_(6.) return x.mul_(inner) if inplace else x.mul(inner) class HardSwish(nn.Module): def __init__(self, inplace: bool = False): super(HardSwish, self).__init__() self.inplace = inplace def forward(self, x): return hard_swish(x, self.inplace) def hard_sigmoid(x, inplace: bool = False): if inplace: return x.add_(3.).clamp_(0., 6.).div_(6.) else: return F.relu6(x + 3.) / 6. class HardSigmoid(nn.Module): def __init__(self, inplace: bool = False): super(HardSigmoid, self).__init__() self.inplace = inplace def forward(self, x): return hard_sigmoid(x, self.inplace) ================================================ FILE: models/activations/activations_autofn.py ================================================ import torch from torch import nn as nn from torch.nn import functional as F __all__ = ['swish_auto', 'SwishAuto', 'mish_auto', 'MishAuto'] class SwishAutoFn(torch.autograd.Function): """Swish - Described in: https://arxiv.org/abs/1710.05941 Memory efficient variant from: https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76 """ @staticmethod def forward(ctx, x): result = x.mul(torch.sigmoid(x)) ctx.save_for_backward(x) return result @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] x_sigmoid = torch.sigmoid(x) return grad_output.mul(x_sigmoid * (1 + x * (1 - x_sigmoid))) def swish_auto(x, inplace=False): # inplace ignored return SwishAutoFn.apply(x) class SwishAuto(nn.Module): def __init__(self, inplace: bool = False): super(SwishAuto, self).__init__() self.inplace = inplace def forward(self, x): return SwishAutoFn.apply(x) class MishAutoFn(torch.autograd.Function): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 Experimental memory-efficient variant """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) y = x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) return y @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] x_sigmoid = torch.sigmoid(x) x_tanh_sp = F.softplus(x).tanh() return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) def mish_auto(x, inplace=False): # inplace ignored return MishAutoFn.apply(x) class MishAuto(nn.Module): def __init__(self, inplace: bool = False): super(MishAuto, self).__init__() self.inplace = inplace def forward(self, x): return MishAutoFn.apply(x) ================================================ FILE: models/activations/activations_jit.py ================================================ import torch from torch import nn as nn from torch.nn import functional as F __all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit'] #'hard_swish_jit', 'HardSwishJit', 'hard_sigmoid_jit', 'HardSigmoidJit'] @torch.jit.script def swish_jit_fwd(x): return x.mul(torch.sigmoid(x)) @torch.jit.script def swish_jit_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) class SwishJitAutoFn(torch.autograd.Function): """ torch.jit.script optimised Swish Inspired by conversation btw Jeremy Howard & Adam Pazske https://twitter.com/jeremyphoward/status/1188251041835315200 """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return swish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return swish_jit_bwd(x, grad_output) def swish_jit(x, inplace=False): # inplace ignored return SwishJitAutoFn.apply(x) class SwishJit(nn.Module): def __init__(self, inplace: bool = False): super(SwishJit, self).__init__() self.inplace = inplace def forward(self, x): return SwishJitAutoFn.apply(x) @torch.jit.script def mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x))) @torch.jit.script def mish_jit_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) x_tanh_sp = F.softplus(x).tanh() return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) class MishJitAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return mish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return mish_jit_bwd(x, grad_output) def mish_jit(x, inplace=False): # inplace ignored return MishJitAutoFn.apply(x) class MishJit(nn.Module): def __init__(self, inplace: bool = False): super(MishJit, self).__init__() self.inplace = inplace def forward(self, x): return MishJitAutoFn.apply(x) # @torch.jit.script # def hard_swish_jit(x, inplac: bool = False): # return x.mul(F.relu6(x + 3.).mul_(1./6.)) # # # class HardSwishJit(nn.Module): # def __init__(self, inplace: bool = False): # super(HardSwishJit, self).__init__() # # def forward(self, x): # return hard_swish_jit(x) # # # @torch.jit.script # def hard_sigmoid_jit(x, inplace: bool = False): # return F.relu6(x + 3.).mul(1./6.) # # # class HardSigmoidJit(nn.Module): # def __init__(self, inplace: bool = False): # super(HardSigmoidJit, self).__init__() # # def forward(self, x): # return hard_sigmoid_jit(x) ================================================ FILE: models/activations/config.py ================================================ """ Global Config and Constants """ __all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable'] # Set to True if exporting a model with Same padding via ONNX _EXPORTABLE = False # Set to True if wanting to use torch.jit.script on a model _SCRIPTABLE = False def is_exportable(): return _EXPORTABLE def set_exportable(value): global _EXPORTABLE _EXPORTABLE = value def is_scriptable(): return _SCRIPTABLE def set_scriptable(value): global _SCRIPTABLE _SCRIPTABLE = value ================================================ FILE: models/config.py ================================================ """ Global Config and Constants """ __all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable'] # Set to True if exporting a model with Same padding via ONNX _EXPORTABLE = False # Set to True if wanting to use torch.jit.script on a model _SCRIPTABLE = False def is_exportable(): return _EXPORTABLE def set_exportable(value): global _EXPORTABLE _EXPORTABLE = value def is_scriptable(): return _SCRIPTABLE def set_scriptable(value): global _SCRIPTABLE _SCRIPTABLE = value ================================================ FILE: models/conv2d_layers.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch._six import container_abcs from itertools import repeat from functools import partial from typing import Union, List, Tuple, Optional, Callable import numpy as np import math from .activations.config import * def _ntuple(n): def parse(x): if isinstance(x, container_abcs.Iterable): return x return tuple(repeat(x, n)) return parse _single = _ntuple(1) _pair = _ntuple(2) _triple = _ntuple(3) _quadruple = _ntuple(4) def _is_static_pad(kernel_size, stride=1, dilation=1, **_): return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 def _get_padding(kernel_size, stride=1, dilation=1, **_): padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding def _calc_same_pad(i: int, k: int, s: int, d: int): return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) def _same_pad_arg(input_size, kernel_size, stride, dilation): ih, iw = input_size kh, kw = kernel_size pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] def _split_channels(num_chan, num_groups): split = [num_chan // num_groups for _ in range(num_groups)] split[0] += num_chan - sum(split) return split def conv2d_same( x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): ih, iw = x.size()[-2:] kh, kw = weight.size()[-2:] pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) if pad_h > 0 or pad_w > 0: x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) class Conv2dSame(nn.Conv2d): """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions """ # pylint: disable=unused-argument def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(Conv2dSame, self).__init__( in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) def forward(self, x): return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) class Conv2dSameExport(nn.Conv2d): """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions NOTE: This does not currently work with torch.jit.script """ # pylint: disable=unused-argument def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(Conv2dSameExport, self).__init__( in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) self.pad = None self.pad_input_size = (0, 0) def forward(self, x): input_size = x.size()[-2:] if self.pad is None: pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) self.pad = nn.ZeroPad2d(pad_arg) self.pad_input_size = input_size else: assert self.pad_input_size == input_size x = self.pad(x) return F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) def get_padding_value(padding, kernel_size, **kwargs): dynamic = False if isinstance(padding, str): # for any string padding, the padding will be calculated for you, one of three ways padding = padding.lower() if padding == 'same': # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact if _is_static_pad(kernel_size, **kwargs): # static case, no extra overhead padding = _get_padding(kernel_size, **kwargs) else: # dynamic padding padding = 0 dynamic = True elif padding == 'valid': # 'VALID' padding, same as padding=0 padding = 0 else: # Default to PyTorch style 'same'-ish symmetric padding padding = _get_padding(kernel_size, **kwargs) return padding, dynamic def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): padding = kwargs.pop('padding', '') kwargs.setdefault('bias', False) padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) if is_dynamic: if is_exportable(): assert not is_scriptable() return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) else: return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) else: return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) class MixedConv2d(nn.ModuleDict): """ Mixed Grouped Convolution Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, depthwise=False, **kwargs): super(MixedConv2d, self).__init__() kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] num_groups = len(kernel_size) in_splits = _split_channels(in_channels, num_groups) out_splits = _split_channels(out_channels, num_groups) self.in_channels = sum(in_splits) self.out_channels = sum(out_splits) for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): conv_groups = out_ch if depthwise else 1 self.add_module( str(idx), create_conv2d_pad( in_ch, out_ch, k, stride=stride, padding=padding, dilation=dilation, groups=conv_groups, **kwargs) ) self.splits = in_splits def forward(self, x): x_split = torch.split(x, self.splits, 1) x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())] x = torch.cat(x_out, 1) return x def get_condconv_initializer(initializer, num_experts, expert_shape): def condconv_initializer(weight): """CondConv initializer function.""" num_params = np.prod(expert_shape) if (len(weight.shape) != 2 or weight.shape[0] != num_experts or weight.shape[1] != num_params): raise (ValueError( 'CondConv variables must have shape [num_experts, num_params]')) for i in range(num_experts): initializer(weight[i].view(expert_shape)) return condconv_initializer class CondConv2d(nn.Module): """ Conditional Convolution Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: https://github.com/pytorch/pytorch/issues/17983 """ __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): super(CondConv2d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) padding_val, is_padding_dynamic = get_padding_value( padding, kernel_size, stride=stride, dilation=dilation) self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript self.padding = _pair(padding_val) self.dilation = _pair(dilation) self.groups = groups self.num_experts = num_experts self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size weight_num_param = 1 for wd in self.weight_shape: weight_num_param *= wd self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) if bias: self.bias_shape = (self.out_channels,) self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): init_weight = get_condconv_initializer( partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) init_weight(self.weight) if self.bias is not None: fan_in = np.prod(self.weight_shape[1:]) bound = 1 / math.sqrt(fan_in) init_bias = get_condconv_initializer( partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) init_bias(self.bias) def forward(self, x, routing_weights): B, C, H, W = x.shape weight = torch.matmul(routing_weights, self.weight) new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size weight = weight.view(new_weight_shape) bias = None if self.bias is not None: bias = torch.matmul(routing_weights, self.bias) bias = bias.view(B * self.out_channels) # move batch elements with channels so each batch element can be efficiently convolved with separate kernel x = x.view(1, B * C, H, W) if self.dynamic_padding: out = conv2d_same( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * B) else: out = F.conv2d( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * B) out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) # Literal port (from TF definition) # x = torch.split(x, 1, 0) # weight = torch.split(weight, 1, 0) # if self.bias is not None: # bias = torch.matmul(routing_weights, self.bias) # bias = torch.split(bias, 1, 0) # else: # bias = [None] * B # out = [] # for xi, wi, bi in zip(x, weight, bias): # wi = wi.view(*self.weight_shape) # if bi is not None: # bi = bi.view(*self.bias_shape) # out.append(self.conv_fn( # xi, wi, bi, stride=self.stride, padding=self.padding, # dilation=self.dilation, groups=self.groups)) # out = torch.cat(out, 0) return out def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): assert 'groups' not in kwargs # only use 'depthwise' bool arg if isinstance(kernel_size, list): assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently # We're going to use only lists for defining the MixedConv2d kernel groups, # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) else: depthwise = kwargs.pop('depthwise', False) groups = out_chs if depthwise else 1 if 'num_experts' in kwargs and kwargs['num_experts'] > 0: m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) else: m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) return m ================================================ FILE: models/densenet.py ================================================ import re import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] model_urls = { 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', } def densenet121(pretrained=False, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) if pretrained: # model.load_state_dict(torch.load(model_urls['densenet121'])) # '.'s are no longer allowed in module names, but pervious _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') state_dict = torch.load(model_urls['densenet121']) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) return model def densenet169(pretrained=False, **kwargs): r"""Densenet-169 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) if pretrained: # '.'s are no longer allowed in module names, but pervious _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') state_dict = torch.load(model_urls['densenet169']) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) return model def densenet201(pretrained=False, **kwargs): r"""Densenet-201 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) if pretrained: # '.'s are no longer allowed in module names, but pervious _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') state_dict = torch.load(model_urls['densenet201']) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) return model def densenet161(pretrained=False, **kwargs): r"""Densenet-161 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) if pretrained: # '.'s are no longer allowed in module names, but pervious _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') state_dict = torch.load(model_urls['densenet161']) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) return model class _DenseLayer(nn.Sequential): def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): super(_DenseLayer, self).__init__() self.add_module('norm1', nn.BatchNorm2d(num_input_features)), self.add_module('relu1', nn.ReLU(inplace=True)), self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), self.add_module('relu2', nn.ReLU(inplace=True)), self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), self.drop_rate = drop_rate def forward(self, x): new_features = super(_DenseLayer, self).forward(x) if self.drop_rate > 0: new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) return torch.cat([x, new_features], 1) class _DenseBlock(nn.Sequential): def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): super(_DenseBlock, self).__init__() for i in range(num_layers): layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) self.add_module('denselayer%d' % (i + 1), layer) class _Transition(nn.Sequential): def __init__(self, num_input_features, num_output_features): super(_Transition, self).__init__() self.add_module('norm', nn.BatchNorm2d(num_input_features)) self.add_module('relu', nn.ReLU(inplace=True)) self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) class DenseNet(nn.Module): r"""Densenet-BC model class, based on `"Densely Connected Convolutional Networks" `_ Args: growth_rate (int) - how many filters to add each layer (`k` in paper) block_config (list of 4 ints) - how many layers in each pooling block num_init_features (int) - the number of filters to learn in the first convolution layer bn_size (int) - multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer) drop_rate (float) - dropout rate after each dense layer num_classes (int) - number of classification classes """ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): super(DenseNet, self).__init__() # First convolution self.features = nn.Sequential(OrderedDict([ ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), ('norm0', nn.BatchNorm2d(num_init_features)), ('relu0', nn.ReLU(inplace=True)), ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), ])) # Each denseblock num_features = num_init_features for i, num_layers in enumerate(block_config): block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) self.features.add_module('denseblock%d' % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) self.features.add_module('transition%d' % (i + 1), trans) num_features = num_features // 2 # Final batch norm self.features.add_module('norm5', nn.BatchNorm2d(num_features)) self.feature_num = num_features self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Linear layer self.classifier = nn.Linear(num_features, num_classes) # Official init from torch repo. for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal(m.weight.data) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_() def forward(self, x): features = self.features(x) x = F.relu(features, inplace=True) out = self.avgpool(x).view(features.size(0), -1) return out, x.detach() ================================================ FILE: models/efficientnet_builder.py ================================================ import re from copy import deepcopy from .conv2d_layers import * from .activations import * # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) # NOTE: momentum varies btw .99 and .9997 depending on source # .99 in official TF TPU impl # .9997 (/w .999 in search space) for paper # # PyTorch defaults are momentum = .1, eps = 1e-5 # BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 BN_EPS_TF_DEFAULT = 1e-3 _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) def get_bn_args_tf(): return _BN_ARGS_TF.copy() def resolve_bn_args(kwargs): bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} bn_momentum = kwargs.pop('bn_momentum', None) if bn_momentum is not None: bn_args['momentum'] = bn_momentum bn_eps = kwargs.pop('bn_eps', None) if bn_eps is not None: bn_args['eps'] = bn_eps return bn_args _SE_ARGS_DEFAULT = dict( gate_fn=sigmoid, act_layer=None, # None == use containing block's activation layer reduce_mid=False, divisor=1) def resolve_se_args(kwargs, in_chs, act_layer=None): se_kwargs = kwargs.copy() if kwargs is not None else {} # fill in args that aren't specified with the defaults for k, v in _SE_ARGS_DEFAULT.items(): se_kwargs.setdefault(k, v) # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch if not se_kwargs.pop('reduce_mid'): se_kwargs['reduced_base_chs'] = in_chs # act_layer override, if it remains None, the containing block's act_layer will be used if se_kwargs['act_layer'] is None: assert act_layer is not None se_kwargs['act_layer'] = act_layer return se_kwargs def resolve_act_layer(kwargs, default='relu'): act_layer = kwargs.pop('act_layer', default) if isinstance(act_layer, str): act_layer = get_act_layer(act_layer) return act_layer def make_divisible(v: int, divisor: int = 8, min_value: int = None): min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) if new_v < 0.9 * v: # ensure round down does not go down by more than 10%. new_v += divisor return new_v def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): """Round number of filters based on depth multiplier.""" if not multiplier: return channels channels *= multiplier return make_divisible(channels, divisor, channel_min) def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): """Apply drop connect.""" if not training: return inputs keep_prob = 1 - drop_connect_rate random_tensor = keep_prob + torch.rand( (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) random_tensor.floor_() # binarize output = inputs.div(keep_prob) * random_tensor return output class SqueezeExcite(nn.Module): def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1): super(SqueezeExcite, self).__init__() self.gate_fn = gate_fn reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) self.act1 = act_layer(inplace=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) def forward(self, x): # tensor.view + mean bad for ONNX export (produces mess of gather ops that break TensorRT) x_se = self.avg_pool(x) x_se = self.conv_reduce(x_se) x_se = self.act1(x_se) x_se = self.conv_expand(x_se) x = x * self.gate_fn(x_se) return x class ConvBnAct(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(ConvBnAct, self).__init__() assert stride in [1, 2] norm_kwargs = norm_kwargs or {} self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) self.bn1 = norm_layer(out_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn1(x) x = self.act1(x) return x class DepthwiseSeparableConv(nn.Module): """ DepthwiseSeparable block Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion factor of 1.0. This is an alternative to having a IR with optional first pw conv. """ def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): super(DepthwiseSeparableConv, self).__init__() assert stride in [1, 2] norm_kwargs = norm_kwargs or {} self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.drop_connect_rate = drop_connect_rate self.conv_dw = select_conv2d( in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) self.bn1 = norm_layer(in_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Squeeze-and-excitation if se_ratio is not None and se_ratio > 0.: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) else: self.se = nn.Identity() self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) if pw_act else nn.Identity() def forward(self, x): residual = x x = self.conv_dw(x) x = self.bn1(x) x = self.act1(x) x = self.se(x) x = self.conv_pw(x) x = self.bn2(x) x = self.act2(x) if self.has_residual: if self.drop_connect_rate > 0.: x = drop_connect(x, self.training, self.drop_connect_rate) x += residual return x class InvertedResidual(nn.Module): """ Inverted residual block w/ optional SE""" def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, conv_kwargs=None, drop_connect_rate=0.): super(InvertedResidual, self).__init__() norm_kwargs = norm_kwargs or {} conv_kwargs = conv_kwargs or {} mid_chs: int = make_divisible(in_chs * exp_ratio) self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_connect_rate = drop_connect_rate # Point-wise expansion self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Depth-wise convolution self.conv_dw = select_conv2d( mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs) self.bn2 = norm_layer(mid_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) # Squeeze-and-excitation if se_ratio is not None and se_ratio > 0.: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) else: self.se = nn.Identity() # for jit.script compat # Point-wise linear projection self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) self.bn3 = norm_layer(out_chs, **norm_kwargs) def forward(self, x): residual = x # Point-wise expansion x = self.conv_pw(x) x = self.bn1(x) x = self.act1(x) # Depth-wise convolution x = self.conv_dw(x) x = self.bn2(x) x = self.act2(x) # Squeeze-and-excitation x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x) x = self.bn3(x) if self.has_residual: if self.drop_connect_rate > 0.: x = drop_connect(x, self.training, self.drop_connect_rate) x += residual return x class CondConvResidual(InvertedResidual): """ Inverted residual block w/ CondConv routing""" def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, num_experts=0, drop_connect_rate=0.): self.num_experts = num_experts conv_kwargs = dict(num_experts=self.num_experts) super(CondConvResidual, self).__init__( in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, drop_connect_rate=drop_connect_rate) self.routing_fn = nn.Linear(in_chs, self.num_experts) def forward(self, x): residual = x # CondConv routing pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) # Point-wise expansion x = self.conv_pw(x, routing_weights) x = self.bn1(x) x = self.act1(x) # Depth-wise convolution x = self.conv_dw(x, routing_weights) x = self.bn2(x) x = self.act2(x) # Squeeze-and-excitation x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x, routing_weights) x = self.bn3(x) if self.has_residual: if self.drop_connect_rate > 0.: x = drop_connect(x, self.training, self.drop_connect_rate) x += residual return x class EdgeResidual(nn.Module): """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride""" def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): super(EdgeResidual, self).__init__() norm_kwargs = norm_kwargs or {} mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio) self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_connect_rate = drop_connect_rate # Expansion convolution self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Squeeze-and-excitation if se_ratio is not None and se_ratio > 0.: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) else: self.se = nn.Identity() # Point-wise linear projection self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs) def forward(self, x): residual = x # Expansion convolution x = self.conv_exp(x) x = self.bn1(x) x = self.act1(x) # Squeeze-and-excitation x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x) x = self.bn2(x) if self.has_residual: if self.drop_connect_rate > 0.: x = drop_connect(x, self.training, self.drop_connect_rate) x += residual return x class EfficientNetBuilder: """ Build Trunk Blocks for Efficient/Mobile Networks This ended up being somewhat of a cross between https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py and https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py """ def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, pad_type='', act_layer=None, se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): self.channel_multiplier = channel_multiplier self.channel_divisor = channel_divisor self.channel_min = channel_min self.pad_type = pad_type self.act_layer = act_layer self.se_kwargs = se_kwargs self.norm_layer = norm_layer self.norm_kwargs = norm_kwargs self.drop_connect_rate = drop_connect_rate # updated during build self.in_chs = None self.block_idx = 0 self.block_count = 0 def _round_channels(self, chs): return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) def _make_block(self, ba): bt = ba.pop('block_type') ba['in_chs'] = self.in_chs ba['out_chs'] = self._round_channels(ba['out_chs']) if 'fake_in_chs' in ba and ba['fake_in_chs']: # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) ba['norm_layer'] = self.norm_layer ba['norm_kwargs'] = self.norm_kwargs ba['pad_type'] = self.pad_type # block act fn overrides the model default ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer assert ba['act_layer'] is not None if bt == 'ir': ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count ba['se_kwargs'] = self.se_kwargs if ba.get('num_experts', 0) > 0: block = CondConvResidual(**ba) else: block = InvertedResidual(**ba) elif bt == 'ds' or bt == 'dsa': ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count ba['se_kwargs'] = self.se_kwargs block = DepthwiseSeparableConv(**ba) elif bt == 'er': ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count ba['se_kwargs'] = self.se_kwargs block = EdgeResidual(**ba) elif bt == 'cn': block = ConvBnAct(**ba) else: assert False, 'Uknkown block type (%s) while building model.' % bt self.in_chs = ba['out_chs'] # update in_chs for arg of next block return block def _make_stack(self, stack_args): blocks = [] # each stack (stage) contains a list of block arguments for i, ba in enumerate(stack_args): if i >= 1: # only the first block in any stack can have a stride > 1 ba['stride'] = 1 block = self._make_block(ba) blocks.append(block) self.block_idx += 1 # incr global idx (across all stacks) return nn.Sequential(*blocks) def __call__(self, in_chs, block_args): """ Build the blocks Args: in_chs: Number of input-channels passed to first block block_args: A list of lists, outer list defines stages, inner list contains strings defining block configuration(s) Return: List of block stacks (each stack wrapped in nn.Sequential) """ self.in_chs = in_chs self.block_count = sum([len(x) for x in block_args]) self.block_idx = 0 blocks = [] # outer list of block_args defines the stacks ('stages' by some conventions) for stack_idx, stack in enumerate(block_args): assert isinstance(stack, list) stack = self._make_stack(stack) blocks.append(stack) return blocks def _parse_ksize(ss): if ss.isdigit(): return int(ss) else: return [int(k) for k in ss.split('.')] def _decode_block_str(block_str): """ Decode block definition string Gets a list of block arg (dicts) through a string notation of arguments. E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip All args can exist in any order with the exception of the leading string which is assumed to indicate the block type. leading string - block type ( ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) r - number of repeat blocks, k - kernel size, s - strides (1-9), e - expansion ratio, c - output channels, se - squeeze/excitation ratio n - activation fn ('re', 'r6', 'hs', or 'sw') Args: block_str: a string representation of block arguments. Returns: A list of block args (dicts) Raises: ValueError: if the string def not properly specified (TODO) """ assert isinstance(block_str, str) ops = block_str.split('_') block_type = ops[0] # take the block type off the front ops = ops[1:] options = {} noskip = False for op in ops: # string options being checked on individual basis, combine if they grow if op == 'noskip': noskip = True elif op.startswith('n'): # activation fn key = op[0] v = op[1:] if v == 're': value = get_act_layer('relu') elif v == 'r6': value = get_act_layer('relu6') elif v == 'hs': value = get_act_layer('hard_swish') elif v == 'sw': value = get_act_layer('swish') else: continue options[key] = value else: # all numeric options splits = re.split(r'(\d.*)', op) if len(splits) >= 2: key, value = splits[:2] options[key] = value # if act_layer is None, the model default (passed to model init) will be used act_layer = options['n'] if 'n' in options else None exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def num_repeat = int(options['r']) # each type of block has different valid arguments, fill accordingly if block_type == 'ir': block_args = dict( block_type=block_type, dw_kernel_size=_parse_ksize(options['k']), exp_kernel_size=exp_kernel_size, pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), exp_ratio=float(options['e']), se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), act_layer=act_layer, noskip=noskip, ) if 'cc' in options: block_args['num_experts'] = int(options['cc']) elif block_type == 'ds' or block_type == 'dsa': block_args = dict( block_type=block_type, dw_kernel_size=_parse_ksize(options['k']), pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), act_layer=act_layer, pw_act=block_type == 'dsa', noskip=block_type == 'dsa' or noskip, ) elif block_type == 'er': block_args = dict( block_type=block_type, exp_kernel_size=_parse_ksize(options['k']), pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), exp_ratio=float(options['e']), fake_in_chs=fake_in_chs, se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), act_layer=act_layer, noskip=noskip, ) elif block_type == 'cn': block_args = dict( block_type=block_type, kernel_size=int(options['k']), out_chs=int(options['c']), stride=int(options['s']), act_layer=act_layer, ) else: assert False, 'Unknown block type (%s)' % block_type return block_args, num_repeat def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): """ Per-stage depth scaling Scales the block repeats in each stage. This depth scaling impl maintains compatibility with the EfficientNet scaling method, while allowing sensible scaling for other models that may have multiple block arg definitions in each stage. """ # We scale the total repeat count for each stage, there may be multiple # block arg defs per stage so we need to sum. num_repeat = sum(repeats) if depth_trunc == 'round': # Truncating to int by rounding allows stages with few repeats to remain # proportionally smaller for longer. This is a good choice when stage definitions # include single repeat stages that we'd prefer to keep that way as long as possible num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) else: # The default for EfficientNet truncates repeats to int via 'ceil'. # Any multiplier > 1.0 will result in an increased depth for every stage. num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) # Proportionally distribute repeat count scaling to each block definition in the stage. # Allocation is done in reverse as it results in the first block being less likely to be scaled. # The first block makes less sense to repeat in most of the arch definitions. repeats_scaled = [] for r in repeats[::-1]: rs = max(1, round((r / num_repeat * num_repeat_scaled))) repeats_scaled.append(rs) num_repeat -= r num_repeat_scaled -= rs repeats_scaled = repeats_scaled[::-1] # Apply the calculated scaling to each block arg in the stage sa_scaled = [] for ba, rep in zip(stack_args, repeats_scaled): sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) return sa_scaled def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): arch_args = [] for stack_idx, block_strings in enumerate(arch_def): assert isinstance(block_strings, list) stack_args = [] repeats = [] for block_str in block_strings: assert isinstance(block_str, str) ba, rep = _decode_block_str(block_str) if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: ba['num_experts'] *= experts_multiplier stack_args.append(ba) repeats.append(rep) if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) else: arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) return arch_args def initialize_weight_goog(m, n='', fix_group_fanout=True): # weight init as per Tensorflow Official impl # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py if isinstance(m, CondConv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels if fix_group_fanout: fan_out //= m.groups init_weight_fn = get_condconv_initializer( lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) init_weight_fn(m.weight) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels if fix_group_fanout: fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1.0) m.bias.data.zero_() elif isinstance(m, nn.Linear): fan_out = m.weight.size(0) # fan-out fan_in = 0 if 'routing_fn' in n: fan_in = m.weight.size(1) init_range = 1.0 / math.sqrt(fan_in + fan_out) m.weight.data.uniform_(-init_range, init_range) m.bias.data.zero_() def initialize_weight_default(m, n=''): if isinstance(m, CondConv2d): init_fn = get_condconv_initializer(partial( nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) init_fn(m.weight) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1.0) m.bias.data.zero_() elif isinstance(m, nn.Linear): nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') ================================================ FILE: models/gen_efficientnet.py ================================================ """ Generic Efficient Networks A generic MobileNet class with building blocks to support a variety of models: * EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports) - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 * EfficientNet-Lite * MixNet (Small, Medium, and Large) - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 * MNasNet B1, A1 (SE), Small - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 * FBNet-C - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 * Single-Path NAS Pixel1 - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 * And likely more... Hacked together by Ross Wightman """ import torch.nn as nn import torch.nn.functional as F from .helpers import load_pretrained from .efficientnet_builder import * __all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', 'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el', 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8', 'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap', 'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap', 'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns', 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns', 'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475', 'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el', 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', 'tf_efficientnet_lite4', 'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l'] model_urls = { 'mnasnet_050': None, 'mnasnet_075': None, 'mnasnet_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', 'mnasnet_140': None, 'mnasnet_small': None, 'semnasnet_050': None, 'semnasnet_075': None, 'semnasnet_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', 'semnasnet_140': None, 'fbnetc_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', 'spnasnet_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', 'efficientnet_b0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth', 'efficientnet_b1': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', 'efficientnet_b2': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', 'efficientnet_b3': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra-a5e2fbc7.pth', 'efficientnet_b4': None, 'efficientnet_b5': None, 'efficientnet_b6': None, 'efficientnet_b7': None, 'efficientnet_b8': None, 'efficientnet_l2': None, 'efficientnet_es': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', 'efficientnet_em': None, 'efficientnet_el': None, 'efficientnet_cc_b0_4e': None, 'efficientnet_cc_b0_8e': None, 'efficientnet_cc_b1_8e': None, 'efficientnet_lite0': None, 'efficientnet_lite1': None, 'efficientnet_lite2': None, 'efficientnet_lite3': None, 'efficientnet_lite4': None, 'tf_efficientnet_b0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', 'tf_efficientnet_b1': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', 'tf_efficientnet_b2': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', 'tf_efficientnet_b3': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', 'tf_efficientnet_b4': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', 'tf_efficientnet_b5': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', 'tf_efficientnet_b6': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', 'tf_efficientnet_b7': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', 'tf_efficientnet_b8': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', 'tf_efficientnet_b0_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', 'tf_efficientnet_b1_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', 'tf_efficientnet_b2_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', 'tf_efficientnet_b3_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', 'tf_efficientnet_b4_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', 'tf_efficientnet_b5_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', 'tf_efficientnet_b6_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', 'tf_efficientnet_b7_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', 'tf_efficientnet_b8_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', 'tf_efficientnet_b0_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', 'tf_efficientnet_b1_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', 'tf_efficientnet_b2_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', 'tf_efficientnet_b3_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', 'tf_efficientnet_b4_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', 'tf_efficientnet_b5_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', 'tf_efficientnet_b6_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', 'tf_efficientnet_b7_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', 'tf_efficientnet_l2_ns_475': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', 'tf_efficientnet_l2_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', 'tf_efficientnet_es': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', 'tf_efficientnet_em': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', 'tf_efficientnet_el': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', 'tf_efficientnet_cc_b0_4e': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', 'tf_efficientnet_cc_b0_8e': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', 'tf_efficientnet_cc_b1_8e': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', 'tf_efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', 'tf_efficientnet_lite1': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', 'tf_efficientnet_lite2': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', 'tf_efficientnet_lite3': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', 'tf_efficientnet_lite4': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', 'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth', 'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth', 'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth', 'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth', 'tf_mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth', 'tf_mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth', 'tf_mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth', } class GenEfficientNet(nn.Module): """ Generic EfficientNets An implementation of mobile optimized networks that covers: * EfficientNet (B0-B8, L2, CondConv, EdgeTPU) * MixNet (Small, Medium, and Large, XL) * MNASNet A1, B1, and small * FBNet C * Single-Path NAS Pixel1 """ def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False, channel_multiplier=1.0, channel_divisor=8, channel_min=None, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): super(GenEfficientNet, self).__init__() self.drop_rate = drop_rate if not fix_stem: stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) in_chs = stem_size builder = EfficientNetBuilder( channel_multiplier, channel_divisor, channel_min, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate) self.blocks = nn.Sequential(*builder(in_chs, block_args)) in_chs = builder.in_chs self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type) self.bn2 = norm_layer(num_features, **norm_kwargs) self.act2 = act_layer(inplace=True) self.global_pool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Linear(num_features, num_classes) for n, m in self.named_modules(): if weight_init == 'goog': initialize_weight_goog(m, n) else: initialize_weight_default(m, n) self.feature_num = num_features def features(self, x): x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) x = self.blocks(x) x = self.conv_head(x) x = self.bn2(x) x = self.act2(x) return x def as_sequential(self): layers = [self.conv_stem, self.bn1, self.act1] layers.extend(self.blocks) layers.extend([ self.conv_head, self.bn2, self.act2, self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) def forward(self, x): feature = self.features(x) x = self.global_pool(feature) x = x.flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return x, feature.detach() def _create_model(model_kwargs, variant, pretrained=False): as_sequential = model_kwargs.pop('as_sequential', False) model = GenEfficientNet(**model_kwargs) if pretrained: load_pretrained(model, model_urls[variant]) if as_sequential: model = model.as_sequential() return model def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a mnasnet-a1 model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet Paper: https://arxiv.org/pdf/1807.11626.pdf. Args: channel_multiplier: multiplier to number of channels per layer. """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_noskip'], # stage 1, 112x112 in ['ir_r2_k3_s2_e6_c24'], # stage 2, 56x56 in ['ir_r3_k5_s2_e3_c40_se0.25'], # stage 3, 28x28 in ['ir_r4_k3_s2_e6_c80'], # stage 4, 14x14in ['ir_r2_k3_s1_e6_c112_se0.25'], # stage 5, 14x14in ['ir_r3_k5_s2_e6_c160_se0.25'], # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320'], ] model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a mnasnet-b1 model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet Paper: https://arxiv.org/pdf/1807.11626.pdf. Args: channel_multiplier: multiplier to number of channels per layer. """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_c16_noskip'], # stage 1, 112x112 in ['ir_r3_k3_s2_e3_c24'], # stage 2, 56x56 in ['ir_r3_k5_s2_e3_c40'], # stage 3, 28x28 in ['ir_r3_k5_s2_e6_c80'], # stage 4, 14x14in ['ir_r2_k3_s1_e6_c96'], # stage 5, 14x14in ['ir_r4_k5_s2_e6_c192'], # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320_noskip'] ] model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a mnasnet-b1 model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet Paper: https://arxiv.org/pdf/1807.11626.pdf. Args: channel_multiplier: multiplier to number of channels per layer. """ arch_def = [ ['ds_r1_k3_s1_c8'], ['ir_r1_k3_s2_e3_c16'], ['ir_r2_k3_s2_e6_c16'], ['ir_r4_k5_s2_e6_c32_se0.25'], ['ir_r3_k3_s1_e6_c32_se0.25'], ['ir_r3_k5_s2_e6_c88_se0.25'], ['ir_r1_k3_s1_e6_c144'] ] model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=8, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """ FBNet-C Paper: https://arxiv.org/abs/1812.03443 Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, it was used to confirm some building block details """ arch_def = [ ['ir_r1_k3_s1_e1_c16'], ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], ['ir_r4_k5_s2_e6_c184'], ['ir_r1_k3_s1_e6_c352'], ] model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=16, num_features=1984, # paper suggests this, but is not 100% clear channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates the Single-Path NAS model from search targeted for Pixel1 phone. Paper: https://arxiv.org/abs/1904.02877 Args: channel_multiplier: multiplier to number of channels per layer. """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_c16_noskip'], # stage 1, 112x112 in ['ir_r3_k3_s2_e3_c24'], # stage 2, 56x56 in ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], # stage 3, 28x28 in ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], # stage 4, 14x14in ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], # stage 5, 14x14in ['ir_r4_k5_s2_e6_c192'], # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320_noskip'] ] model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): """Creates an EfficientNet model. Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py Paper: https://arxiv.org/abs/1905.11946 EfficientNet params name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 'efficientnet-b8': (2.2, 3.6, 672, 0.5), Args: channel_multiplier: multiplier to number of channels per layer depth_multiplier: multiplier to number of repeats per stage """ arch_def = [ ['ds_r1_k3_s1_e1_c16_se0.25'], ['ir_r2_k3_s2_e6_c24_se0.25'], ['ir_r2_k5_s2_e6_c40_se0.25'], ['ir_r3_k3_s2_e6_c80_se0.25'], ['ir_r3_k5_s1_e6_c112_se0.25'], ['ir_r4_k5_s2_e6_c192_se0.25'], ['ir_r1_k3_s1_e6_c320_se0.25'], ] model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier), num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'swish'), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): arch_def = [ # NOTE `fc` is present to override a mismatch between stem channels and in chs not # present in other models ['er_r1_k3_s1_e4_c24_fc24_noskip'], ['er_r2_k3_s2_e8_c32'], ['er_r4_k3_s2_e8_c48'], ['ir_r5_k5_s2_e8_c96'], ['ir_r4_k5_s1_e8_c144'], ['ir_r2_k5_s2_e8_c192'], ] model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier), num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_efficientnet_condconv( variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): """Creates an efficientnet-condconv model.""" arch_def = [ ['ds_r1_k3_s1_e1_c16_se0.25'], ['ir_r2_k3_s2_e6_c24_se0.25'], ['ir_r2_k5_s2_e6_c40_se0.25'], ['ir_r3_k3_s2_e6_c80_se0.25'], ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], ] model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'swish'), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): """Creates an EfficientNet-Lite model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite Paper: https://arxiv.org/abs/1905.11946 EfficientNet params name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), Args: channel_multiplier: multiplier to number of channels per layer depth_multiplier: multiplier to number of repeats per stage """ arch_def = [ ['ds_r1_k3_s1_e1_c16'], ['ir_r2_k3_s2_e6_c24'], ['ir_r2_k5_s2_e6_c40'], ['ir_r3_k3_s2_e6_c80'], ['ir_r3_k5_s1_e6_c112'], ['ir_r4_k5_s2_e6_c192'], ['ir_r1_k3_s1_e6_c320'], ] model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), num_features=1280, stem_size=32, fix_stem=True, channel_multiplier=channel_multiplier, act_layer=nn.ReLU6, norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MixNet Small model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet Paper: https://arxiv.org/abs/1907.09595 """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16'], # relu # stage 1, 112x112 in ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu # stage 2, 56x56 in ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish # stage 3, 28x28 in ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish # stage 4, 14x14in ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish # stage 5, 14x14in ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish # 7x7 ] model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=1536, stem_size=16, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): """Creates a MixNet Medium-Large model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet Paper: https://arxiv.org/abs/1907.09595 """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c24'], # relu # stage 1, 112x112 in ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu # stage 2, 56x56 in ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish # stage 3, 28x28 in ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish # stage 4, 14x14in ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish # stage 5, 14x14in ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish # 7x7 ] model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), num_features=1536, stem_size=24, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def mnasnet_050(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 0.5. """ model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) return model def mnasnet_075(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 0.75. """ model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) return model def mnasnet_100(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) return model def mnasnet_b1(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ return mnasnet_100(pretrained, **kwargs) def mnasnet_140(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.4 """ model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) return model def semnasnet_050(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) return model def semnasnet_075(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) return model def semnasnet_100(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) return model def mnasnet_a1(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ return semnasnet_100(pretrained, **kwargs) def semnasnet_140(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) return model def mnasnet_small(pretrained=False, **kwargs): """ MNASNet Small, depth multiplier of 1.0. """ model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) return model def fbnetc_100(pretrained=False, **kwargs): """ FBNet-C """ if pretrained: # pretrained model trained with non-default BN epsilon kwargs['bn_eps'] = BN_EPS_TF_DEFAULT model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) return model def spnasnet_100(pretrained=False, **kwargs): """ Single-Path NAS Pixel1""" model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) return model def efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0 """ # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def efficientnet_b1(pretrained=False, **kwargs): """ EfficientNet-B1 """ # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def efficientnet_b2(pretrained=False, **kwargs): """ EfficientNet-B2 """ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def efficientnet_b3(pretrained=False, **kwargs): """ EfficientNet-B3 """ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def efficientnet_b4(pretrained=False, **kwargs): """ EfficientNet-B4 """ # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def efficientnet_b5(pretrained=False, **kwargs): """ EfficientNet-B5 """ # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model def efficientnet_b6(pretrained=False, **kwargs): """ EfficientNet-B6 """ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) return model def efficientnet_b7(pretrained=False, **kwargs): """ EfficientNet-B7 """ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model def efficientnet_b8(pretrained=False, **kwargs): """ EfficientNet-B8 """ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) return model def efficientnet_l2(pretrained=False, **kwargs): """ EfficientNet-L2. """ # NOTE for train, drop_rate should be 0.5 model = _gen_efficientnet( 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) return model def efficientnet_es(pretrained=False, **kwargs): """ EfficientNet-Edge Small. """ model = _gen_efficientnet_edge( 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def efficientnet_em(pretrained=False, **kwargs): """ EfficientNet-Edge-Medium. """ model = _gen_efficientnet_edge( 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def efficientnet_el(pretrained=False, **kwargs): """ EfficientNet-Edge-Large. """ model = _gen_efficientnet_edge( 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def efficientnet_cc_b0_4e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B0 w/ 8 Experts """ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 model = _gen_efficientnet_condconv( 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def efficientnet_cc_b0_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B0 w/ 8 Experts """ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 model = _gen_efficientnet_condconv( 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, pretrained=pretrained, **kwargs) return model def efficientnet_cc_b1_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B1 w/ 8 Experts """ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 model = _gen_efficientnet_condconv( 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, pretrained=pretrained, **kwargs) return model def efficientnet_lite0(pretrained=False, **kwargs): """ EfficientNet-Lite0 """ model = _gen_efficientnet_lite( 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def efficientnet_lite1(pretrained=False, **kwargs): """ EfficientNet-Lite1 """ model = _gen_efficientnet_lite( 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def efficientnet_lite2(pretrained=False, **kwargs): """ EfficientNet-Lite2 """ model = _gen_efficientnet_lite( 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def efficientnet_lite3(pretrained=False, **kwargs): """ EfficientNet-Lite3 """ model = _gen_efficientnet_lite( 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def efficientnet_lite4(pretrained=False, **kwargs): """ EfficientNet-Lite4 """ model = _gen_efficientnet_lite( 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b1(pretrained=False, **kwargs): """ EfficientNet-B1 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b2(pretrained=False, **kwargs): """ EfficientNet-B2 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b3(pretrained=False, **kwargs): """ EfficientNet-B3 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b4(pretrained=False, **kwargs): """ EfficientNet-B4 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b5(pretrained=False, **kwargs): """ EfficientNet-B5 RandAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b6(pretrained=False, **kwargs): """ EfficientNet-B6 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b7(pretrained=False, **kwargs): """ EfficientNet-B7 RandAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b8(pretrained=False, **kwargs): """ EfficientNet-B8 RandAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b0_ap(pretrained=False, **kwargs): """ EfficientNet-B0 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b1_ap(pretrained=False, **kwargs): """ EfficientNet-B1 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b2_ap(pretrained=False, **kwargs): """ EfficientNet-B2 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b3_ap(pretrained=False, **kwargs): """ EfficientNet-B3 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b4_ap(pretrained=False, **kwargs): """ EfficientNet-B4 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b5_ap(pretrained=False, **kwargs): """ EfficientNet-B5 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b6_ap(pretrained=False, **kwargs): """ EfficientNet-B6 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b7_ap(pretrained=False, **kwargs): """ EfficientNet-B7 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b8_ap(pretrained=False, **kwargs): """ EfficientNet-B8 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b0_ns(pretrained=False, **kwargs): """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b1_ns(pretrained=False, **kwargs): """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b2_ns(pretrained=False, **kwargs): """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b3_ns(pretrained=False, **kwargs): """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b4_ns(pretrained=False, **kwargs): """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b5_ns(pretrained=False, **kwargs): """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b6_ns(pretrained=False, **kwargs): """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b7_ns(pretrained=False, **kwargs): """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) return model def tf_efficientnet_l2_ns(pretrained=False, **kwargs): """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) return model def tf_efficientnet_es(pretrained=False, **kwargs): """ EfficientNet-Edge Small. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_em(pretrained=False, **kwargs): """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_el(pretrained=False, **kwargs): """ EfficientNet-Edge-Large. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B0 w/ 4 Experts """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_condconv( 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B0 w/ 8 Experts """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_condconv( 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B1 w/ 8 Experts """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_condconv( 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_lite0(pretrained=False, **kwargs): """ EfficientNet-Lite0. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_lite( 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_lite1(pretrained=False, **kwargs): """ EfficientNet-Lite1. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_lite( 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_lite2(pretrained=False, **kwargs): """ EfficientNet-Lite2. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_lite( 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_lite3(pretrained=False, **kwargs): """ EfficientNet-Lite3. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_lite( 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def tf_efficientnet_lite4(pretrained=False, **kwargs): """ EfficientNet-Lite4. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_lite( 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def mixnet_s(pretrained=False, **kwargs): """Creates a MixNet Small model. """ # NOTE for train set drop_rate=0.2 model = _gen_mixnet_s( 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model def mixnet_m(pretrained=False, **kwargs): """Creates a MixNet Medium model. """ # NOTE for train set drop_rate=0.25 model = _gen_mixnet_m( 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model def mixnet_l(pretrained=False, **kwargs): """Creates a MixNet Large model. """ # NOTE for train set drop_rate=0.25 model = _gen_mixnet_m( 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) return model def mixnet_xl(pretrained=False, **kwargs): """Creates a MixNet Extra-Large model. Not a paper spec, experimental def by RW w/ depth scaling. """ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 model = _gen_mixnet_m( 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def mixnet_xxl(pretrained=False, **kwargs): """Creates a MixNet Double Extra Large model. Not a paper spec, experimental def by RW w/ depth scaling. """ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 model = _gen_mixnet_m( 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) return model def tf_mixnet_s(pretrained=False, **kwargs): """Creates a MixNet Small model. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_s( 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_mixnet_m(pretrained=False, **kwargs): """Creates a MixNet Medium model. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_m( 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_mixnet_l(pretrained=False, **kwargs): """Creates a MixNet Large model. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_m( 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) return model ================================================ FILE: models/helpers.py ================================================ import torch import os from collections import OrderedDict try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url def load_checkpoint(model, checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path): print("=> Loading checkpoint '{}'".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): if k.startswith('module'): name = k[7:] # remove `module.` else: name = k new_state_dict[name] = v model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint) print("=> Loaded checkpoint '{}'".format(checkpoint_path)) else: print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() def load_pretrained(model, url, filter_fn=None, strict=True): if not url: print("=> Warning: Pretrained model URL is empty, using random initialization.") return state_dict = torch.load(url, map_location='cpu') input_conv = 'conv_stem' classifier = 'classifier' in_chans = getattr(model, input_conv).weight.shape[1] num_classes = getattr(model, classifier).weight.shape[0] input_conv_weight = input_conv + '.weight' pretrained_in_chans = state_dict[input_conv_weight].shape[1] if in_chans != pretrained_in_chans: if in_chans == 1: print('=> Converting pretrained input conv {} from {} to 1 channel'.format( input_conv_weight, pretrained_in_chans)) conv1_weight = state_dict[input_conv_weight] state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) else: print('=> Discarding pretrained input conv {} since input channel count != {}'.format( input_conv_weight, pretrained_in_chans)) del state_dict[input_conv_weight] strict = False classifier_weight = classifier + '.weight' pretrained_num_classes = state_dict[classifier_weight].shape[0] if num_classes != pretrained_num_classes: print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) del state_dict[classifier_weight] del state_dict[classifier + '.bias'] strict = False if filter_fn is not None: state_dict = filter_fn(state_dict) model.load_state_dict(state_dict, strict=strict) ================================================ FILE: models/mobilenetv3.py ================================================ """ MobileNet-V3 A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by Ross Wightman """ import torch.nn as nn import torch.nn.functional as F from .helpers import load_pretrained from .efficientnet_builder import * __all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100', 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100', 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100', 'mobilenetv3_large_125'] model_urls = { 'mobilenetv3_rw': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', 'mobilenetv3_large_075': None, 'mobilenetv3_large_100': None, 'mobilenetv3_large_125': None, 'mobilenetv3_large_minimal_100': None, 'mobilenetv3_small_075': None, 'mobilenetv3_small_100': None, 'mobilenetv3_small_minimal_100': None, 'tf_mobilenetv3_large_075': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', 'tf_mobilenetv3_large_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', 'tf_mobilenetv3_large_minimal_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', 'tf_mobilenetv3_small_075': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', 'tf_mobilenetv3_small_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', 'tf_mobilenetv3_small_minimal_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', } class MobileNetV3(nn.Module): """ MobileNet-V3 A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the head convolution without a final batch-norm layer before the classifier. Paper: https://arxiv.org/abs/1905.02244 """ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): super(MobileNetV3, self).__init__() self.drop_rate = drop_rate stem_size = round_channels(stem_size, channel_multiplier) self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) in_chs = stem_size builder = EfficientNetBuilder( channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs, norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate) self.blocks = nn.Sequential(*builder(in_chs, block_args)) in_chs = builder.in_chs self.global_pool = nn.AdaptiveAvgPool2d(1) self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias) self.act2 = act_layer(inplace=True) self.feature_num = num_features print(num_features) self.classifier = nn.Linear(num_features, num_classes) for m in self.modules(): if weight_init == 'goog': initialize_weight_goog(m) else: initialize_weight_default(m) def as_sequential(self): layers = [self.conv_stem, self.bn1, self.act1] layers.extend(self.blocks) layers.extend([ self.global_pool, self.conv_head, self.act2, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) def features(self, x): x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) features = self.blocks(x) x = self.global_pool(features) x = self.conv_head(x) x = self.act2(x) return x, features.detach() def forward(self, x): x, features = self.features(x) x = x.flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return (x), features.detach() def _create_model(model_kwargs, variant, pretrained=False): as_sequential = model_kwargs.pop('as_sequential', False) model = MobileNetV3(**model_kwargs) if pretrained and model_urls[variant]: load_pretrained(model, model_urls[variant]) if as_sequential: model = model.as_sequential() return model def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MobileNet-V3 model (RW variant). Paper: https://arxiv.org/abs/1905.02244 This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the eventual Tensorflow reference impl but has a few differences: 1. This model has no bias on the head convolution 2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet 3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer from their parent block 4. This model does not enforce divisible by 8 limitation on the SE reduction channel count Overall the changes are fairly minor and result in a very small parameter count difference and no top-1/5 Args: channel_multiplier: multiplier to number of channels per layer. """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu # stage 1, 112x112 in ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu # stage 2, 56x56 in ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu # stage 3, 28x28 in ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish # stage 4, 14x14in ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish # stage 5, 14x14in ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], # hard-swish ] model_kwargs = dict( block_args=decode_arch_def(arch_def), head_bias=False, # one of my mistakes channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'hard_swish'), se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MobileNet-V3 large/small/minimal models. Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py Paper: https://arxiv.org/abs/1905.02244 Args: channel_multiplier: multiplier to number of channels per layer. """ if 'small' in variant: num_features = 1024 if 'minimal' in variant: act_layer = 'relu' arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s2_e1_c16'], # stage 1, 56x56 in ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], # stage 2, 28x28 in ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], # stage 3, 14x14 in ['ir_r2_k3_s1_e3_c48'], # stage 4, 14x14in ['ir_r3_k3_s2_e6_c96'], # stage 6, 7x7 in ['cn_r1_k1_s1_c576'], ] else: act_layer = 'hard_swish' arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu # stage 1, 56x56 in ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu # stage 2, 28x28 in ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish # stage 3, 14x14 in ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish # stage 4, 14x14in ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish # stage 6, 7x7 in ['cn_r1_k1_s1_c576'], # hard-swish ] else: num_features = 1280 if 'minimal' in variant: act_layer = 'relu' arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16'], # stage 1, 112x112 in ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], # stage 2, 56x56 in ['ir_r3_k3_s2_e3_c40'], # stage 3, 28x28 in ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # stage 4, 14x14in ['ir_r2_k3_s1_e6_c112'], # stage 5, 14x14in ['ir_r3_k3_s2_e6_c160'], # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], ] else: act_layer = 'hard_swish' arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_nre'], # relu # stage 1, 112x112 in ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu # stage 2, 56x56 in ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu # stage 3, 28x28 in ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish # stage 4, 14x14in ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish # stage 5, 14x14in ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], # hard-swish ] model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, stem_size=16, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, act_layer), se_kwargs=dict( act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def mobilenetv3_rw(pretrained=False, **kwargs): """ MobileNet-V3 RW Attn: See note in gen function for this variant. """ # NOTE for train set drop_rate=0.2 if pretrained: # pretrained model trained with non-default BN epsilon kwargs['bn_eps'] = BN_EPS_TF_DEFAULT model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) return model def mobilenetv3_large_075(pretrained=False, **kwargs): """ MobileNet V3 Large 0.75""" # NOTE for train set drop_rate=0.2 model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) return model def mobilenetv3_large_100(pretrained=False, **kwargs): """ MobileNet V3 Large 1.0 """ # NOTE for train set drop_rate=0.2 model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) return model def mobilenetv3_large_125(pretrained=False, **kwargs): """ MobileNet V3 Large 1.25 """ # NOTE for train set drop_rate=0.2 model = _gen_mobilenet_v3('mobilenetv3_large_125', 1.25, pretrained=pretrained, **kwargs) return model def mobilenetv3_large_minimal_100(pretrained=False, **kwargs): """ MobileNet V3 Large (Minimalistic) 1.0 """ # NOTE for train set drop_rate=0.2 model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) return model def mobilenetv3_small_075(pretrained=False, **kwargs): """ MobileNet V3 Small 0.75 """ model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) return model def mobilenetv3_small_100(pretrained=False, **kwargs): """ MobileNet V3 Small 1.0 """ model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) return model def mobilenetv3_small_minimal_100(pretrained=False, **kwargs): """ MobileNet V3 Small (Minimalistic) 1.0 """ model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_large_075(pretrained=False, **kwargs): """ MobileNet V3 Large 0.75. Tensorflow compat variant. """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_large_100(pretrained=False, **kwargs): """ MobileNet V3 Large 1.0. Tensorflow compat variant. """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_small_075(pretrained=False, **kwargs): """ MobileNet V3 Small 0.75. Tensorflow compat variant. """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_small_100(pretrained=False, **kwargs): """ MobileNet V3 Small 1.0. Tensorflow compat variant.""" kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) return model ================================================ FILE: models/model_factory.py ================================================ from .mobilenetv3 import * from .gen_efficientnet import * from .helpers import load_checkpoint def create_model( model_name='mnasnet_100', pretrained=None, num_classes=1000, in_chans=3, checkpoint_path='', **kwargs): margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained) if model_name in globals(): create_fn = globals()[model_name] model = create_fn(**margs, **kwargs) else: raise RuntimeError('Unknown model (%s)' % model_name) if checkpoint_path and not pretrained: load_checkpoint(model, checkpoint_path) return model ================================================ FILE: models/resnet.py ================================================ import torch.nn as nn import torch.utils.model_zoo as model_zoo __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', } def conv3x3(in_planes, out_planes, stride=1, groups=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, norm_layer=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, norm_layer=None): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self.inplanes = 64 self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): if norm_layer is None: norm_layer = nn.BatchNorm2d downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, norm_layer=norm_layer)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) output = self.avgpool(x) output = output.view(x.size(0), -1) return output, x.detach() def resnet18(pretrained=False, **kwargs): """Constructs a ResNet-18 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) return model def resnet34(pretrained=False, **kwargs): """Constructs a ResNet-34 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) return model def resnet50(pretrained=False, **kwargs): """Constructs a ResNet-50 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) return model def resnet101(pretrained=False, **kwargs): """Constructs a ResNet-101 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) return model def resnet152(pretrained=False, **kwargs): """Constructs a ResNet-152 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) return model def resnext50_32x4d(pretrained=False, **kwargs): model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) return model def resnext101_32x8d(pretrained=False, **kwargs): model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) return model def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) ================================================ FILE: models/version.py ================================================ __version__ = '0.9.8' ================================================ FILE: network.py ================================================ import torch import torchvision import torch.nn as nn import torch.nn.functional as F import math class Memory: def __init__(self): self.actions = [] self.states = [] self.logprobs = [] self.rewards = [] self.is_terminals = [] self.hidden = [] def clear_memory(self): del self.actions[:] del self.states[:] del self.logprobs[:] del self.rewards[:] del self.is_terminals[:] del self.hidden[:] class ActorCritic(nn.Module): def __init__(self, feature_dim, state_dim, hidden_state_dim=1024, policy_conv=True, action_std=0.1): super(ActorCritic, self).__init__() # encoder with convolution layer for MobileNetV3, EfficientNet and RegNet if policy_conv: self.state_encoder = nn.Sequential( nn.Conv2d(feature_dim, 32, kernel_size=1, stride=1, padding=0, bias=False), nn.ReLU(), nn.Flatten(), nn.Linear(int(state_dim * 32 / feature_dim), hidden_state_dim), nn.ReLU() ) # encoder with linear layer for ResNet and DenseNet else: self.state_encoder = nn.Sequential( nn.Linear(state_dim, 2048), nn.ReLU(), nn.Linear(2048, hidden_state_dim), nn.ReLU() ) self.gru = nn.GRU(hidden_state_dim, hidden_state_dim, batch_first=False) self.actor = nn.Sequential( nn.Linear(hidden_state_dim, 2), nn.Sigmoid()) self.critic = nn.Sequential( nn.Linear(hidden_state_dim, 1)) self.action_var = torch.full((2,), action_std).cuda() self.hidden_state_dim = hidden_state_dim self.policy_conv = policy_conv self.feature_dim = feature_dim self.feature_ratio = int(math.sqrt(state_dim/feature_dim)) def forward(self): raise NotImplementedError def act(self, state_ini, memory, restart_batch=False, training=False): if restart_batch: del memory.hidden[:] memory.hidden.append(torch.zeros(1, state_ini.size(0), self.hidden_state_dim).cuda()) if not self.policy_conv: state = state_ini.flatten(1) else: state = state_ini state = self.state_encoder(state) state, hidden_output = self.gru(state.view(1, state.size(0), state.size(1)), memory.hidden[-1]) memory.hidden.append(hidden_output) state = state[0] action_mean = self.actor(state) cov_mat = torch.diag(self.action_var).cuda() dist = torch.distributions.multivariate_normal.MultivariateNormal(action_mean, scale_tril=cov_mat) action = dist.sample().cuda() if training: action = F.relu(action) action = 1 - F.relu(1 - action) action_logprob = dist.log_prob(action).cuda() memory.states.append(state_ini) memory.actions.append(action) memory.logprobs.append(action_logprob) else: action = action_mean return action.detach() def evaluate(self, state, action): seq_l = state.size(0) batch_size = state.size(1) if not self.policy_conv: state = state.flatten(2) state = state.view(seq_l * batch_size, state.size(2)) else: state = state.view(seq_l * batch_size, state.size(2), state.size(3), state.size(4)) state = self.state_encoder(state) state = state.view(seq_l, batch_size, -1) state, hidden = self.gru(state, torch.zeros(1, batch_size, state.size(2)).cuda()) state = state.view(seq_l * batch_size, -1) action_mean = self.actor(state) cov_mat = torch.diag(self.action_var).cuda() dist = torch.distributions.multivariate_normal.MultivariateNormal(action_mean, scale_tril=cov_mat) action_logprobs = dist.log_prob(torch.squeeze(action.view(seq_l * batch_size, -1))).cuda() dist_entropy = dist.entropy().cuda() state_value = self.critic(state) return action_logprobs.view(seq_l, batch_size), \ state_value.view(seq_l, batch_size), \ dist_entropy.view(seq_l, batch_size) class PPO: def __init__(self, feature_dim, state_dim, hidden_state_dim, policy_conv, action_std=0.1, lr=0.0003, betas=(0.9, 0.999), gamma=0.7, K_epochs=1, eps_clip=0.2): self.lr = lr self.betas = betas self.gamma = gamma self.eps_clip = eps_clip self.K_epochs = K_epochs self.policy = ActorCritic(feature_dim, state_dim, hidden_state_dim, policy_conv, action_std).cuda() self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas) self.policy_old = ActorCritic(feature_dim, state_dim, hidden_state_dim, policy_conv, action_std).cuda() self.policy_old.load_state_dict(self.policy.state_dict()) self.MseLoss = nn.MSELoss() def select_action(self, state, memory, restart_batch=False, training=True): return self.policy_old.act(state, memory, restart_batch, training) def update(self, memory): rewards = [] discounted_reward = 0 for reward in reversed(memory.rewards): discounted_reward = reward + (self.gamma * discounted_reward) rewards.insert(0, discounted_reward) rewards = torch.cat(rewards, 0).cuda() rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5) old_states = torch.stack(memory.states, 0).cuda().detach() old_actions = torch.stack(memory.actions, 0).cuda().detach() old_logprobs = torch.stack(memory.logprobs, 0).cuda().detach() for _ in range(self.K_epochs): logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions) ratios = torch.exp(logprobs - old_logprobs.detach()) advantages = rewards - state_values.detach() surr1 = ratios * advantages surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy self.optimizer.zero_grad() loss.mean().backward() self.optimizer.step() self.policy_old.load_state_dict(self.policy.state_dict()) class Full_layer(torch.nn.Module): def __init__(self, feature_num, hidden_state_dim=1024, fc_rnn=True, class_num=1000): super(Full_layer, self).__init__() self.class_num = class_num self.feature_num = feature_num self.hidden_state_dim = hidden_state_dim self.hidden = None self.fc_rnn = fc_rnn # classifier with RNN for ResNet, DenseNet and RegNet if fc_rnn: self.rnn = nn.GRU(feature_num, self.hidden_state_dim) self.fc = nn.Linear(self.hidden_state_dim, class_num) # cascaded classifier for MobileNetV3 and EfficientNet else: self.fc_2 = nn.Linear(self.feature_num * 2, class_num) self.fc_3 = nn.Linear(self.feature_num * 3, class_num) self.fc_4 = nn.Linear(self.feature_num * 4, class_num) self.fc_5 = nn.Linear(self.feature_num * 5, class_num) def forward(self, x, restart=False): if self.fc_rnn: if restart: output, h_n = self.rnn(x.view(1, x.size(0), x.size(1)), torch.zeros(1, x.size(0), self.hidden_state_dim).cuda()) self.hidden = h_n else: output, h_n = self.rnn(x.view(1, x.size(0), x.size(1)), self.hidden) self.hidden = h_n return self.fc(output[0]) else: if restart: self.hidden = x else: self.hidden = torch.cat([self.hidden, x], 1) if self.hidden.size(1) == self.feature_num: return None elif self.hidden.size(1) == self.feature_num * 2: return self.fc_2(self.hidden) elif self.hidden.size(1) == self.feature_num * 3: return self.fc_3(self.hidden) elif self.hidden.size(1) == self.feature_num * 4: return self.fc_4(self.hidden) elif self.hidden.size(1) == self.feature_num * 5: return self.fc_5(self.hidden) else: print(self.hidden.size()) exit() ================================================ FILE: pycls/__init__.py ================================================ ================================================ FILE: pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml ================================================ MODEL: TYPE: regnet NUM_CLASSES: 1000 REGNET: SE_ON: True DEPTH: 27 W0: 48 WA: 20.71 WM: 2.65 GROUP_W: 24 OPTIM: LR_POLICY: cos BASE_LR: 0.8 MAX_EPOCH: 100 MOMENTUM: 0.9 WEIGHT_DECAY: 5e-5 WARMUP_EPOCHS: 5 TRAIN: DATASET: imagenet IM_SIZE: 224 BATCH_SIZE: 1024 TEST: DATASET: imagenet IM_SIZE: 256 BATCH_SIZE: 800 NUM_GPUS: 1 OUT_DIR: . ================================================ FILE: pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml ================================================ MODEL: TYPE: regnet NUM_CLASSES: 1000 REGNET: SE_ON: True DEPTH: 15 W0: 48 WA: 32.54 WM: 2.32 GROUP_W: 16 OPTIM: LR_POLICY: cos BASE_LR: 0.8 MAX_EPOCH: 100 MOMENTUM: 0.9 WEIGHT_DECAY: 5e-5 WARMUP_EPOCHS: 5 TRAIN: DATASET: imagenet IM_SIZE: 224 BATCH_SIZE: 1024 TEST: DATASET: imagenet IM_SIZE: 256 BATCH_SIZE: 800 NUM_GPUS: 1 OUT_DIR: . ================================================ FILE: pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml ================================================ MODEL: TYPE: regnet NUM_CLASSES: 1000 REGNET: SE_ON: True DEPTH: 14 W0: 56 WA: 38.84 WM: 2.4 GROUP_W: 16 OPTIM: LR_POLICY: cos BASE_LR: 0.8 MAX_EPOCH: 100 MOMENTUM: 0.9 WEIGHT_DECAY: 5e-5 WARMUP_EPOCHS: 5 TRAIN: DATASET: imagenet IM_SIZE: 224 BATCH_SIZE: 1024 TEST: DATASET: imagenet IM_SIZE: 256 BATCH_SIZE: 800 NUM_GPUS: 1 OUT_DIR: . ================================================ FILE: pycls/core/__init__.py ================================================ ================================================ FILE: pycls/core/config.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Configuration file (powered by YACS).""" import os from pycls.utils.io import cache_url from yacs.config import CfgNode as CN # Global config object _C = CN() # Example usage: # from core.config import cfg cfg = _C # ---------------------------------------------------------------------------- # # Model options # ---------------------------------------------------------------------------- # _C.MODEL = CN() # Model type _C.MODEL.TYPE = "" # Number of weight layers _C.MODEL.DEPTH = 0 # Number of classes _C.MODEL.NUM_CLASSES = 10 # Loss function (see pycls/models/loss.py for options) _C.MODEL.LOSS_FUN = "cross_entropy" # ---------------------------------------------------------------------------- # # ResNet options # ---------------------------------------------------------------------------- # _C.RESNET = CN() # Transformation function (see pycls/models/resnet.py for options) _C.RESNET.TRANS_FUN = "basic_transform" # Number of groups to use (1 -> ResNet; > 1 -> ResNeXt) _C.RESNET.NUM_GROUPS = 1 # Width of each group (64 -> ResNet; 4 -> ResNeXt) _C.RESNET.WIDTH_PER_GROUP = 64 # Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch) _C.RESNET.STRIDE_1X1 = True # ---------------------------------------------------------------------------- # # AnyNet options # ---------------------------------------------------------------------------- # _C.ANYNET = CN() # Stem type _C.ANYNET.STEM_TYPE = "plain_block" # Stem width _C.ANYNET.STEM_W = 32 # Block type _C.ANYNET.BLOCK_TYPE = "plain_block" # Depth for each stage (number of blocks in the stage) _C.ANYNET.DEPTHS = [] # Width for each stage (width of each block in the stage) _C.ANYNET.WIDTHS = [] # Strides for each stage (applies to the first block of each stage) _C.ANYNET.STRIDES = [] # Bottleneck multipliers for each stage (applies to bottleneck block) _C.ANYNET.BOT_MULS = [] # Group widths for each stage (applies to bottleneck block) _C.ANYNET.GROUP_WS = [] # Whether SE is enabled for res_bottleneck_block _C.ANYNET.SE_ON = False # SE ratio _C.ANYNET.SE_R = 0.25 # ---------------------------------------------------------------------------- # # RegNet options # ---------------------------------------------------------------------------- # _C.REGNET = CN() # Stem type _C.REGNET.STEM_TYPE = "simple_stem_in" # Stem width _C.REGNET.STEM_W = 32 # Block type _C.REGNET.BLOCK_TYPE = "res_bottleneck_block" # Stride of each stage _C.REGNET.STRIDE = 2 # Squeeze-and-Excitation (RegNetY) _C.REGNET.SE_ON = False _C.REGNET.SE_R = 0.25 # Depth _C.REGNET.DEPTH = 10 # Initial width _C.REGNET.W0 = 32 # Slope _C.REGNET.WA = 5.0 # Quantization _C.REGNET.WM = 2.5 # Group width _C.REGNET.GROUP_W = 16 # Bottleneck multiplier (bm = 1 / b from the paper) _C.REGNET.BOT_MUL = 1.0 # ---------------------------------------------------------------------------- # # EfficientNet options # ---------------------------------------------------------------------------- # _C.EN = CN() # Stem width _C.EN.STEM_W = 32 # Depth for each stage (number of blocks in the stage) _C.EN.DEPTHS = [] # Width for each stage (width of each block in the stage) _C.EN.WIDTHS = [] # Expansion ratios for MBConv blocks in each stage _C.EN.EXP_RATIOS = [] # Squeeze-and-Excitation (SE) ratio _C.EN.SE_R = 0.25 # Strides for each stage (applies to the first block of each stage) _C.EN.STRIDES = [] # Kernel sizes for each stage _C.EN.KERNELS = [] # Head width _C.EN.HEAD_W = 1280 # Drop connect ratio _C.EN.DC_RATIO = 0.0 # Dropout ratio _C.EN.DROPOUT_RATIO = 0.0 # ---------------------------------------------------------------------------- # # Batch norm options # ---------------------------------------------------------------------------- # _C.BN = CN() # BN epsilon _C.BN.EPS = 1e-5 # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) _C.BN.MOM = 0.1 # Precise BN stats _C.BN.USE_PRECISE_STATS = False _C.BN.NUM_SAMPLES_PRECISE = 1024 # Initialize the gamma of the final BN of each block to zero _C.BN.ZERO_INIT_FINAL_GAMMA = False # Use a different weight decay for BN layers _C.BN.USE_CUSTOM_WEIGHT_DECAY = False _C.BN.CUSTOM_WEIGHT_DECAY = 0.0 # ---------------------------------------------------------------------------- # # Optimizer options # ---------------------------------------------------------------------------- # _C.OPTIM = CN() # Base learning rate _C.OPTIM.BASE_LR = 0.1 # Learning rate policy select from {'cos', 'exp', 'steps'} _C.OPTIM.LR_POLICY = "cos" # Exponential decay factor _C.OPTIM.GAMMA = 0.1 # Steps for 'steps' policy (in epochs) _C.OPTIM.STEPS = [] # Learning rate multiplier for 'steps' policy _C.OPTIM.LR_MULT = 0.1 # Maximal number of epochs _C.OPTIM.MAX_EPOCH = 200 # Momentum _C.OPTIM.MOMENTUM = 0.9 # Momentum dampening _C.OPTIM.DAMPENING = 0.0 # Nesterov momentum _C.OPTIM.NESTEROV = True # L2 regularization _C.OPTIM.WEIGHT_DECAY = 5e-4 # Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR _C.OPTIM.WARMUP_FACTOR = 0.1 # Gradually warm up the OPTIM.BASE_LR over this number of epochs _C.OPTIM.WARMUP_EPOCHS = 0 # ---------------------------------------------------------------------------- # # Training options # ---------------------------------------------------------------------------- # _C.TRAIN = CN() # Dataset and split _C.TRAIN.DATASET = "" _C.TRAIN.SPLIT = "train" # Total mini-batch size _C.TRAIN.BATCH_SIZE = 128 # Image size _C.TRAIN.IM_SIZE = 224 # Evaluate model on test data every eval period epochs _C.TRAIN.EVAL_PERIOD = 1 # Save model checkpoint every checkpoint period epochs _C.TRAIN.CHECKPOINT_PERIOD = 1 # Resume training from the latest checkpoint in the output directory _C.TRAIN.AUTO_RESUME = True # Weights to start training from _C.TRAIN.WEIGHTS = "" # ---------------------------------------------------------------------------- # # Testing options # ---------------------------------------------------------------------------- # _C.TEST = CN() # Dataset and split _C.TEST.DATASET = "" _C.TEST.SPLIT = "val" # Total mini-batch size _C.TEST.BATCH_SIZE = 200 # Image size _C.TEST.IM_SIZE = 256 # Weights to use for testing _C.TEST.WEIGHTS = "" # ---------------------------------------------------------------------------- # # Common train/test data loader options # ---------------------------------------------------------------------------- # _C.DATA_LOADER = CN() # Number of data loader workers per training process _C.DATA_LOADER.NUM_WORKERS = 4 # Load data to pinned host memory _C.DATA_LOADER.PIN_MEMORY = True # ---------------------------------------------------------------------------- # # Memory options # ---------------------------------------------------------------------------- # _C.MEM = CN() # Perform ReLU inplace _C.MEM.RELU_INPLACE = True # ---------------------------------------------------------------------------- # # CUDNN options # ---------------------------------------------------------------------------- # _C.CUDNN = CN() # Perform benchmarking to select the fastest CUDNN algorithms to use # Note that this may increase the memory usage and will likely not result # in overall speedups when variable size inputs are used (e.g. COCO training) _C.CUDNN.BENCHMARK = True # ---------------------------------------------------------------------------- # # Precise timing options # ---------------------------------------------------------------------------- # _C.PREC_TIME = CN() # Perform precise timing at the start of training _C.PREC_TIME.ENABLED = False # Total mini-batch size _C.PREC_TIME.BATCH_SIZE = 128 # Number of iterations to warm up the caches _C.PREC_TIME.WARMUP_ITER = 3 # Number of iterations to compute avg time _C.PREC_TIME.NUM_ITER = 30 # ---------------------------------------------------------------------------- # # Misc options # ---------------------------------------------------------------------------- # # Number of GPUs to use (applies to both training and testing) _C.NUM_GPUS = 1 # Output directory _C.OUT_DIR = "/tmp" # Config destination (in OUT_DIR) _C.CFG_DEST = "config.yaml" # Note that non-determinism may still be present due to non-deterministic # operator implementations in GPU operator libraries _C.RNG_SEED = 1 # Log destination ('stdout' or 'file') _C.LOG_DEST = "stdout" # Log period in iters _C.LOG_PERIOD = 10 # Distributed backend _C.DIST_BACKEND = "nccl" # Hostname and port for initializing multi-process groups _C.HOST = "localhost" _C.PORT = 10001 # Models weights referred to by URL are downloaded to this local cache _C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache" def assert_and_infer_cfg(cache_urls=True): """Checks config values invariants.""" assert ( not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0 ), "The first lr step must start at 0" assert _C.TRAIN.SPLIT in [ "train", "val", "test", ], "Train split '{}' not supported".format(_C.TRAIN.SPLIT) assert ( _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0 ), "Train mini-batch size should be a multiple of NUM_GPUS." assert _C.TEST.SPLIT in [ "train", "val", "test", ], "Test split '{}' not supported".format(_C.TEST.SPLIT) assert ( _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0 ), "Test mini-batch size should be a multiple of NUM_GPUS." assert ( not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1 ), "Precise BN stats computation not verified for > 1 GPU" assert _C.LOG_DEST in [ "stdout", "file", ], "Log destination '{}' not supported".format(_C.LOG_DEST) assert ( not _C.PREC_TIME.ENABLED or _C.NUM_GPUS == 1 ), "Precise iter time computation not verified for > 1 GPU" if cache_urls: cache_cfg_urls() def cache_cfg_urls(): """Download URLs in the config, cache them locally, and rewrite cfg to make use of the locally cached file. """ _C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE) _C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE) def dump_cfg(): """Dumps the config to the output directory.""" cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST) with open(cfg_file, "w") as f: _C.dump(stream=f) def load_cfg(out_dir, cfg_dest="config.yaml"): """Loads config from specified output directory.""" cfg_file = os.path.join(out_dir, cfg_dest) _C.merge_from_file(cfg_file) ================================================ FILE: pycls/core/losses.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Loss functions.""" import torch.nn as nn from pycls.core.config import cfg # Supported loss functions _loss_funs = {"cross_entropy": nn.CrossEntropyLoss} def get_loss_fun(): """Retrieves the loss function.""" assert ( cfg.MODEL.LOSS_FUN in _loss_funs.keys() ), "Loss function '{}' not supported".format(cfg.TRAIN.LOSS) return _loss_funs[cfg.MODEL.LOSS_FUN]().cuda() def register_loss_fun(name, ctor): """Registers a loss function dynamically.""" _loss_funs[name] = ctor ================================================ FILE: pycls/core/model_builder.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Model construction functions.""" import pycls.utils.logging as lu import torch from pycls.core.config import cfg from pycls.models.anynet import AnyNet from pycls.models.effnet import EffNet from pycls.models.regnet import RegNet from pycls.models.resnet import ResNet logger = lu.get_logger(__name__) # Supported models _models = {"anynet": AnyNet, "effnet": EffNet, "resnet": ResNet, "regnet": RegNet} def build_model(): """Builds the model.""" assert cfg.MODEL.TYPE in _models.keys(), "Model type '{}' not supported".format( cfg.MODEL.TYPE ) # print(torch.cuda.device_count()) # print(torch.cuda.current_device()) assert ( cfg.NUM_GPUS <= torch.cuda.device_count() ), "Cannot use more GPU devices than available" # Construct the model model = _models[cfg.MODEL.TYPE]() # Determine the GPU used by the current process cur_device = torch.cuda.current_device() # Transfer the model to the current GPU device model = model.cuda(device=cur_device) # Use multi-process data parallel model in the multi-gpu setting if cfg.NUM_GPUS > 1: # Make model replica operate on the current device model = torch.nn.parallel.DistributedDataParallel( module=model, device_ids=[cur_device], output_device=cur_device ) return model def register_model(name, ctor): """Registers a model dynamically.""" _models[name] = ctor ================================================ FILE: pycls/core/old_config.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Configuration file (powered by YACS).""" import os from pycls.utils.io import cache_url from yacs.config import CfgNode as CN # Global config object _C = CN() # Example usage: # from core.config import cfg cfg = _C # ---------------------------------------------------------------------------- # # Model options # ---------------------------------------------------------------------------- # _C.MODEL = CN() # Model type _C.MODEL.TYPE = "" # Number of weight layers _C.MODEL.DEPTH = 0 # Number of classes _C.MODEL.NUM_CLASSES = 10 # Loss function (see pycls/models/loss.py for options) _C.MODEL.LOSS_FUN = "cross_entropy" # ---------------------------------------------------------------------------- # # ResNet options # ---------------------------------------------------------------------------- # _C.RESNET = CN() # Transformation function (see pycls/models/resnet.py for options) _C.RESNET.TRANS_FUN = "basic_transform" # Number of groups to use (1 -> ResNet; > 1 -> ResNeXt) _C.RESNET.NUM_GROUPS = 1 # Width of each group (64 -> ResNet; 4 -> ResNeXt) _C.RESNET.WIDTH_PER_GROUP = 64 # Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch) _C.RESNET.STRIDE_1X1 = True # ---------------------------------------------------------------------------- # # AnyNet options # ---------------------------------------------------------------------------- # _C.ANYNET = CN() # Stem type _C.ANYNET.STEM_TYPE = "plain_block" # Stem width _C.ANYNET.STEM_W = 32 # Block type _C.ANYNET.BLOCK_TYPE = "plain_block" # Depth for each stage (number of blocks in the stage) _C.ANYNET.DEPTHS = [] # Width for each stage (width of each block in the stage) _C.ANYNET.WIDTHS = [] # Strides for each stage (applies to the first block of each stage) _C.ANYNET.STRIDES = [] # Bottleneck multipliers for each stage (applies to bottleneck block) _C.ANYNET.BOT_MULS = [] # Group widths for each stage (applies to bottleneck block) _C.ANYNET.GROUP_WS = [] # Whether SE is enabled for res_bottleneck_block _C.ANYNET.SE_ON = False # SE ratio _C.ANYNET.SE_R = 0.25 # ---------------------------------------------------------------------------- # # RegNet options # ---------------------------------------------------------------------------- # _C.REGNET = CN() # Stem type _C.REGNET.STEM_TYPE = "simple_stem_in" # Stem width _C.REGNET.STEM_W = 32 # Block type _C.REGNET.BLOCK_TYPE = "res_bottleneck_block" # Stride of each stage _C.REGNET.STRIDE = 2 # Squeeze-and-Excitation (RegNetY) _C.REGNET.SE_ON = False _C.REGNET.SE_R = 0.25 # Depth _C.REGNET.DEPTH = 10 # Initial width _C.REGNET.W0 = 32 # Slope _C.REGNET.WA = 5.0 # Quantization _C.REGNET.WM = 2.5 # Group width _C.REGNET.GROUP_W = 16 # Bottleneck multiplier (bm = 1 / b from the paper) _C.REGNET.BOT_MUL = 1.0 # ---------------------------------------------------------------------------- # # EfficientNet options # ---------------------------------------------------------------------------- # _C.EN = CN() # Stem width _C.EN.STEM_W = 32 # Depth for each stage (number of blocks in the stage) _C.EN.DEPTHS = [] # Width for each stage (width of each block in the stage) _C.EN.WIDTHS = [] # Expansion ratios for MBConv blocks in each stage _C.EN.EXP_RATIOS = [] # Squeeze-and-Excitation (SE) ratio _C.EN.SE_R = 0.25 # Strides for each stage (applies to the first block of each stage) _C.EN.STRIDES = [] # Kernel sizes for each stage _C.EN.KERNELS = [] # Head width _C.EN.HEAD_W = 1280 # Drop connect ratio _C.EN.DC_RATIO = 0.0 # Dropout ratio _C.EN.DROPOUT_RATIO = 0.0 # ---------------------------------------------------------------------------- # # Batch norm options # ---------------------------------------------------------------------------- # _C.BN = CN() # BN epsilon _C.BN.EPS = 1e-5 # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) _C.BN.MOM = 0.1 # Precise BN stats _C.BN.USE_PRECISE_STATS = False _C.BN.NUM_SAMPLES_PRECISE = 1024 # Initialize the gamma of the final BN of each block to zero _C.BN.ZERO_INIT_FINAL_GAMMA = False # Use a different weight decay for BN layers _C.BN.USE_CUSTOM_WEIGHT_DECAY = False _C.BN.CUSTOM_WEIGHT_DECAY = 0.0 # ---------------------------------------------------------------------------- # # Optimizer options # ---------------------------------------------------------------------------- # _C.OPTIM = CN() # Base learning rate _C.OPTIM.BASE_LR = 0.1 # Learning rate policy select from {'cos', 'exp', 'steps'} _C.OPTIM.LR_POLICY = "cos" # Exponential decay factor _C.OPTIM.GAMMA = 0.1 # Steps for 'steps' policy (in epochs) _C.OPTIM.STEPS = [] # Learning rate multiplier for 'steps' policy _C.OPTIM.LR_MULT = 0.1 # Maximal number of epochs _C.OPTIM.MAX_EPOCH = 200 # Momentum _C.OPTIM.MOMENTUM = 0.9 # Momentum dampening _C.OPTIM.DAMPENING = 0.0 # Nesterov momentum _C.OPTIM.NESTEROV = True # L2 regularization _C.OPTIM.WEIGHT_DECAY = 5e-4 # Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR _C.OPTIM.WARMUP_FACTOR = 0.1 # Gradually warm up the OPTIM.BASE_LR over this number of epochs _C.OPTIM.WARMUP_EPOCHS = 0 # ---------------------------------------------------------------------------- # # Training options # ---------------------------------------------------------------------------- # _C.TRAIN = CN() # Dataset and split _C.TRAIN.DATASET = "" _C.TRAIN.SPLIT = "train" # Total mini-batch size _C.TRAIN.BATCH_SIZE = 128 # Image size _C.TRAIN.IM_SIZE = 224 # Evaluate model on test data every eval period epochs _C.TRAIN.EVAL_PERIOD = 1 # Save model checkpoint every checkpoint period epochs _C.TRAIN.CHECKPOINT_PERIOD = 1 # Resume training from the latest checkpoint in the output directory _C.TRAIN.AUTO_RESUME = True # Weights to start training from _C.TRAIN.WEIGHTS = "" # ---------------------------------------------------------------------------- # # Testing options # ---------------------------------------------------------------------------- # _C.TEST = CN() # Dataset and split _C.TEST.DATASET = "" _C.TEST.SPLIT = "val" # Total mini-batch size _C.TEST.BATCH_SIZE = 200 # Image size _C.TEST.IM_SIZE = 256 # Weights to use for testing _C.TEST.WEIGHTS = "" # ---------------------------------------------------------------------------- # # Common train/test data loader options # ---------------------------------------------------------------------------- # _C.DATA_LOADER = CN() # Number of data loader workers per training process _C.DATA_LOADER.NUM_WORKERS = 4 # Load data to pinned host memory _C.DATA_LOADER.PIN_MEMORY = True # ---------------------------------------------------------------------------- # # Memory options # ---------------------------------------------------------------------------- # _C.MEM = CN() # Perform ReLU inplace _C.MEM.RELU_INPLACE = True # ---------------------------------------------------------------------------- # # CUDNN options # ---------------------------------------------------------------------------- # _C.CUDNN = CN() # Perform benchmarking to select the fastest CUDNN algorithms to use # Note that this may increase the memory usage and will likely not result # in overall speedups when variable size inputs are used (e.g. COCO training) _C.CUDNN.BENCHMARK = True # ---------------------------------------------------------------------------- # # Precise timing options # ---------------------------------------------------------------------------- # _C.PREC_TIME = CN() # Perform precise timing at the start of training _C.PREC_TIME.ENABLED = False # Total mini-batch size _C.PREC_TIME.BATCH_SIZE = 128 # Number of iterations to warm up the caches _C.PREC_TIME.WARMUP_ITER = 3 # Number of iterations to compute avg time _C.PREC_TIME.NUM_ITER = 30 # ---------------------------------------------------------------------------- # # Misc options # ---------------------------------------------------------------------------- # # Number of GPUs to use (applies to both training and testing) _C.NUM_GPUS = 1 # Output directory _C.OUT_DIR = "/tmp" # Config destination (in OUT_DIR) _C.CFG_DEST = "config.yaml" # Note that non-determinism may still be present due to non-deterministic # operator implementations in GPU operator libraries _C.RNG_SEED = 1 # Log destination ('stdout' or 'file') _C.LOG_DEST = "stdout" # Log period in iters _C.LOG_PERIOD = 10 # Distributed backend _C.DIST_BACKEND = "nccl" # Hostname and port for initializing multi-process groups _C.HOST = "localhost" _C.PORT = 10001 # Models weights referred to by URL are downloaded to this local cache _C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache" def assert_and_infer_cfg(cache_urls=True): """Checks config values invariants.""" assert ( not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0 ), "The first lr step must start at 0" assert _C.TRAIN.SPLIT in [ "train", "val", "test", ], "Train split '{}' not supported".format(_C.TRAIN.SPLIT) assert ( _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0 ), "Train mini-batch size should be a multiple of NUM_GPUS." assert _C.TEST.SPLIT in [ "train", "val", "test", ], "Test split '{}' not supported".format(_C.TEST.SPLIT) assert ( _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0 ), "Test mini-batch size should be a multiple of NUM_GPUS." assert ( not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1 ), "Precise BN stats computation not verified for > 1 GPU" assert _C.LOG_DEST in [ "stdout", "file", ], "Log destination '{}' not supported".format(_C.LOG_DEST) assert ( not _C.PREC_TIME.ENABLED or _C.NUM_GPUS == 1 ), "Precise iter time computation not verified for > 1 GPU" if cache_urls: cache_cfg_urls() def cache_cfg_urls(): """Download URLs in the config, cache them locally, and rewrite cfg to make use of the locally cached file. """ _C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE) _C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE) def dump_cfg(): """Dumps the config to the output directory.""" cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST) with open(cfg_file, "w") as f: _C.dump(stream=f) def load_cfg(out_dir, cfg_dest="config.yaml"): """Loads config from specified output directory.""" cfg_file = os.path.join(out_dir, cfg_dest) _C.merge_from_file(cfg_file) ================================================ FILE: pycls/core/optimizer.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Optimizer.""" import pycls.utils.lr_policy as lr_policy import torch from pycls.core.config import cfg def construct_optimizer(model): """Constructs the optimizer. Note that the momentum update in PyTorch differs from the one in Caffe2. In particular, Caffe2: V := mu * V + lr * g p := p - V PyTorch: V := mu * V + g p := p - lr * V where V is the velocity, mu is the momentum factor, lr is the learning rate, g is the gradient and p are the parameters. Since V is defined independently of the learning rate in PyTorch, when the learning rate is changed there is no need to perform the momentum correction by scaling V (unlike in the Caffe2 case). """ # Batchnorm parameters. bn_params = [] # Non-batchnorm parameters. non_bn_parameters = [] for name, p in model.named_parameters(): if "bn" in name: bn_params.append(p) else: non_bn_parameters.append(p) # Apply different weight decay to Batchnorm and non-batchnorm parameters. bn_weight_decay = ( cfg.BN.CUSTOM_WEIGHT_DECAY if cfg.BN.USE_CUSTOM_WEIGHT_DECAY else cfg.OPTIM.WEIGHT_DECAY ) optim_params = [ {"params": bn_params, "weight_decay": bn_weight_decay}, {"params": non_bn_parameters, "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, ] # Check all parameters will be passed into optimizer. assert len(list(model.parameters())) == len(non_bn_parameters) + len( bn_params ), "parameter size does not match: {} + {} != {}".format( len(non_bn_parameters), len(bn_params), len(list(model.parameters())) ) return torch.optim.SGD( optim_params, lr=cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY, dampening=cfg.OPTIM.DAMPENING, nesterov=cfg.OPTIM.NESTEROV, ) def get_epoch_lr(cur_epoch): """Retrieves the lr for the given epoch (as specified by the lr policy).""" return lr_policy.get_epoch_lr(cur_epoch) def set_lr(optimizer, new_lr): """Sets the optimizer lr to the specified value.""" for param_group in optimizer.param_groups: param_group["lr"] = new_lr ================================================ FILE: pycls/datasets/__init__.py ================================================ ================================================ FILE: pycls/datasets/cifar10.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """CIFAR10 dataset.""" import os import pickle import numpy as np import pycls.datasets.transforms as transforms import pycls.utils.logging as lu import torch import torch.utils.data from pycls.core.config import cfg logger = lu.get_logger(__name__) # Per-channel mean and SD values in BGR order _MEAN = [125.3, 123.0, 113.9] _SD = [63.0, 62.1, 66.7] class Cifar10(torch.utils.data.Dataset): """CIFAR-10 dataset.""" def __init__(self, data_path, split): assert os.path.exists(data_path), "Data path '{}' not found".format(data_path) assert split in ["train", "test"], "Split '{}' not supported for cifar".format( split ) logger.info("Constructing CIFAR-10 {}...".format(split)) self._data_path = data_path self._split = split # Data format: # self._inputs - (split_size, 3, im_size, im_size) ndarray # self._labels - split_size list self._inputs, self._labels = self._load_data() def _load_batch(self, batch_path): with open(batch_path, "rb") as f: d = pickle.load(f, encoding="bytes") return d[b"data"], d[b"labels"] def _load_data(self): """Loads data in memory.""" logger.info("{} data path: {}".format(self._split, self._data_path)) # Compute data batch names if self._split == "train": batch_names = ["data_batch_{}".format(i) for i in range(1, 6)] else: batch_names = ["test_batch"] # Load data batches inputs, labels = [], [] for batch_name in batch_names: batch_path = os.path.join(self._data_path, batch_name) inputs_batch, labels_batch = self._load_batch(batch_path) inputs.append(inputs_batch) labels += labels_batch # Combine and reshape the inputs inputs = np.vstack(inputs).astype(np.float32) inputs = inputs.reshape((-1, 3, cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE)) return inputs, labels def _prepare_im(self, im): """Prepares the image for network input.""" im = transforms.color_norm(im, _MEAN, _SD) if self._split == "train": im = transforms.horizontal_flip(im=im, p=0.5) im = transforms.random_crop(im=im, size=cfg.TRAIN.IM_SIZE, pad_size=4) return im def __getitem__(self, index): im, label = self._inputs[index, ...].copy(), self._labels[index] im = self._prepare_im(im) return im, label def __len__(self): return self._inputs.shape[0] ================================================ FILE: pycls/datasets/imagenet.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ImageNet dataset.""" import os import re import cv2 import numpy as np import pycls.datasets.transforms as transforms import pycls.utils.logging as lu import torch import torch.utils.data from pycls.core.config import cfg logger = lu.get_logger(__name__) # Per-channel mean and SD values in BGR order _MEAN = [0.406, 0.456, 0.485] _SD = [0.225, 0.224, 0.229] # Eig vals and vecs of the cov mat _EIG_VALS = np.array([[0.2175, 0.0188, 0.0045]]) _EIG_VECS = np.array( [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]] ) class ImageNet(torch.utils.data.Dataset): """ImageNet dataset.""" def __init__(self, data_path, split): assert os.path.exists(data_path), "Data path '{}' not found".format(data_path) assert split in [ "train", "val", ], "Split '{}' not supported for ImageNet".format(split) logger.info("Constructing ImageNet {}...".format(split)) self._data_path = data_path self._split = split self._construct_imdb() def _construct_imdb(self): """Constructs the imdb.""" # Compile the split data path split_path = os.path.join(self._data_path, self._split) logger.info("{} data path: {}".format(self._split, split_path)) # Images are stored per class in subdirs (format: n) self._class_ids = sorted( f for f in os.listdir(split_path) if re.match(r"^n[0-9]+$", f) ) # Map ImageNet class ids to contiguous ids self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} # Construct the image db self._imdb = [] for class_id in self._class_ids: cont_id = self._class_id_cont_id[class_id] im_dir = os.path.join(split_path, class_id) for im_name in os.listdir(im_dir): self._imdb.append( {"im_path": os.path.join(im_dir, im_name), "class": cont_id} ) logger.info("Number of images: {}".format(len(self._imdb))) logger.info("Number of classes: {}".format(len(self._class_ids))) def _prepare_im(self, im): """Prepares the image for network input.""" # Train and test setups differ if self._split == "train": # Scale and aspect ratio im = transforms.random_sized_crop( im=im, size=cfg.TRAIN.IM_SIZE, area_frac=0.08 ) # Horizontal flip im = transforms.horizontal_flip(im=im, p=0.5, order="HWC") else: # Scale and center crop im = transforms.scale(cfg.TEST.IM_SIZE, im) im = transforms.center_crop(cfg.TRAIN.IM_SIZE, im) # HWC -> CHW im = im.transpose([2, 0, 1]) # [0, 255] -> [0, 1] im = im / 255.0 # PCA jitter if self._split == "train": im = transforms.lighting(im, 0.1, _EIG_VALS, _EIG_VECS) # Color normalization im = transforms.color_norm(im, _MEAN, _SD) return im def __getitem__(self, index): # Load the image im = cv2.imread(self._imdb[index]["im_path"]) im = im.astype(np.float32, copy=False) # Prepare the image for training / testing im = self._prepare_im(im) # Retrieve the label label = self._imdb[index]["class"] return im, label def __len__(self): return len(self._imdb) ================================================ FILE: pycls/datasets/loader.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Data loader.""" import pycls.datasets.paths as dp import torch from pycls.core.config import cfg from pycls.datasets.cifar10 import Cifar10 from pycls.datasets.imagenet import ImageNet from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler # Supported datasets _DATASET_CATALOG = {"cifar10": Cifar10, "imagenet": ImageNet} def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last): """Constructs the data loader for the given dataset.""" assert dataset_name in _DATASET_CATALOG.keys(), "Dataset '{}' not supported".format( dataset_name ) assert dp.has_data_path(dataset_name), "Dataset '{}' has no data path".format( dataset_name ) # Retrieve the data path for the dataset data_path = dp.get_data_path(dataset_name) # Construct the dataset dataset = _DATASET_CATALOG[dataset_name](data_path, split) # Create a sampler for multi-process training sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None # Create a loader loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=(False if sampler else shuffle), sampler=sampler, num_workers=cfg.DATA_LOADER.NUM_WORKERS, pin_memory=cfg.DATA_LOADER.PIN_MEMORY, drop_last=drop_last, ) return loader def construct_train_loader(): """Train loader wrapper.""" return _construct_loader( dataset_name=cfg.TRAIN.DATASET, split=cfg.TRAIN.SPLIT, batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), shuffle=True, drop_last=True, ) def construct_test_loader(): """Test loader wrapper.""" return _construct_loader( dataset_name=cfg.TEST.DATASET, split=cfg.TEST.SPLIT, batch_size=int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS), shuffle=False, drop_last=False, ) def shuffle(loader, cur_epoch): """"Shuffles the data.""" assert isinstance( loader.sampler, (RandomSampler, DistributedSampler) ), "Sampler type '{}' not supported".format(type(loader.sampler)) # RandomSampler handles shuffling automatically if isinstance(loader.sampler, DistributedSampler): # DistributedSampler shuffles data based on epoch loader.sampler.set_epoch(cur_epoch) ================================================ FILE: pycls/datasets/paths.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Dataset paths.""" import os # Default data directory (/path/pycls/pycls/datasets/data) _DEF_DATA_DIR = os.path.join(os.path.dirname(__file__), "data") # Data paths _paths = { "cifar10": _DEF_DATA_DIR + "/cifar10", "imagenet": _DEF_DATA_DIR + "/imagenet", } def has_data_path(dataset_name): """Determines if the dataset has a data path.""" return dataset_name in _paths.keys() def get_data_path(dataset_name): """Retrieves data path for the dataset.""" return _paths[dataset_name] def register_path(name, path): """Registers a dataset path dynamically.""" _paths[name] = path ================================================ FILE: pycls/datasets/transforms.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Image transformations.""" import math import cv2 import numpy as np def color_norm(im, mean, std): """Performs per-channel normalization (CHW format).""" for i in range(im.shape[0]): im[i] = im[i] - mean[i] im[i] = im[i] / std[i] return im def zero_pad(im, pad_size): """Performs zero padding (CHW format).""" pad_width = ((0, 0), (pad_size, pad_size), (pad_size, pad_size)) return np.pad(im, pad_width, mode="constant") def horizontal_flip(im, p, order="CHW"): """Performs horizontal flip (CHW or HWC format).""" assert order in ["CHW", "HWC"] if np.random.uniform() < p: if order == "CHW": im = im[:, :, ::-1] else: im = im[:, ::-1, :] return im def random_crop(im, size, pad_size=0): """Performs random crop (CHW format).""" if pad_size > 0: im = zero_pad(im=im, pad_size=pad_size) h, w = im.shape[1:] y = np.random.randint(0, h - size) x = np.random.randint(0, w - size) im_crop = im[:, y : (y + size), x : (x + size)] assert im_crop.shape[1:] == (size, size) return im_crop def scale(size, im): """Performs scaling (HWC format).""" h, w = im.shape[:2] if (w <= h and w == size) or (h <= w and h == size): return im h_new, w_new = size, size if w < h: h_new = int(math.floor((float(h) / w) * size)) else: w_new = int(math.floor((float(w) / h) * size)) im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_LINEAR) return im.astype(np.float32) def center_crop(size, im): """Performs center cropping (HWC format).""" h, w = im.shape[:2] y = int(math.ceil((h - size) / 2)) x = int(math.ceil((w - size) / 2)) im_crop = im[y : (y + size), x : (x + size), :] assert im_crop.shape[:2] == (size, size) return im_crop def random_sized_crop(im, size, area_frac=0.08, max_iter=10): """Performs Inception-style cropping (HWC format).""" h, w = im.shape[:2] area = h * w for _ in range(max_iter): target_area = np.random.uniform(area_frac, 1.0) * area aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0) w_crop = int(round(math.sqrt(float(target_area) * aspect_ratio))) h_crop = int(round(math.sqrt(float(target_area) / aspect_ratio))) if np.random.uniform() < 0.5: w_crop, h_crop = h_crop, w_crop if h_crop <= h and w_crop <= w: y = 0 if h_crop == h else np.random.randint(0, h - h_crop) x = 0 if w_crop == w else np.random.randint(0, w - w_crop) im_crop = im[y : (y + h_crop), x : (x + w_crop), :] assert im_crop.shape[:2] == (h_crop, w_crop) im_crop = cv2.resize(im_crop, (size, size), interpolation=cv2.INTER_LINEAR) return im_crop.astype(np.float32) return center_crop(size, scale(size, im)) def lighting(im, alpha_std, eig_val, eig_vec): """Performs AlexNet-style PCA jitter (CHW format).""" if alpha_std == 0: return im alpha = np.random.normal(0, alpha_std, size=(1, 3)) rgb = np.sum( eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), axis=1 ) for i in range(im.shape[0]): im[i] = im[i] + rgb[2 - i] return im ================================================ FILE: pycls/models/__init__.py ================================================ ================================================ FILE: pycls/models/anynet.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """AnyNet models.""" import pycls.utils.logging as lu import pycls.utils.net as nu import torch.nn as nn from pycls.core.config import cfg logger = lu.get_logger(__name__) def get_stem_fun(stem_type): """Retrives the stem function by name.""" stem_funs = { "res_stem_cifar": ResStemCifar, "res_stem_in": ResStemIN, "simple_stem_in": SimpleStemIN, } assert stem_type in stem_funs.keys(), "Stem type '{}' not supported".format( stem_type ) return stem_funs[stem_type] def get_block_fun(block_type): """Retrieves the block function by name.""" block_funs = { "vanilla_block": VanillaBlock, "res_basic_block": ResBasicBlock, "res_bottleneck_block": ResBottleneckBlock, } assert block_type in block_funs.keys(), "Block type '{}' not supported".format( block_type ) return block_funs[block_type] class AnyHead(nn.Module): """AnyNet head.""" def __init__(self, w_in, nc): super(AnyHead, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) # self.fc = nn.Linear(w_in, nc, bias=True) def forward(self, x): x = self.avg_pool(x) # print(x.shape) x = x.view(x.size(0), -1) # x = self.fc(x) return x class VanillaBlock(nn.Module): """Vanilla block: [3x3 conv, BN, Relu] x2""" def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None): assert ( bm is None and gw is None and se_r is None ), "Vanilla block does not support bm, gw, and se_r options" super(VanillaBlock, self).__init__() self._construct(w_in, w_out, stride) def _construct(self, w_in, w_out, stride): # 3x3, BN, ReLU self.a = nn.Conv2d( w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False ) self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) # 3x3, BN, ReLU self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False) self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) def forward(self, x): for layer in self.children(): x = layer(x) return x class BasicTransform(nn.Module): """Basic transformation: [3x3 conv, BN, Relu] x2""" def __init__(self, w_in, w_out, stride): super(BasicTransform, self).__init__() self._construct(w_in, w_out, stride) def _construct(self, w_in, w_out, stride): # 3x3, BN, ReLU self.a = nn.Conv2d( w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False ) self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) # 3x3, BN self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False) self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.b_bn.final_bn = True def forward(self, x): for layer in self.children(): x = layer(x) return x class ResBasicBlock(nn.Module): """Residual basic block: x + F(x), F = basic transform""" def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None): assert ( bm is None and gw is None and se_r is None ), "Basic transform does not support bm, gw, and se_r options" super(ResBasicBlock, self).__init__() self._construct(w_in, w_out, stride) def _add_skip_proj(self, w_in, w_out, stride): self.proj = nn.Conv2d( w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False ) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) def _construct(self, w_in, w_out, stride): # Use skip connection with projection if shape changes self.proj_block = (w_in != w_out) or (stride != 1) if self.proj_block: self._add_skip_proj(w_in, w_out, stride) self.f = BasicTransform(w_in, w_out, stride) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) def forward(self, x): if self.proj_block: x = self.bn(self.proj(x)) + self.f(x) else: x = x + self.f(x) x = self.relu(x) return x class SE(nn.Module): """Squeeze-and-Excitation (SE) block""" def __init__(self, w_in, w_se): super(SE, self).__init__() self._construct(w_in, w_se) def _construct(self, w_in, w_se): # AvgPool self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) # FC, Activation, FC, Sigmoid self.f_ex = nn.Sequential( nn.Conv2d(w_in, w_se, kernel_size=1, bias=True), nn.ReLU(inplace=cfg.MEM.RELU_INPLACE), nn.Conv2d(w_se, w_in, kernel_size=1, bias=True), nn.Sigmoid(), ) def forward(self, x): return x * self.f_ex(self.avg_pool(x)) class BottleneckTransform(nn.Module): """Bottlenect transformation: 1x1, 3x3, 1x1""" def __init__(self, w_in, w_out, stride, bm, gw, se_r): super(BottleneckTransform, self).__init__() self._construct(w_in, w_out, stride, bm, gw, se_r) def _construct(self, w_in, w_out, stride, bm, gw, se_r): # Compute the bottleneck width w_b = int(round(w_out * bm)) # Compute the number of groups num_gs = w_b // gw # 1x1, BN, ReLU self.a = nn.Conv2d(w_in, w_b, kernel_size=1, stride=1, padding=0, bias=False) self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) # 3x3, BN, ReLU self.b = nn.Conv2d( w_b, w_b, kernel_size=3, stride=stride, padding=1, groups=num_gs, bias=False ) self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) # Squeeze-and-Excitation (SE) if se_r: w_se = int(round(w_in * se_r)) self.se = SE(w_b, w_se) # 1x1, BN self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False) self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.c_bn.final_bn = True def forward(self, x): for layer in self.children(): x = layer(x) return x class ResBottleneckBlock(nn.Module): """Residual bottleneck block: x + F(x), F = bottleneck transform""" def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None): super(ResBottleneckBlock, self).__init__() self._construct(w_in, w_out, stride, bm, gw, se_r) def _add_skip_proj(self, w_in, w_out, stride): self.proj = nn.Conv2d( w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False ) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) def _construct(self, w_in, w_out, stride, bm, gw, se_r): # Use skip connection with projection if shape changes self.proj_block = (w_in != w_out) or (stride != 1) if self.proj_block: self._add_skip_proj(w_in, w_out, stride) self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) def forward(self, x): if self.proj_block: x = self.bn(self.proj(x)) + self.f(x) else: x = x + self.f(x) x = self.relu(x) return x class ResStemCifar(nn.Module): """ResNet stem for CIFAR.""" def __init__(self, w_in, w_out): super(ResStemCifar, self).__init__() self._construct(w_in, w_out) def _construct(self, w_in, w_out): # 3x3, BN, ReLU self.conv = nn.Conv2d( w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False ) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) def forward(self, x): for layer in self.children(): x = layer(x) return x class ResStemIN(nn.Module): """ResNet stem for ImageNet.""" def __init__(self, w_in, w_out): super(ResStemIN, self).__init__() self._construct(w_in, w_out) def _construct(self, w_in, w_out): # 7x7, BN, ReLU, maxpool self.conv = nn.Conv2d( w_in, w_out, kernel_size=7, stride=2, padding=3, bias=False ) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def forward(self, x): for layer in self.children(): x = layer(x) return x class SimpleStemIN(nn.Module): """Simple stem for ImageNet.""" def __init__(self, in_w, out_w): super(SimpleStemIN, self).__init__() self._construct(in_w, out_w) def _construct(self, in_w, out_w): # 3x3, BN, ReLU self.conv = nn.Conv2d( in_w, out_w, kernel_size=3, stride=2, padding=1, bias=False ) self.bn = nn.BatchNorm2d(out_w, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) def forward(self, x): for layer in self.children(): x = layer(x) return x class AnyStage(nn.Module): """AnyNet stage (sequence of blocks w/ the same output shape).""" def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r): super(AnyStage, self).__init__() self._construct(w_in, w_out, stride, d, block_fun, bm, gw, se_r) def _construct(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r): # Construct the blocks for i in range(d): # Stride and w_in apply to the first block of the stage b_stride = stride if i == 0 else 1 b_w_in = w_in if i == 0 else w_out # Construct the block self.add_module( "b{}".format(i + 1), block_fun(b_w_in, w_out, b_stride, bm, gw, se_r) ) def forward(self, x): for block in self.children(): x = block(x) return x class AnyNet(nn.Module): """AnyNet model.""" def __init__(self, **kwargs): super(AnyNet, self).__init__() if kwargs: self._construct( stem_type=kwargs["stem_type"], stem_w=kwargs["stem_w"], block_type=kwargs["block_type"], ds=kwargs["ds"], ws=kwargs["ws"], ss=kwargs["ss"], bms=kwargs["bms"], gws=kwargs["gws"], se_r=kwargs["se_r"], nc=kwargs["nc"], ) else: self._construct( stem_type=cfg.ANYNET.STEM_TYPE, stem_w=cfg.ANYNET.STEM_W, block_type=cfg.ANYNET.BLOCK_TYPE, ds=cfg.ANYNET.DEPTHS, ws=cfg.ANYNET.WIDTHS, ss=cfg.ANYNET.STRIDES, bms=cfg.ANYNET.BOT_MULS, gws=cfg.ANYNET.GROUP_WS, se_r=cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None, nc=cfg.MODEL.NUM_CLASSES, ) self.apply(nu.init_weights) def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc): logger.info("Constructing AnyNet: ds={}, ws={}".format(ds, ws)) # Generate dummy bot muls and gs for models that do not use them bms = bms if bms else [1.0 for _d in ds] gws = gws if gws else [1 for _d in ds] # Group params by stage stage_params = list(zip(ds, ws, ss, bms, gws)) # Construct the stem stem_fun = get_stem_fun(stem_type) self.stem = stem_fun(3, stem_w) # Construct the stages block_fun = get_block_fun(block_type) prev_w = stem_w for i, (d, w, s, bm, gw) in enumerate(stage_params): self.add_module( "s{}".format(i + 1), AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r) ) prev_w = w # Construct the head self.head = AnyHead(w_in=prev_w, nc=nc) self.fc = nn.Linear(prev_w, nc, bias=True) self.feature_num = prev_w def forward(self, x): # for module in self.children(): # x = module(x) # return x for name, module in self.named_children(): if name != 'head' and name != 'fc': x = module(x) # print(x.shape) output = self.head(x) # print(output.shape) return output, x.detach() ================================================ FILE: pycls/models/effnet.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """EfficientNet models.""" import pycls.utils.logging as logging import pycls.utils.net as nu import torch import torch.nn as nn from pycls.core.config import cfg logger = logging.get_logger(__name__) class EffHead(nn.Module): """EfficientNet head.""" def __init__(self, w_in, w_out, nc): super(EffHead, self).__init__() self._construct(w_in, w_out, nc) def _construct(self, w_in, w_out, nc): # 1x1, BN, Swish self.conv = nn.Conv2d( w_in, w_out, kernel_size=1, stride=1, padding=0, bias=False ) self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.conv_swish = Swish() # AvgPool self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) # Dropout if cfg.EN.DROPOUT_RATIO > 0.0: self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO) # FC self.fc = nn.Linear(w_out, nc, bias=True) def forward(self, x): x = self.conv_swish(self.conv_bn(self.conv(x))) x = self.avg_pool(x) x = x.view(x.size(0), -1) x = self.dropout(x) if hasattr(self, "dropout") else x x = self.fc(x) return x class Swish(nn.Module): """Swish activation function: x * sigmoid(x)""" def __init__(self): super(Swish, self).__init__() def forward(self, x): return x * torch.sigmoid(x) class SE(nn.Module): """Squeeze-and-Excitation (SE) block w/ Swish.""" def __init__(self, w_in, w_se): super(SE, self).__init__() self._construct(w_in, w_se) def _construct(self, w_in, w_se): # AvgPool self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) # FC, Swish, FC, Sigmoid self.f_ex = nn.Sequential( nn.Conv2d(w_in, w_se, kernel_size=1, bias=True), Swish(), nn.Conv2d(w_se, w_in, kernel_size=1, bias=True), nn.Sigmoid(), ) def forward(self, x): return x * self.f_ex(self.avg_pool(x)) class MBConv(nn.Module): """Mobile inverted bottleneck block w/ SE (MBConv).""" def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out): super(MBConv, self).__init__() self._construct(w_in, exp_r, kernel, stride, se_r, w_out) def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out): # Expansion ratio is wrt the input width self.exp = None w_exp = int(w_in * exp_r) # Include exp ops only if the exp ratio is different from 1 if w_exp != w_in: # 1x1, BN, Swish self.exp = nn.Conv2d( w_in, w_exp, kernel_size=1, stride=1, padding=0, bias=False ) self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.exp_swish = Swish() # 3x3 dwise, BN, Swish self.dwise = nn.Conv2d( w_exp, w_exp, kernel_size=kernel, stride=stride, groups=w_exp, bias=False, # Hacky padding to preserve res (supports only 3x3 and 5x5) padding=(1 if kernel == 3 else 2), ) self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.dwise_swish = Swish() # Squeeze-and-Excitation (SE) w_se = int(w_in * se_r) self.se = SE(w_exp, w_se) # 1x1, BN self.lin_proj = nn.Conv2d( w_exp, w_out, kernel_size=1, stride=1, padding=0, bias=False ) self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) # Skip connection if in and out shapes are the same (MN-V2 style) self.has_skip = (stride == 1) and (w_in == w_out) def forward(self, x): f_x = x # Expansion if self.exp: f_x = self.exp_swish(self.exp_bn(self.exp(f_x))) # Depthwise f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x))) # SE f_x = self.se(f_x) # Linear projection f_x = self.lin_proj_bn(self.lin_proj(f_x)) # Skip connection if self.has_skip: # Drop connect if self.training and cfg.EN.DC_RATIO > 0.0: f_x = nu.drop_connect(f_x, cfg.EN.DC_RATIO) f_x = x + f_x return f_x class EffStage(nn.Module): """EfficientNet stage.""" def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d): super(EffStage, self).__init__() self._construct(w_in, exp_r, kernel, stride, se_r, w_out, d) def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out, d): # Construct the blocks for i in range(d): # Stride and input width apply to the first block of the stage b_stride = stride if i == 0 else 1 b_w_in = w_in if i == 0 else w_out # Construct the block self.add_module( "b{}".format(i + 1), MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out), ) def forward(self, x): for block in self.children(): x = block(x) return x class StemIN(nn.Module): """EfficientNet stem for ImageNet.""" def __init__(self, w_in, w_out): super(StemIN, self).__init__() self._construct(w_in, w_out) def _construct(self, w_in, w_out): # 3x3, BN, Swish self.conv = nn.Conv2d( w_in, w_out, kernel_size=3, stride=2, padding=1, bias=False ) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.swish = Swish() def forward(self, x): for layer in self.children(): x = layer(x) return x class EffNet(nn.Module): """EfficientNet model.""" def __init__(self): assert cfg.TRAIN.DATASET in [ "imagenet" ], "Training on {} is not supported".format(cfg.TRAIN.DATASET) assert cfg.TEST.DATASET in [ "imagenet" ], "Testing on {} is not supported".format(cfg.TEST.DATASET) super(EffNet, self).__init__() self._construct( stem_w=cfg.EN.STEM_W, ds=cfg.EN.DEPTHS, ws=cfg.EN.WIDTHS, exp_rs=cfg.EN.EXP_RATIOS, se_r=cfg.EN.SE_R, ss=cfg.EN.STRIDES, ks=cfg.EN.KERNELS, head_w=cfg.EN.HEAD_W, nc=cfg.MODEL.NUM_CLASSES, ) self.apply(nu.init_weights) def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc): # Group params by stage stage_params = list(zip(ds, ws, exp_rs, ss, ks)) logger.info("Constructing: EfficientNet-{}".format(stage_params)) # Construct the stem self.stem = StemIN(3, stem_w) prev_w = stem_w # Construct the stages for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params): self.add_module( "s{}".format(i + 1), EffStage(prev_w, exp_r, kernel, stride, se_r, w, d) ) prev_w = w # Construct the head self.head = EffHead(prev_w, head_w, nc) def forward(self, x): for module in self.children(): x = module(x) return x ================================================ FILE: pycls/models/regnet.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """RegNet models.""" import numpy as np import pycls.utils.logging as lu from pycls.core.config import cfg from pycls.models.anynet import AnyNet logger = lu.get_logger(__name__) def quantize_float(f, q): """Converts a float to closest non-zero int divisible by q.""" return int(round(f / q) * q) def adjust_ws_gs_comp(ws, bms, gs): """Adjusts the compatibility of widths and groups.""" ws_bot = [int(w * b) for w, b in zip(ws, bms)] gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)] ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)] ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)] return ws, gs def get_stages_from_blocks(ws, rs): """Gets ws/ds of network at each stage from per block values.""" ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs) ts = [w != wp or r != rp for w, wp, r, rp in ts_temp] s_ws = [w for w, t in zip(ws, ts[:-1]) if t] s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist() return s_ws, s_ds def generate_regnet(w_a, w_0, w_m, d, q=8): """Generates per block ws from RegNet parameters.""" assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0 ws_cont = np.arange(d) * w_a + w_0 ks = np.round(np.log(ws_cont / w_0) / np.log(w_m)) ws = w_0 * np.power(w_m, ks) ws = np.round(np.divide(ws, q)) * q num_stages, max_stage = len(np.unique(ws)), ks.max() + 1 ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist() return ws, num_stages, max_stage, ws_cont class RegNet(AnyNet): """RegNet model.""" def __init__(self): # Generate RegNet ws per block b_ws, num_s, _, _ = generate_regnet( cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH ) # print(cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH, cfg.REGNET.GROUP_W) # Convert to per stage format ws, ds = get_stages_from_blocks(b_ws, b_ws) # Generate group widths and bot muls gws = [cfg.REGNET.GROUP_W for _ in range(num_s)] bms = [cfg.REGNET.BOT_MUL for _ in range(num_s)] # Adjust the compatibility of ws and gws ws, gws = adjust_ws_gs_comp(ws, bms, gws) # Use the same stride for each stage ss = [cfg.REGNET.STRIDE for _ in range(num_s)] # Use SE for RegNetY se_r = cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None # Construct the model kwargs = { "stem_type": cfg.REGNET.STEM_TYPE, "stem_w": cfg.REGNET.STEM_W, "block_type": cfg.REGNET.BLOCK_TYPE, "ss": ss, "ds": ds, "ws": ws, "bms": bms, "gws": gws, "se_r": se_r, "nc": cfg.MODEL.NUM_CLASSES, } super(RegNet, self).__init__(**kwargs) ================================================ FILE: pycls/models/resnet.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ResNe(X)t models.""" import pycls.utils.logging as lu import pycls.utils.net as nu import torch.nn as nn from pycls.core.config import cfg logger = lu.get_logger(__name__) # Stage depths for ImageNet models _IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)} def get_trans_fun(name): """Retrieves the transformation function by name.""" trans_funs = { "basic_transform": BasicTransform, "bottleneck_transform": BottleneckTransform, } assert ( name in trans_funs.keys() ), "Transformation function '{}' not supported".format(name) return trans_funs[name] class ResHead(nn.Module): """ResNet head.""" def __init__(self, w_in, nc): super(ResHead, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(w_in, nc, bias=True) def forward(self, x): x = self.avg_pool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x class BasicTransform(nn.Module): """Basic transformation: 3x3, 3x3""" def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1): assert ( w_b is None and num_gs == 1 ), "Basic transform does not support w_b and num_gs options" super(BasicTransform, self).__init__() self._construct(w_in, w_out, stride) def _construct(self, w_in, w_out, stride): # 3x3, BN, ReLU self.a = nn.Conv2d( w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False ) self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) # 3x3, BN self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False) self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.b_bn.final_bn = True def forward(self, x): for layer in self.children(): x = layer(x) return x class BottleneckTransform(nn.Module): """Bottleneck transformation: 1x1, 3x3, 1x1""" def __init__(self, w_in, w_out, stride, w_b, num_gs): super(BottleneckTransform, self).__init__() self._construct(w_in, w_out, stride, w_b, num_gs) def _construct(self, w_in, w_out, stride, w_b, num_gs): # MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3 (str1x1, str3x3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) # 1x1, BN, ReLU self.a = nn.Conv2d( w_in, w_b, kernel_size=1, stride=str1x1, padding=0, bias=False ) self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) # 3x3, BN, ReLU self.b = nn.Conv2d( w_b, w_b, kernel_size=3, stride=str3x3, padding=1, groups=num_gs, bias=False ) self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) # 1x1, BN self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False) self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.c_bn.final_bn = True def forward(self, x): for layer in self.children(): x = layer(x) return x class ResBlock(nn.Module): """Residual block: x + F(x)""" def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1): super(ResBlock, self).__init__() self._construct(w_in, w_out, stride, trans_fun, w_b, num_gs) def _add_skip_proj(self, w_in, w_out, stride): self.proj = nn.Conv2d( w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False ) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) def _construct(self, w_in, w_out, stride, trans_fun, w_b, num_gs): # Use skip connection with projection if shape changes self.proj_block = (w_in != w_out) or (stride != 1) if self.proj_block: self._add_skip_proj(w_in, w_out, stride) self.f = trans_fun(w_in, w_out, stride, w_b, num_gs) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) def forward(self, x): if self.proj_block: x = self.bn(self.proj(x)) + self.f(x) else: x = x + self.f(x) x = self.relu(x) return x class ResStage(nn.Module): """Stage of ResNet.""" def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1): super(ResStage, self).__init__() self._construct(w_in, w_out, stride, d, w_b, num_gs) def _construct(self, w_in, w_out, stride, d, w_b, num_gs): # Construct the blocks for i in range(d): # Stride and w_in apply to the first block of the stage b_stride = stride if i == 0 else 1 b_w_in = w_in if i == 0 else w_out # Retrieve the transformation function trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN) # Construct the block res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs) self.add_module("b{}".format(i + 1), res_block) def forward(self, x): for block in self.children(): x = block(x) return x class ResStem(nn.Module): """Stem of ResNet.""" def __init__(self, w_in, w_out): assert ( cfg.TRAIN.DATASET == cfg.TEST.DATASET ), "Train and test dataset must be the same for now" super(ResStem, self).__init__() if "cifar" in cfg.TRAIN.DATASET: self._construct_cifar(w_in, w_out) else: self._construct_imagenet(w_in, w_out) def _construct_cifar(self, w_in, w_out): # 3x3, BN, ReLU self.conv = nn.Conv2d( w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False ) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) def _construct_imagenet(self, w_in, w_out): # 7x7, BN, ReLU, maxpool self.conv = nn.Conv2d( w_in, w_out, kernel_size=7, stride=2, padding=3, bias=False ) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def forward(self, x): for layer in self.children(): x = layer(x) return x class ResNet(nn.Module): """ResNet model.""" def __init__(self): assert cfg.TRAIN.DATASET in [ "cifar10", "imagenet", ], "Training ResNet on {} is not supported".format(cfg.TRAIN.DATASET) assert cfg.TEST.DATASET in [ "cifar10", "imagenet", ], "Testing ResNet on {} is not supported".format(cfg.TEST.DATASET) super(ResNet, self).__init__() if "cifar" in cfg.TRAIN.DATASET: self._construct_cifar() else: self._construct_imagenet() self.apply(nu.init_weights) def _construct_cifar(self): assert ( cfg.MODEL.DEPTH - 2 ) % 6 == 0, "Model depth should be of the format 6n + 2 for cifar" logger.info("Constructing: ResNet-{}".format(cfg.MODEL.DEPTH)) # Each stage has the same number of blocks for cifar d = int((cfg.MODEL.DEPTH - 2) / 6) # Stem: (N, 3, 32, 32) -> (N, 16, 32, 32) self.stem = ResStem(w_in=3, w_out=16) # Stage 1: (N, 16, 32, 32) -> (N, 16, 32, 32) self.s1 = ResStage(w_in=16, w_out=16, stride=1, d=d) # Stage 2: (N, 16, 32, 32) -> (N, 32, 16, 16) self.s2 = ResStage(w_in=16, w_out=32, stride=2, d=d) # Stage 3: (N, 32, 16, 16) -> (N, 64, 8, 8) self.s3 = ResStage(w_in=32, w_out=64, stride=2, d=d) # Head: (N, 64, 8, 8) -> (N, num_classes) self.head = ResHead(w_in=64, nc=cfg.MODEL.NUM_CLASSES) def _construct_imagenet(self): logger.info( "Constructing: ResNe(X)t-{}-{}x{}, {}".format( cfg.MODEL.DEPTH, cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP, cfg.RESNET.TRANS_FUN, ) ) # Retrieve the number of blocks per stage (d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH] # Compute the initial bottleneck width num_gs = cfg.RESNET.NUM_GROUPS w_b = cfg.RESNET.WIDTH_PER_GROUP * num_gs # Stem: (N, 3, 224, 224) -> (N, 64, 56, 56) self.stem = ResStem(w_in=3, w_out=64) # Stage 1: (N, 64, 56, 56) -> (N, 256, 56, 56) self.s1 = ResStage(w_in=64, w_out=256, stride=1, d=d1, w_b=w_b, num_gs=num_gs) # Stage 2: (N, 256, 56, 56) -> (N, 512, 28, 28) self.s2 = ResStage( w_in=256, w_out=512, stride=2, d=d2, w_b=w_b * 2, num_gs=num_gs ) # Stage 3: (N, 512, 56, 56) -> (N, 1024, 14, 14) self.s3 = ResStage( w_in=512, w_out=1024, stride=2, d=d3, w_b=w_b * 4, num_gs=num_gs ) # Stage 4: (N, 1024, 14, 14) -> (N, 2048, 7, 7) self.s4 = ResStage( w_in=1024, w_out=2048, stride=2, d=d4, w_b=w_b * 8, num_gs=num_gs ) # Head: (N, 2048, 7, 7) -> (N, num_classes) self.head = ResHead(w_in=2048, nc=cfg.MODEL.NUM_CLASSES) def forward(self, x): for module in self.children(): x = module(x) return x ================================================ FILE: pycls/utils/__init__.py ================================================ ================================================ FILE: pycls/utils/benchmark.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Functions for benchmarking networks.""" import pycls.utils.logging as lu import torch from pycls.core.config import cfg from pycls.utils.timer import Timer @torch.no_grad() def compute_fw_test_time(model, inputs): """Computes forward test time (no grad, eval mode).""" # Use eval mode model.eval() # Warm up the caches for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER): model(inputs) # Make sure warmup kernels completed torch.cuda.synchronize() # Compute precise forward pass time timer = Timer() for _cur_iter in range(cfg.PREC_TIME.NUM_ITER): timer.tic() model(inputs) torch.cuda.synchronize() timer.toc() # Make sure forward kernels completed torch.cuda.synchronize() return timer.average_time def compute_fw_bw_time(model, loss_fun, inputs, labels): """Computes forward backward time.""" # Use train mode model.train() # Warm up the caches for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER): preds = model(inputs) loss = loss_fun(preds, labels) loss.backward() # Make sure warmup kernels completed torch.cuda.synchronize() # Compute precise forward backward pass time fw_timer = Timer() bw_timer = Timer() for _cur_iter in range(cfg.PREC_TIME.NUM_ITER): # Forward fw_timer.tic() preds = model(inputs) loss = loss_fun(preds, labels) torch.cuda.synchronize() fw_timer.toc() # Backward bw_timer.tic() loss.backward() torch.cuda.synchronize() bw_timer.toc() # Make sure forward backward kernels completed torch.cuda.synchronize() return fw_timer.average_time, bw_timer.average_time def compute_precise_time(model, loss_fun): """Computes precise time.""" # Generate a dummy mini-batch im_size = cfg.TRAIN.IM_SIZE inputs = torch.rand(cfg.PREC_TIME.BATCH_SIZE, 3, im_size, im_size) labels = torch.zeros(cfg.PREC_TIME.BATCH_SIZE, dtype=torch.int64) # Copy the data to the GPU inputs = inputs.cuda(non_blocking=False) labels = labels.cuda(non_blocking=False) # Compute precise time fw_test_time = compute_fw_test_time(model, inputs) fw_time, bw_time = compute_fw_bw_time(model, loss_fun, inputs, labels) # Log precise time lu.log_json_stats( { "prec_test_fw_time": fw_test_time, "prec_train_fw_time": fw_time, "prec_train_bw_time": bw_time, "prec_train_fw_bw_time": fw_time + bw_time, } ) ================================================ FILE: pycls/utils/checkpoint.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Functions that handle saving and loading of checkpoints.""" import os import pycls.utils.distributed as du import torch from pycls.core.config import cfg # Common prefix for checkpoint file names _NAME_PREFIX = "model_epoch_" # Checkpoints directory name _DIR_NAME = "checkpoints" def get_checkpoint_dir(): """Retrieves the location for storing checkpoints.""" return os.path.join(cfg.OUT_DIR, _DIR_NAME) def get_checkpoint(epoch): """Retrieves the path to a checkpoint file.""" name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) return os.path.join(get_checkpoint_dir(), name) def get_last_checkpoint(): """Retrieves the most recent checkpoint (highest epoch number).""" checkpoint_dir = get_checkpoint_dir() # Checkpoint file names are in lexicographic order checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] last_checkpoint_name = sorted(checkpoints)[-1] return os.path.join(checkpoint_dir, last_checkpoint_name) def has_checkpoint(): """Determines if there are checkpoints available.""" checkpoint_dir = get_checkpoint_dir() if not os.path.exists(checkpoint_dir): return False return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir)) def is_checkpoint_epoch(cur_epoch): """Determines if a checkpoint should be saved on current epoch.""" return (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0 def save_checkpoint(model, optimizer, epoch): """Saves a checkpoint.""" # Save checkpoints only from the master process if not du.is_master_proc(): return # Ensure that the checkpoint dir exists os.makedirs(get_checkpoint_dir(), exist_ok=True) # Omit the DDP wrapper in the multi-gpu setting sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict() # Record the state checkpoint = { "epoch": epoch, "model_state": sd, "optimizer_state": optimizer.state_dict(), "cfg": cfg.dump(), } # Write the checkpoint checkpoint_file = get_checkpoint(epoch + 1) torch.save(checkpoint, checkpoint_file) return checkpoint_file def load_checkpoint(checkpoint_file, model, optimizer=None): """Loads the checkpoint from the given file.""" assert os.path.exists(checkpoint_file), "Checkpoint '{}' not found".format( checkpoint_file ) # Load the checkpoint on CPU to avoid GPU mem spike checkpoint = torch.load(checkpoint_file, map_location="cpu") # Account for the DDP wrapper in the multi-gpu setting ms = model.module if cfg.NUM_GPUS > 1 else model ms.load_state_dict(checkpoint["model_state"]) # Load the optimizer state (commonly not done when fine-tuning) if optimizer: optimizer.load_state_dict(checkpoint["optimizer_state"]) return checkpoint["epoch"] ================================================ FILE: pycls/utils/distributed.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Distributed helpers.""" import torch from pycls.core.config import cfg def is_master_proc(): """Determines if the current process is the master process. Master process is responsible for logging, writing and loading checkpoints. In the multi GPU setting, we assign the master role to the rank 0 process. When training using a single GPU, there is only one training processes which is considered the master processes. """ return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0 def init_process_group(proc_rank, world_size): """Initializes the default process group.""" # Set the GPU to use torch.cuda.set_device(proc_rank) # Initialize the process group torch.distributed.init_process_group( backend=cfg.DIST_BACKEND, init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT), world_size=world_size, rank=proc_rank, ) def destroy_process_group(): """Destroys the default process group.""" torch.distributed.destroy_process_group() def scaled_all_reduce(tensors): """Performs the scaled all_reduce operation on the provided tensors. The input tensors are modified in-place. Currently supports only the sum reduction operator. The reduced values are scaled by the inverse size of the process group (equivalent to cfg.NUM_GPUS). """ # Queue the reductions reductions = [] for tensor in tensors: reduction = torch.distributed.all_reduce(tensor, async_op=True) reductions.append(reduction) # Wait for reductions to finish for reduction in reductions: reduction.wait() # Scale the results for tensor in tensors: tensor.mul_(1.0 / cfg.NUM_GPUS) return tensors ================================================ FILE: pycls/utils/error_handler.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Multiprocessing error handler.""" import os import signal import threading class ChildException(Exception): """Wraps an exception from a child process.""" def __init__(self, child_trace): super(ChildException, self).__init__(child_trace) class ErrorHandler(object): """Multiprocessing error handler (based on fairseq's). Listens for errors in child processes and propagates the tracebacks to the parent process. """ def __init__(self, error_queue): # Shared error queue self.error_queue = error_queue # Children processes sharing the error queue self.children_pids = [] # Start a thread listening to errors self.error_listener = threading.Thread(target=self.listen, daemon=True) self.error_listener.start() # Register the signal handler signal.signal(signal.SIGUSR1, self.signal_handler) def add_child(self, pid): """Registers a child process.""" self.children_pids.append(pid) def listen(self): """Listens for errors in the error queue.""" # Wait until there is an error in the queue child_trace = self.error_queue.get() # Put the error back for the signal handler self.error_queue.put(child_trace) # Invoke the signal handler os.kill(os.getpid(), signal.SIGUSR1) def signal_handler(self, _sig_num, _stack_frame): """Signal handler.""" # Kill children processes for pid in self.children_pids: os.kill(pid, signal.SIGINT) # Propagate the error from the child process raise ChildException(self.error_queue.get()) ================================================ FILE: pycls/utils/io.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """IO utilities (adapted from Detectron)""" import logging import os import re import sys from urllib import request as urlrequest logger = logging.getLogger(__name__) _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls" def cache_url(url_or_file, cache_dir): """Download the file specified by the URL to the cache_dir and return the path to the cached file. If the argument is not a URL, simply return it as is. """ is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None if not is_url: return url_or_file url = url_or_file assert url.startswith(_PYCLS_BASE_URL), ( "pycls only automatically caches URLs in the pycls S3 bucket: {}" ).format(_PYCLS_BASE_URL) cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir) if os.path.exists(cache_file_path): return cache_file_path cache_file_dir = os.path.dirname(cache_file_path) if not os.path.exists(cache_file_dir): os.makedirs(cache_file_dir) logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) download_url(url, cache_file_path) return cache_file_path def _progress_bar(count, total): """Report download progress. Credit: https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 """ bar_len = 60 filled_len = int(round(bar_len * count / float(total))) percents = round(100.0 * count / float(total), 1) bar = "=" * filled_len + "-" * (bar_len - filled_len) sys.stdout.write( " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024) ) sys.stdout.flush() if count >= total: sys.stdout.write("\n") def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): """Download url and write it to dst_file_path. Credit: https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook """ req = urlrequest.Request(url) response = urlrequest.urlopen(req) total_size = response.info().get("Content-Length").strip() total_size = int(total_size) bytes_so_far = 0 with open(dst_file_path, "wb") as f: while 1: chunk = response.read(chunk_size) bytes_so_far += len(chunk) if not chunk: break if progress_hook: progress_hook(bytes_so_far, total_size) f.write(chunk) return bytes_so_far ================================================ FILE: pycls/utils/logging.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Logging.""" import builtins import decimal import logging import os import sys import pycls.utils.distributed as du import simplejson from pycls.core.config import cfg # Show filename and line number in logs _FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s" # Log file name (for cfg.LOG_DEST = 'file') _LOG_FILE = "stdout.log" # Printed json stats lines will be tagged w/ this _TAG = "json_stats: " def _suppress_print(): """Suppresses printing from the current process.""" def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False): pass builtins.print = ignore def setup_logging(): """Sets up the logging.""" # Enable logging only for the master process if du.is_master_proc(): # Clear the root logger to prevent any existing logging config # (e.g. set by another module) from messing with our setup logging.root.handlers = [] # Construct logging configuration logging_config = {"level": logging.INFO, "format": _FORMAT} # Log either to stdout or to a file if cfg.LOG_DEST == "stdout": logging_config["stream"] = sys.stdout else: logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE) # Configure logging logging.basicConfig(**logging_config) else: _suppress_print() def get_logger(name): """Retrieves the logger.""" return logging.getLogger(name) def log_json_stats(stats): """Logs json stats.""" # Decimal + string workaround for having fixed len float vals in logs stats = { k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v for k, v in stats.items() } json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) logger = get_logger(__name__) logger.info("{:s}{:s}".format(_TAG, json_stats)) def load_json_stats(log_file): """Loads json_stats from a single log file.""" with open(log_file, "r") as f: lines = f.readlines() json_lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l] json_stats = [simplejson.loads(l) for l in json_lines] return json_stats def parse_json_stats(log, row_type, key): """Extract values corresponding to row_type/key out of log.""" vals = [row[key] for row in log if row["_type"] == row_type and key in row] if key == "iter" or key == "epoch": vals = [int(val.split("/")[0]) for val in vals] return vals def get_log_files(log_dir, name_filter=""): """Get all log files in directory containing subdirs of trained models.""" names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n] files = [os.path.join(log_dir, n, _LOG_FILE) for n in names] f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)] files, names = zip(*f_n_ps) return files, names ================================================ FILE: pycls/utils/lr_policy.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Learning rate policies.""" import numpy as np from pycls.core.config import cfg def lr_fun_steps(cur_epoch): """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps').""" ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1] return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind) def lr_fun_exp(cur_epoch): """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp').""" return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch) def lr_fun_cos(cur_epoch): """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos').""" base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch)) def get_lr_fun(): """Retrieves the specified lr policy function""" lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY if lr_fun not in globals(): raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY) return globals()[lr_fun] def get_epoch_lr(cur_epoch): """Retrieves the lr for the given epoch according to the policy.""" lr = get_lr_fun()(cur_epoch) # Linear warmup if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS: alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha lr *= warmup_factor return lr ================================================ FILE: pycls/utils/meters.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Meters.""" import datetime from collections import deque import numpy as np import pycls.utils.logging as lu import pycls.utils.metrics as metrics from pycls.core.config import cfg from pycls.utils.timer import Timer def eta_str(eta_td): """Converts an eta timedelta to a fixed-width string format.""" days = eta_td.days hrs, rem = divmod(eta_td.seconds, 3600) mins, secs = divmod(rem, 60) return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs) class ScalarMeter(object): """Measures a scalar value (adapted from Detectron).""" def __init__(self, window_size): self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 def reset(self): self.deque.clear() self.total = 0.0 self.count = 0 def add_value(self, value): self.deque.append(value) self.count += 1 self.total += value def get_win_median(self): return np.median(self.deque) def get_win_avg(self): return np.mean(self.deque) def get_global_avg(self): return self.total / self.count class TrainMeter(object): """Measures training stats.""" def __init__(self, epoch_iters): self.epoch_iters = epoch_iters self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters self.iter_timer = Timer() self.loss = ScalarMeter(cfg.LOG_PERIOD) self.loss_total = 0.0 self.lr = None # Current minibatch errors (smoothed over a window) self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) # Number of misclassified examples self.num_top1_mis = 0 self.num_top5_mis = 0 self.num_samples = 0 def reset(self, timer=False): if timer: self.iter_timer.reset() self.loss.reset() self.loss_total = 0.0 self.lr = None self.mb_top1_err.reset() self.mb_top5_err.reset() self.num_top1_mis = 0 self.num_top5_mis = 0 self.num_samples = 0 def iter_tic(self): self.iter_timer.tic() def iter_toc(self): self.iter_timer.toc() def update_stats(self, top1_err, top5_err, loss, lr, mb_size): # Current minibatch stats self.mb_top1_err.add_value(top1_err) self.mb_top5_err.add_value(top5_err) self.loss.add_value(loss) self.lr = lr # Aggregate stats self.num_top1_mis += top1_err * mb_size self.num_top5_mis += top5_err * mb_size self.loss_total += loss * mb_size self.num_samples += mb_size def get_iter_stats(self, cur_epoch, cur_iter): eta_sec = self.iter_timer.average_time * ( self.max_iter - (cur_epoch * self.epoch_iters + cur_iter + 1) ) eta_td = datetime.timedelta(seconds=int(eta_sec)) mem_usage = metrics.gpu_mem_usage() stats = { "_type": "train_iter", "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), "time_avg": self.iter_timer.average_time, "time_diff": self.iter_timer.diff, "eta": eta_str(eta_td), "top1_err": self.mb_top1_err.get_win_median(), "top5_err": self.mb_top5_err.get_win_median(), "loss": self.loss.get_win_median(), "lr": self.lr, "mem": int(np.ceil(mem_usage)), } return stats def log_iter_stats(self, cur_epoch, cur_iter): if (cur_iter + 1) % cfg.LOG_PERIOD != 0: return stats = self.get_iter_stats(cur_epoch, cur_iter) lu.log_json_stats(stats) def get_epoch_stats(self, cur_epoch): eta_sec = self.iter_timer.average_time * ( self.max_iter - (cur_epoch + 1) * self.epoch_iters ) eta_td = datetime.timedelta(seconds=int(eta_sec)) mem_usage = metrics.gpu_mem_usage() top1_err = self.num_top1_mis / self.num_samples top5_err = self.num_top5_mis / self.num_samples avg_loss = self.loss_total / self.num_samples stats = { "_type": "train_epoch", "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), "time_avg": self.iter_timer.average_time, "eta": eta_str(eta_td), "top1_err": top1_err, "top5_err": top5_err, "loss": avg_loss, "lr": self.lr, "mem": int(np.ceil(mem_usage)), } return stats def log_epoch_stats(self, cur_epoch): stats = self.get_epoch_stats(cur_epoch) lu.log_json_stats(stats) class TestMeter(object): """Measures testing stats.""" def __init__(self, max_iter): self.max_iter = max_iter self.iter_timer = Timer() # Current minibatch errors (smoothed over a window) self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) # Min errors (over the full test set) self.min_top1_err = 100.0 self.min_top5_err = 100.0 # Number of misclassified examples self.num_top1_mis = 0 self.num_top5_mis = 0 self.num_samples = 0 def reset(self, min_errs=False): if min_errs: self.min_top1_err = 100.0 self.min_top5_err = 100.0 self.iter_timer.reset() self.mb_top1_err.reset() self.mb_top5_err.reset() self.num_top1_mis = 0 self.num_top5_mis = 0 self.num_samples = 0 def iter_tic(self): self.iter_timer.tic() def iter_toc(self): self.iter_timer.toc() def update_stats(self, top1_err, top5_err, mb_size): self.mb_top1_err.add_value(top1_err) self.mb_top5_err.add_value(top5_err) self.num_top1_mis += top1_err * mb_size self.num_top5_mis += top5_err * mb_size self.num_samples += mb_size def get_iter_stats(self, cur_epoch, cur_iter): mem_usage = metrics.gpu_mem_usage() iter_stats = { "_type": "test_iter", "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.max_iter), "time_avg": self.iter_timer.average_time, "time_diff": self.iter_timer.diff, "top1_err": self.mb_top1_err.get_win_median(), "top5_err": self.mb_top5_err.get_win_median(), "mem": int(np.ceil(mem_usage)), } return iter_stats def log_iter_stats(self, cur_epoch, cur_iter): if (cur_iter + 1) % cfg.LOG_PERIOD != 0: return stats = self.get_iter_stats(cur_epoch, cur_iter) lu.log_json_stats(stats) def get_epoch_stats(self, cur_epoch): top1_err = self.num_top1_mis / self.num_samples top5_err = self.num_top5_mis / self.num_samples self.min_top1_err = min(self.min_top1_err, top1_err) self.min_top5_err = min(self.min_top5_err, top5_err) mem_usage = metrics.gpu_mem_usage() stats = { "_type": "test_epoch", "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), "time_avg": self.iter_timer.average_time, "top1_err": top1_err, "top5_err": top5_err, "min_top1_err": self.min_top1_err, "min_top5_err": self.min_top5_err, "mem": int(np.ceil(mem_usage)), } return stats def log_epoch_stats(self, cur_epoch): stats = self.get_epoch_stats(cur_epoch) lu.log_json_stats(stats) ================================================ FILE: pycls/utils/metrics.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Functions for computing metrics.""" import numpy as np import torch import torch.nn as nn from pycls.core.config import cfg # Number of bytes in a megabyte _B_IN_MB = 1024 * 1024 def topks_correct(preds, labels, ks): """Computes the number of top-k correct predictions for each k.""" assert preds.size(0) == labels.size( 0 ), "Batch dim of predictions and labels must match" # Find the top max_k predictions for each sample _top_max_k_vals, top_max_k_inds = torch.topk( preds, max(ks), dim=1, largest=True, sorted=True ) # (batch_size, max_k) -> (max_k, batch_size) top_max_k_inds = top_max_k_inds.t() # (batch_size, ) -> (max_k, batch_size) rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) # (i, j) = 1 if top i-th prediction for the j-th sample is correct top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) # Compute the number of topk correct predictions for each k topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks] return topks_correct def topk_errors(preds, labels, ks): """Computes the top-k error for each k.""" num_topks_correct = topks_correct(preds, labels, ks) return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] def topk_accuracies(preds, labels, ks): """Computes the top-k accuracy for each k.""" num_topks_correct = topks_correct(preds, labels, ks) return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] def params_count(model): """Computes the number of parameters.""" return np.sum([p.numel() for p in model.parameters()]).item() def flops_count(model): """Computes the number of flops statically.""" h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE count = 0 for n, m in model.named_modules(): if isinstance(m, nn.Conv2d): if "se." in n: count += m.in_channels * m.out_channels + m.bias.numel() continue h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1 w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1 count += np.prod([m.weight.numel(), h_out, w_out]) if ".proj" not in n: h, w = h_out, w_out elif isinstance(m, nn.MaxPool2d): h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1 w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1 elif isinstance(m, nn.Linear): count += m.in_features * m.out_features + m.bias.numel() return count.item() def acts_count(model): """Computes the number of activations statically.""" h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE count = 0 for n, m in model.named_modules(): if isinstance(m, nn.Conv2d): if "se." in n: count += m.out_channels continue h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1 w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1 count += np.prod([m.out_channels, h_out, w_out]) if ".proj" not in n: h, w = h_out, w_out elif isinstance(m, nn.MaxPool2d): h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1 w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1 elif isinstance(m, nn.Linear): count += m.out_features return count.item() def gpu_mem_usage(): """Computes the GPU memory usage for the current device (MB).""" mem_usage_bytes = torch.cuda.max_memory_allocated() return mem_usage_bytes / _B_IN_MB ================================================ FILE: pycls/utils/multiprocessing.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Multiprocessing helpers.""" import multiprocessing as mp import traceback import pycls.utils.distributed as du from pycls.utils.error_handler import ErrorHandler def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs): """Runs a function from a child process.""" try: # Initialize the process group du.init_process_group(proc_rank, world_size) # Run the function fun(*fun_args, **fun_kwargs) except KeyboardInterrupt: # Killed by the parent process pass except Exception: # Propagate exception to the parent process error_queue.put(traceback.format_exc()) finally: # Destroy the process group du.destroy_process_group() def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None): """Runs a function in a multi-proc setting.""" if fun_kwargs is None: fun_kwargs = {} # Handle errors from training subprocesses error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Run each training subprocess ps = [] for i in range(num_proc): p_i = mp.Process( target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs) ) ps.append(p_i) p_i.start() error_handler.add_child(p_i.pid) # Wait for each subprocess to finish for p in ps: p.join() ================================================ FILE: pycls/utils/net.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Functions for manipulating networks.""" import itertools import math import torch import torch.nn as nn from pycls.core.config import cfg def init_weights(m): """Performs ResNet-style weight initialization.""" if isinstance(m, nn.Conv2d): # Note that there is no bias due to BN fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) elif isinstance(m, nn.BatchNorm2d): zero_init_gamma = ( hasattr(m, "final_bn") and m.final_bn and cfg.BN.ZERO_INIT_FINAL_GAMMA ) m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.weight.data.normal_(mean=0.0, std=0.01) m.bias.data.zero_() @torch.no_grad() def compute_precise_bn_stats(model, loader): """Computes precise BN stats on training data.""" # Compute the number of minibatches to use num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader)) # Retrieve the BN layers bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] # Initialize stats storage mus = [torch.zeros_like(bn.running_mean) for bn in bns] sqs = [torch.zeros_like(bn.running_var) for bn in bns] # Remember momentum values moms = [bn.momentum for bn in bns] # Disable momentum for bn in bns: bn.momentum = 1.0 # Accumulate the stats across the data samples for inputs, _labels in itertools.islice(loader, num_iter): model(inputs.cuda()) # Accumulate the stats for each BN layer for i, bn in enumerate(bns): m, v = bn.running_mean, bn.running_var sqs[i] += (v + m * m) / num_iter mus[i] += m / num_iter # Set the stats and restore momentum values for i, bn in enumerate(bns): bn.running_var = sqs[i] - mus[i] * mus[i] bn.running_mean = mus[i] bn.momentum = moms[i] def reset_bn_stats(model): """Resets running BN stats.""" for m in model.modules(): if isinstance(m, torch.nn.BatchNorm2d): m.reset_running_stats() def drop_connect(x, drop_ratio): """Drop connect (adapted from DARTS).""" keep_ratio = 1.0 - drop_ratio mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) mask.bernoulli_(keep_ratio) x.div_(keep_ratio) x.mul_(mask) return x def get_flat_weights(model): """Gets all model weights as a single flat vector.""" return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0) def set_flat_weights(model, flat_weights): """Sets all model weights from a single flat vector.""" k = 0 for p in model.parameters(): n = p.data.numel() p.data.copy_(flat_weights[k : (k + n)].view_as(p.data)) k += n assert k == flat_weights.numel() ================================================ FILE: pycls/utils/plotting.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Plotting functions.""" import colorlover as cl import matplotlib.pyplot as plt import plotly.graph_objs as go import plotly.offline as offline import pycls.utils.logging as lu def get_plot_colors(max_colors, color_format="pyplot"): """Generate colors for plotting.""" colors = cl.scales["11"]["qual"]["Paired"] if max_colors > len(colors): colors = cl.to_rgb(cl.interp(colors, max_colors)) if color_format == "pyplot": return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)] return colors def prepare_plot_data(log_files, names, key="top1_err"): """Load logs and extract data for plotting error curves.""" plot_data = [] for file, name in zip(log_files, names): d, log = {}, lu.load_json_stats(file) for phase in ["train", "test"]: x = lu.parse_json_stats(log, phase + "_epoch", "epoch") y = lu.parse_json_stats(log, phase + "_epoch", key) d["x_" + phase], d["y_" + phase] = x, y d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name plot_data.append(d) assert len(plot_data) > 0, "No data to plot" return plot_data def plot_error_curves_plotly(log_files, names, filename, key="top1_err"): """Plot error curves using plotly and save to file.""" plot_data = prepare_plot_data(log_files, names, key) colors = get_plot_colors(len(plot_data), "plotly") # Prepare data for plots (3 sets, train duplicated w and w/o legend) data = [] for i, d in enumerate(plot_data): s = str(i) line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5} line_test = {"color": colors[i], "dash": "solid", "width": 1.5} data.append( go.Scatter( x=d["x_train"], y=d["y_train"], mode="lines", name=d["train_label"], line=line_train, legendgroup=s, visible=True, showlegend=False, ) ) data.append( go.Scatter( x=d["x_test"], y=d["y_test"], mode="lines", name=d["test_label"], line=line_test, legendgroup=s, visible=True, showlegend=True, ) ) data.append( go.Scatter( x=d["x_train"], y=d["y_train"], mode="lines", name=d["train_label"], line=line_train, legendgroup=s, visible=False, showlegend=True, ) ) # Prepare layout w ability to toggle 'all', 'train', 'test' titlefont = {"size": 18, "color": "#7f7f7f"} vis = [[True, True, False], [False, False, True], [False, True, False]] buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis]) buttons = [{"label": l, "args": v, "method": "update"} for l, v in buttons] layout = go.Layout( title=key + " vs. epoch
[dash=train, solid=test]", xaxis={"title": "epoch", "titlefont": titlefont}, yaxis={"title": key, "titlefont": titlefont}, showlegend=True, hoverlabel={"namelength": -1}, updatemenus=[ { "buttons": buttons, "direction": "down", "showactive": True, "x": 1.02, "xanchor": "left", "y": 1.08, "yanchor": "top", } ], ) # Create plotly plot offline.plot({"data": data, "layout": layout}, filename=filename) def plot_error_curves_pyplot(log_files, names, filename=None, key="top1_err"): """Plot error curves using matplotlib.pyplot and save to file.""" plot_data = prepare_plot_data(log_files, names, key) colors = get_plot_colors(len(names)) for ind, d in enumerate(plot_data): c, lbl = colors[ind], d["test_label"] plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8) plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl) plt.title(key + " vs. epoch\n[dash=train, solid=test]", fontsize=14) plt.xlabel("epoch", fontsize=14) plt.ylabel(key, fontsize=14) plt.grid(alpha=0.4) plt.legend() if filename: plt.savefig(filename) plt.clf() else: plt.show() ================================================ FILE: pycls/utils/timer.py ================================================ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Timer.""" import time class Timer(object): """A simple timer (adapted from Detectron).""" def __init__(self): self.reset() def tic(self): # using time.time instead of time.clock because time time.clock # does not normalize for multithreading self.start_time = time.time() def toc(self): self.diff = time.time() - self.start_time self.total_time += self.diff self.calls += 1 self.average_time = self.total_time / self.calls def reset(self): self.total_time = 0.0 self.calls = 0 self.start_time = 0.0 self.diff = 0.0 self.average_time = 0.0 ================================================ FILE: simplejson/__init__.py ================================================ r"""JSON (JavaScript Object Notation) is a subset of JavaScript syntax (ECMA-262 3rd edition) used as a lightweight data interchange format. :mod:`simplejson` exposes an API familiar to users of the standard library :mod:`marshal` and :mod:`pickle` modules. It is the externally maintained version of the :mod:`json` library contained in Python 2.6, but maintains compatibility back to Python 2.5 and (currently) has significant performance advantages, even without using the optional C extension for speedups. Encoding basic Python object hierarchies:: >>> import simplejson as json >>> json.dumps(['foo', {'bar': ('baz', None, 1.0, 2)}]) '["foo", {"bar": ["baz", null, 1.0, 2]}]' >>> print(json.dumps("\"foo\bar")) "\"foo\bar" >>> print(json.dumps(u'\u1234')) "\u1234" >>> print(json.dumps('\\')) "\\" >>> print(json.dumps({"c": 0, "b": 0, "a": 0}, sort_keys=True)) {"a": 0, "b": 0, "c": 0} >>> from simplejson.compat import StringIO >>> io = StringIO() >>> json.dump(['streaming API'], io) >>> io.getvalue() '["streaming API"]' Compact encoding:: >>> import simplejson as json >>> obj = [1,2,3,{'4': 5, '6': 7}] >>> json.dumps(obj, separators=(',',':'), sort_keys=True) '[1,2,3,{"4":5,"6":7}]' Pretty printing:: >>> import simplejson as json >>> print(json.dumps({'4': 5, '6': 7}, sort_keys=True, indent=' ')) { "4": 5, "6": 7 } Decoding JSON:: >>> import simplejson as json >>> obj = [u'foo', {u'bar': [u'baz', None, 1.0, 2]}] >>> json.loads('["foo", {"bar":["baz", null, 1.0, 2]}]') == obj True >>> json.loads('"\\"foo\\bar"') == u'"foo\x08ar' True >>> from simplejson.compat import StringIO >>> io = StringIO('["streaming API"]') >>> json.load(io)[0] == 'streaming API' True Specializing JSON object decoding:: >>> import simplejson as json >>> def as_complex(dct): ... if '__complex__' in dct: ... return complex(dct['real'], dct['imag']) ... return dct ... >>> json.loads('{"__complex__": true, "real": 1, "imag": 2}', ... object_hook=as_complex) (1+2j) >>> from decimal import Decimal >>> json.loads('1.1', parse_float=Decimal) == Decimal('1.1') True Specializing JSON object encoding:: >>> import simplejson as json >>> def encode_complex(obj): ... if isinstance(obj, complex): ... return [obj.real, obj.imag] ... raise TypeError('Object of type %s is not JSON serializable' % ... obj.__class__.__name__) ... >>> json.dumps(2 + 1j, default=encode_complex) '[2.0, 1.0]' >>> json.JSONEncoder(default=encode_complex).encode(2 + 1j) '[2.0, 1.0]' >>> ''.join(json.JSONEncoder(default=encode_complex).iterencode(2 + 1j)) '[2.0, 1.0]' Using simplejson.tool from the shell to validate and pretty-print:: $ echo '{"json":"obj"}' | python -m simplejson.tool { "json": "obj" } $ echo '{ 1.2:3.4}' | python -m simplejson.tool Expecting property name: line 1 column 3 (char 2) Parsing multiple documents serialized as JSON lines (newline-delimited JSON):: >>> import simplejson as json >>> def loads_lines(docs): ... for doc in docs.splitlines(): ... yield json.loads(doc) ... >>> sum(doc["count"] for doc in loads_lines('{"count":1}\n{"count":2}\n{"count":3}\n')) 6 Serializing multiple objects to JSON lines (newline-delimited JSON):: >>> import simplejson as json >>> def dumps_lines(objs): ... for obj in objs: ... yield json.dumps(obj, separators=(',',':')) + '\n' ... >>> ''.join(dumps_lines([{'count': 1}, {'count': 2}, {'count': 3}])) '{"count":1}\n{"count":2}\n{"count":3}\n' """ from __future__ import absolute_import __version__ = '3.17.2' __all__ = [ 'dump', 'dumps', 'load', 'loads', 'JSONDecoder', 'JSONDecodeError', 'JSONEncoder', 'OrderedDict', 'simple_first', 'RawJSON' ] __author__ = 'Bob Ippolito ' from decimal import Decimal from .errors import JSONDecodeError from .raw_json import RawJSON from .decoder import JSONDecoder from .encoder import JSONEncoder, JSONEncoderForHTML def _import_OrderedDict(): import collections try: return collections.OrderedDict except AttributeError: from . import ordered_dict return ordered_dict.OrderedDict OrderedDict = _import_OrderedDict() def _import_c_make_encoder(): try: from ._speedups import make_encoder return make_encoder except ImportError: return None _default_encoder = JSONEncoder( skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, encoding='utf-8', default=None, use_decimal=True, namedtuple_as_object=True, tuple_as_array=True, iterable_as_array=False, bigint_as_string=False, item_sort_key=None, for_json=False, ignore_nan=False, int_as_string_bitcount=None, ) def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, cls=None, indent=None, separators=None, encoding='utf-8', default=None, use_decimal=True, namedtuple_as_object=True, tuple_as_array=True, bigint_as_string=False, sort_keys=False, item_sort_key=None, for_json=False, ignore_nan=False, int_as_string_bitcount=None, iterable_as_array=False, **kw): """Serialize ``obj`` as a JSON formatted stream to ``fp`` (a ``.write()``-supporting file-like object). If *skipkeys* is true then ``dict`` keys that are not basic types (``str``, ``int``, ``long``, ``float``, ``bool``, ``None``) will be skipped instead of raising a ``TypeError``. If *ensure_ascii* is false (default: ``True``), then the output may contain non-ASCII characters, so long as they do not need to be escaped by JSON. When it is true, all non-ASCII characters are escaped. If *allow_nan* is false, then it will be a ``ValueError`` to serialize out of range ``float`` values (``nan``, ``inf``, ``-inf``) in strict compliance of the original JSON specification, instead of using the JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``). See *ignore_nan* for ECMA-262 compliant behavior. If *indent* is a string, then JSON array elements and object members will be pretty-printed with a newline followed by that string repeated for each level of nesting. ``None`` (the default) selects the most compact representation without any newlines. If specified, *separators* should be an ``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')`` if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most compact JSON representation, you should specify ``(',', ':')`` to eliminate whitespace. *encoding* is the character encoding for str instances, default is UTF-8. *default(obj)* is a function that should return a serializable version of obj or raise ``TypeError``. The default simply raises ``TypeError``. If *use_decimal* is true (default: ``True``) then decimal.Decimal will be natively serialized to JSON with full precision. If *namedtuple_as_object* is true (default: ``True``), :class:`tuple` subclasses with ``_asdict()`` methods will be encoded as JSON objects. If *tuple_as_array* is true (default: ``True``), :class:`tuple` (and subclasses) will be encoded as JSON arrays. If *iterable_as_array* is true (default: ``False``), any object not in the above table that implements ``__iter__()`` will be encoded as a JSON array. If *bigint_as_string* is true (default: ``False``), ints 2**53 and higher or lower than -2**53 will be encoded as strings. This is to avoid the rounding that happens in Javascript otherwise. Note that this is still a lossy operation that will not round-trip correctly and should be used sparingly. If *int_as_string_bitcount* is a positive number (n), then int of size greater than or equal to 2**n or lower than or equal to -2**n will be encoded as strings. If specified, *item_sort_key* is a callable used to sort the items in each dictionary. This is useful if you want to sort items other than in alphabetical order by key. This option takes precedence over *sort_keys*. If *sort_keys* is true (default: ``False``), the output of dictionaries will be sorted by item. If *for_json* is true (default: ``False``), objects with a ``for_json()`` method will use the return value of that method for encoding as JSON instead of the object. If *ignore_nan* is true (default: ``False``), then out of range :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as ``null`` in compliance with the ECMA-262 specification. If true, this will override *allow_nan*. To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the ``.default()`` method to serialize additional types), specify it with the ``cls`` kwarg. NOTE: You should use *default* or *for_json* instead of subclassing whenever possible. """ # cached encoder if (not skipkeys and ensure_ascii and check_circular and allow_nan and cls is None and indent is None and separators is None and encoding == 'utf-8' and default is None and use_decimal and namedtuple_as_object and tuple_as_array and not iterable_as_array and not bigint_as_string and not sort_keys and not item_sort_key and not for_json and not ignore_nan and int_as_string_bitcount is None and not kw ): iterable = _default_encoder.iterencode(obj) else: if cls is None: cls = JSONEncoder iterable = cls(skipkeys=skipkeys, ensure_ascii=ensure_ascii, check_circular=check_circular, allow_nan=allow_nan, indent=indent, separators=separators, encoding=encoding, default=default, use_decimal=use_decimal, namedtuple_as_object=namedtuple_as_object, tuple_as_array=tuple_as_array, iterable_as_array=iterable_as_array, bigint_as_string=bigint_as_string, sort_keys=sort_keys, item_sort_key=item_sort_key, for_json=for_json, ignore_nan=ignore_nan, int_as_string_bitcount=int_as_string_bitcount, **kw).iterencode(obj) # could accelerate with writelines in some versions of Python, at # a debuggability cost for chunk in iterable: fp.write(chunk) def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, cls=None, indent=None, separators=None, encoding='utf-8', default=None, use_decimal=True, namedtuple_as_object=True, tuple_as_array=True, bigint_as_string=False, sort_keys=False, item_sort_key=None, for_json=False, ignore_nan=False, int_as_string_bitcount=None, iterable_as_array=False, **kw): """Serialize ``obj`` to a JSON formatted ``str``. If ``skipkeys`` is false then ``dict`` keys that are not basic types (``str``, ``int``, ``long``, ``float``, ``bool``, ``None``) will be skipped instead of raising a ``TypeError``. If *ensure_ascii* is false (default: ``True``), then the output may contain non-ASCII characters, so long as they do not need to be escaped by JSON. When it is true, all non-ASCII characters are escaped. If ``check_circular`` is false, then the circular reference check for container types will be skipped and a circular reference will result in an ``OverflowError`` (or worse). If ``allow_nan`` is false, then it will be a ``ValueError`` to serialize out of range ``float`` values (``nan``, ``inf``, ``-inf``) in strict compliance of the JSON specification, instead of using the JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``). If ``indent`` is a string, then JSON array elements and object members will be pretty-printed with a newline followed by that string repeated for each level of nesting. ``None`` (the default) selects the most compact representation without any newlines. For backwards compatibility with versions of simplejson earlier than 2.1.0, an integer is also accepted and is converted to a string with that many spaces. If specified, ``separators`` should be an ``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')`` if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most compact JSON representation, you should specify ``(',', ':')`` to eliminate whitespace. ``encoding`` is the character encoding for bytes instances, default is UTF-8. ``default(obj)`` is a function that should return a serializable version of obj or raise TypeError. The default simply raises TypeError. If *use_decimal* is true (default: ``True``) then decimal.Decimal will be natively serialized to JSON with full precision. If *namedtuple_as_object* is true (default: ``True``), :class:`tuple` subclasses with ``_asdict()`` methods will be encoded as JSON objects. If *tuple_as_array* is true (default: ``True``), :class:`tuple` (and subclasses) will be encoded as JSON arrays. If *iterable_as_array* is true (default: ``False``), any object not in the above table that implements ``__iter__()`` will be encoded as a JSON array. If *bigint_as_string* is true (not the default), ints 2**53 and higher or lower than -2**53 will be encoded as strings. This is to avoid the rounding that happens in Javascript otherwise. If *int_as_string_bitcount* is a positive number (n), then int of size greater than or equal to 2**n or lower than or equal to -2**n will be encoded as strings. If specified, *item_sort_key* is a callable used to sort the items in each dictionary. This is useful if you want to sort items other than in alphabetical order by key. This option takes precendence over *sort_keys*. If *sort_keys* is true (default: ``False``), the output of dictionaries will be sorted by item. If *for_json* is true (default: ``False``), objects with a ``for_json()`` method will use the return value of that method for encoding as JSON instead of the object. If *ignore_nan* is true (default: ``False``), then out of range :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as ``null`` in compliance with the ECMA-262 specification. If true, this will override *allow_nan*. To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the ``.default()`` method to serialize additional types), specify it with the ``cls`` kwarg. NOTE: You should use *default* instead of subclassing whenever possible. """ # cached encoder if (not skipkeys and ensure_ascii and check_circular and allow_nan and cls is None and indent is None and separators is None and encoding == 'utf-8' and default is None and use_decimal and namedtuple_as_object and tuple_as_array and not iterable_as_array and not bigint_as_string and not sort_keys and not item_sort_key and not for_json and not ignore_nan and int_as_string_bitcount is None and not kw ): return _default_encoder.encode(obj) if cls is None: cls = JSONEncoder return cls( skipkeys=skipkeys, ensure_ascii=ensure_ascii, check_circular=check_circular, allow_nan=allow_nan, indent=indent, separators=separators, encoding=encoding, default=default, use_decimal=use_decimal, namedtuple_as_object=namedtuple_as_object, tuple_as_array=tuple_as_array, iterable_as_array=iterable_as_array, bigint_as_string=bigint_as_string, sort_keys=sort_keys, item_sort_key=item_sort_key, for_json=for_json, ignore_nan=ignore_nan, int_as_string_bitcount=int_as_string_bitcount, **kw).encode(obj) _default_decoder = JSONDecoder(encoding=None, object_hook=None, object_pairs_hook=None) def load(fp, encoding=None, cls=None, object_hook=None, parse_float=None, parse_int=None, parse_constant=None, object_pairs_hook=None, use_decimal=False, namedtuple_as_object=True, tuple_as_array=True, **kw): """Deserialize ``fp`` (a ``.read()``-supporting file-like object containing a JSON document as `str` or `bytes`) to a Python object. *encoding* determines the encoding used to interpret any `bytes` objects decoded by this instance (``'utf-8'`` by default). It has no effect when decoding `str` objects. *object_hook*, if specified, will be called with the result of every JSON object decoded and its return value will be used in place of the given :class:`dict`. This can be used to provide custom deserializations (e.g. to support JSON-RPC class hinting). *object_pairs_hook* is an optional function that will be called with the result of any object literal decode with an ordered list of pairs. The return value of *object_pairs_hook* will be used instead of the :class:`dict`. This feature can be used to implement custom decoders that rely on the order that the key and value pairs are decoded (for example, :func:`collections.OrderedDict` will remember the order of insertion). If *object_hook* is also defined, the *object_pairs_hook* takes priority. *parse_float*, if specified, will be called with the string of every JSON float to be decoded. By default, this is equivalent to ``float(num_str)``. This can be used to use another datatype or parser for JSON floats (e.g. :class:`decimal.Decimal`). *parse_int*, if specified, will be called with the string of every JSON int to be decoded. By default, this is equivalent to ``int(num_str)``. This can be used to use another datatype or parser for JSON integers (e.g. :class:`float`). *parse_constant*, if specified, will be called with one of the following strings: ``'-Infinity'``, ``'Infinity'``, ``'NaN'``. This can be used to raise an exception if invalid JSON numbers are encountered. If *use_decimal* is true (default: ``False``) then it implies parse_float=decimal.Decimal for parity with ``dump``. To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead of subclassing whenever possible. """ return loads(fp.read(), encoding=encoding, cls=cls, object_hook=object_hook, parse_float=parse_float, parse_int=parse_int, parse_constant=parse_constant, object_pairs_hook=object_pairs_hook, use_decimal=use_decimal, **kw) def loads(s, encoding=None, cls=None, object_hook=None, parse_float=None, parse_int=None, parse_constant=None, object_pairs_hook=None, use_decimal=False, **kw): """Deserialize ``s`` (a ``str`` or ``unicode`` instance containing a JSON document) to a Python object. *encoding* determines the encoding used to interpret any :class:`bytes` objects decoded by this instance (``'utf-8'`` by default). It has no effect when decoding :class:`unicode` objects. *object_hook*, if specified, will be called with the result of every JSON object decoded and its return value will be used in place of the given :class:`dict`. This can be used to provide custom deserializations (e.g. to support JSON-RPC class hinting). *object_pairs_hook* is an optional function that will be called with the result of any object literal decode with an ordered list of pairs. The return value of *object_pairs_hook* will be used instead of the :class:`dict`. This feature can be used to implement custom decoders that rely on the order that the key and value pairs are decoded (for example, :func:`collections.OrderedDict` will remember the order of insertion). If *object_hook* is also defined, the *object_pairs_hook* takes priority. *parse_float*, if specified, will be called with the string of every JSON float to be decoded. By default, this is equivalent to ``float(num_str)``. This can be used to use another datatype or parser for JSON floats (e.g. :class:`decimal.Decimal`). *parse_int*, if specified, will be called with the string of every JSON int to be decoded. By default, this is equivalent to ``int(num_str)``. This can be used to use another datatype or parser for JSON integers (e.g. :class:`float`). *parse_constant*, if specified, will be called with one of the following strings: ``'-Infinity'``, ``'Infinity'``, ``'NaN'``. This can be used to raise an exception if invalid JSON numbers are encountered. If *use_decimal* is true (default: ``False``) then it implies parse_float=decimal.Decimal for parity with ``dump``. To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead of subclassing whenever possible. """ if (cls is None and encoding is None and object_hook is None and parse_int is None and parse_float is None and parse_constant is None and object_pairs_hook is None and not use_decimal and not kw): return _default_decoder.decode(s) if cls is None: cls = JSONDecoder if object_hook is not None: kw['object_hook'] = object_hook if object_pairs_hook is not None: kw['object_pairs_hook'] = object_pairs_hook if parse_float is not None: kw['parse_float'] = parse_float if parse_int is not None: kw['parse_int'] = parse_int if parse_constant is not None: kw['parse_constant'] = parse_constant if use_decimal: if parse_float is not None: raise TypeError("use_decimal=True implies parse_float=Decimal") kw['parse_float'] = Decimal return cls(encoding=encoding, **kw).decode(s) def _toggle_speedups(enabled): from . import decoder as dec from . import encoder as enc from . import scanner as scan c_make_encoder = _import_c_make_encoder() if enabled: dec.scanstring = dec.c_scanstring or dec.py_scanstring enc.c_make_encoder = c_make_encoder enc.encode_basestring_ascii = (enc.c_encode_basestring_ascii or enc.py_encode_basestring_ascii) scan.make_scanner = scan.c_make_scanner or scan.py_make_scanner else: dec.scanstring = dec.py_scanstring enc.c_make_encoder = None enc.encode_basestring_ascii = enc.py_encode_basestring_ascii scan.make_scanner = scan.py_make_scanner dec.make_scanner = scan.make_scanner global _default_decoder _default_decoder = JSONDecoder( encoding=None, object_hook=None, object_pairs_hook=None, ) global _default_encoder _default_encoder = JSONEncoder( skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, encoding='utf-8', default=None, ) def simple_first(kv): """Helper function to pass to item_sort_key to sort simple elements to the top, then container elements. """ return (isinstance(kv[1], (list, dict, tuple)), kv[0]) ================================================ FILE: simplejson/_speedups.c ================================================ /* -*- mode: C; c-file-style: "python"; c-basic-offset: 4 -*- */ #include "Python.h" #include "structmember.h" #if PY_MAJOR_VERSION >= 3 #define PyInt_FromSsize_t PyLong_FromSsize_t #define PyInt_AsSsize_t PyLong_AsSsize_t #define PyInt_Check(obj) 0 #define PyInt_CheckExact(obj) 0 #define JSON_UNICHR Py_UCS4 #define JSON_InternFromString PyUnicode_InternFromString #define PyString_GET_SIZE PyUnicode_GET_LENGTH #define PY2_UNUSED #define PY3_UNUSED UNUSED #else /* PY_MAJOR_VERSION >= 3 */ #define PY2_UNUSED UNUSED #define PY3_UNUSED #define PyBytes_Check PyString_Check #define PyUnicode_READY(obj) 0 #define PyUnicode_KIND(obj) (sizeof(Py_UNICODE)) #define PyUnicode_DATA(obj) ((void *)(PyUnicode_AS_UNICODE(obj))) #define PyUnicode_READ(kind, data, index) ((JSON_UNICHR)((const Py_UNICODE *)(data))[(index)]) #define PyUnicode_GET_LENGTH PyUnicode_GET_SIZE #define JSON_UNICHR Py_UNICODE #define JSON_InternFromString PyString_InternFromString #endif /* PY_MAJOR_VERSION < 3 */ #if PY_VERSION_HEX < 0x02070000 #if !defined(PyOS_string_to_double) #define PyOS_string_to_double json_PyOS_string_to_double static double json_PyOS_string_to_double(const char *s, char **endptr, PyObject *overflow_exception); static double json_PyOS_string_to_double(const char *s, char **endptr, PyObject *overflow_exception) { double x; assert(endptr == NULL); assert(overflow_exception == NULL); PyFPE_START_PROTECT("json_PyOS_string_to_double", return -1.0;) x = PyOS_ascii_atof(s); PyFPE_END_PROTECT(x) return x; } #endif #endif /* PY_VERSION_HEX < 0x02070000 */ #if PY_VERSION_HEX < 0x02060000 #if !defined(Py_TYPE) #define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) #endif #if !defined(Py_SIZE) #define Py_SIZE(ob) (((PyVarObject*)(ob))->ob_size) #endif #if !defined(PyVarObject_HEAD_INIT) #define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, #endif #endif /* PY_VERSION_HEX < 0x02060000 */ #ifdef __GNUC__ #define UNUSED __attribute__((__unused__)) #else #define UNUSED #endif #define DEFAULT_ENCODING "utf-8" #define PyScanner_Check(op) PyObject_TypeCheck(op, &PyScannerType) #define PyScanner_CheckExact(op) (Py_TYPE(op) == &PyScannerType) #define PyEncoder_Check(op) PyObject_TypeCheck(op, &PyEncoderType) #define PyEncoder_CheckExact(op) (Py_TYPE(op) == &PyEncoderType) #define JSON_ALLOW_NAN 1 #define JSON_IGNORE_NAN 2 static PyObject *JSON_Infinity = NULL; static PyObject *JSON_NegInfinity = NULL; static PyObject *JSON_NaN = NULL; static PyObject *JSON_EmptyUnicode = NULL; #if PY_MAJOR_VERSION < 3 static PyObject *JSON_EmptyStr = NULL; #endif static PyTypeObject PyScannerType; static PyTypeObject PyEncoderType; typedef struct { PyObject *large_strings; /* A list of previously accumulated large strings */ PyObject *small_strings; /* Pending small strings */ } JSON_Accu; static int JSON_Accu_Init(JSON_Accu *acc); static int JSON_Accu_Accumulate(JSON_Accu *acc, PyObject *unicode); static PyObject * JSON_Accu_FinishAsList(JSON_Accu *acc); static void JSON_Accu_Destroy(JSON_Accu *acc); #define ERR_EXPECTING_VALUE "Expecting value" #define ERR_ARRAY_DELIMITER "Expecting ',' delimiter or ']'" #define ERR_ARRAY_VALUE_FIRST "Expecting value or ']'" #define ERR_OBJECT_DELIMITER "Expecting ',' delimiter or '}'" #define ERR_OBJECT_PROPERTY "Expecting property name enclosed in double quotes" #define ERR_OBJECT_PROPERTY_FIRST "Expecting property name enclosed in double quotes or '}'" #define ERR_OBJECT_PROPERTY_DELIMITER "Expecting ':' delimiter" #define ERR_STRING_UNTERMINATED "Unterminated string starting at" #define ERR_STRING_CONTROL "Invalid control character %r at" #define ERR_STRING_ESC1 "Invalid \\X escape sequence %r" #define ERR_STRING_ESC4 "Invalid \\uXXXX escape sequence" typedef struct _PyScannerObject { PyObject_HEAD PyObject *encoding; PyObject *strict_bool; int strict; PyObject *object_hook; PyObject *pairs_hook; PyObject *parse_float; PyObject *parse_int; PyObject *parse_constant; PyObject *memo; } PyScannerObject; static PyMemberDef scanner_members[] = { {"encoding", T_OBJECT, offsetof(PyScannerObject, encoding), READONLY, "encoding"}, {"strict", T_OBJECT, offsetof(PyScannerObject, strict_bool), READONLY, "strict"}, {"object_hook", T_OBJECT, offsetof(PyScannerObject, object_hook), READONLY, "object_hook"}, {"object_pairs_hook", T_OBJECT, offsetof(PyScannerObject, pairs_hook), READONLY, "object_pairs_hook"}, {"parse_float", T_OBJECT, offsetof(PyScannerObject, parse_float), READONLY, "parse_float"}, {"parse_int", T_OBJECT, offsetof(PyScannerObject, parse_int), READONLY, "parse_int"}, {"parse_constant", T_OBJECT, offsetof(PyScannerObject, parse_constant), READONLY, "parse_constant"}, {NULL} }; typedef struct _PyEncoderObject { PyObject_HEAD PyObject *markers; PyObject *defaultfn; PyObject *encoder; PyObject *indent; PyObject *key_separator; PyObject *item_separator; PyObject *sort_keys; PyObject *key_memo; PyObject *encoding; PyObject *Decimal; PyObject *skipkeys_bool; int skipkeys; int fast_encode; /* 0, JSON_ALLOW_NAN, JSON_IGNORE_NAN */ int allow_or_ignore_nan; int use_decimal; int namedtuple_as_object; int tuple_as_array; int iterable_as_array; PyObject *max_long_size; PyObject *min_long_size; PyObject *item_sort_key; PyObject *item_sort_kw; int for_json; } PyEncoderObject; static PyMemberDef encoder_members[] = { {"markers", T_OBJECT, offsetof(PyEncoderObject, markers), READONLY, "markers"}, {"default", T_OBJECT, offsetof(PyEncoderObject, defaultfn), READONLY, "default"}, {"encoder", T_OBJECT, offsetof(PyEncoderObject, encoder), READONLY, "encoder"}, {"encoding", T_OBJECT, offsetof(PyEncoderObject, encoder), READONLY, "encoding"}, {"indent", T_OBJECT, offsetof(PyEncoderObject, indent), READONLY, "indent"}, {"key_separator", T_OBJECT, offsetof(PyEncoderObject, key_separator), READONLY, "key_separator"}, {"item_separator", T_OBJECT, offsetof(PyEncoderObject, item_separator), READONLY, "item_separator"}, {"sort_keys", T_OBJECT, offsetof(PyEncoderObject, sort_keys), READONLY, "sort_keys"}, /* Python 2.5 does not support T_BOOl */ {"skipkeys", T_OBJECT, offsetof(PyEncoderObject, skipkeys_bool), READONLY, "skipkeys"}, {"key_memo", T_OBJECT, offsetof(PyEncoderObject, key_memo), READONLY, "key_memo"}, {"item_sort_key", T_OBJECT, offsetof(PyEncoderObject, item_sort_key), READONLY, "item_sort_key"}, {"max_long_size", T_OBJECT, offsetof(PyEncoderObject, max_long_size), READONLY, "max_long_size"}, {"min_long_size", T_OBJECT, offsetof(PyEncoderObject, min_long_size), READONLY, "min_long_size"}, {NULL} }; static PyObject * join_list_unicode(PyObject *lst); static PyObject * JSON_ParseEncoding(PyObject *encoding); static PyObject * maybe_quote_bigint(PyEncoderObject* s, PyObject *encoded, PyObject *obj); static Py_ssize_t ascii_char_size(JSON_UNICHR c); static Py_ssize_t ascii_escape_char(JSON_UNICHR c, char *output, Py_ssize_t chars); static PyObject * ascii_escape_unicode(PyObject *pystr); static PyObject * ascii_escape_str(PyObject *pystr); static PyObject * py_encode_basestring_ascii(PyObject* self UNUSED, PyObject *pystr); #if PY_MAJOR_VERSION < 3 static PyObject * join_list_string(PyObject *lst); static PyObject * scan_once_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr); static PyObject * scanstring_str(PyObject *pystr, Py_ssize_t end, char *encoding, int strict, Py_ssize_t *next_end_ptr); static PyObject * _parse_object_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr); #endif static PyObject * scanstring_unicode(PyObject *pystr, Py_ssize_t end, int strict, Py_ssize_t *next_end_ptr); static PyObject * scan_once_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr); static PyObject * _build_rval_index_tuple(PyObject *rval, Py_ssize_t idx); static PyObject * scanner_new(PyTypeObject *type, PyObject *args, PyObject *kwds); static void scanner_dealloc(PyObject *self); static int scanner_clear(PyObject *self); static PyObject * encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds); static void encoder_dealloc(PyObject *self); static int encoder_clear(PyObject *self); static int is_raw_json(PyObject *obj); static PyObject * encoder_stringify_key(PyEncoderObject *s, PyObject *key); static int encoder_listencode_list(PyEncoderObject *s, JSON_Accu *rval, PyObject *seq, Py_ssize_t indent_level); static int encoder_listencode_obj(PyEncoderObject *s, JSON_Accu *rval, PyObject *obj, Py_ssize_t indent_level); static int encoder_listencode_dict(PyEncoderObject *s, JSON_Accu *rval, PyObject *dct, Py_ssize_t indent_level); static PyObject * _encoded_const(PyObject *obj); static void raise_errmsg(char *msg, PyObject *s, Py_ssize_t end); static PyObject * encoder_encode_string(PyEncoderObject *s, PyObject *obj); static int _convertPyInt_AsSsize_t(PyObject *o, Py_ssize_t *size_ptr); static PyObject * _convertPyInt_FromSsize_t(Py_ssize_t *size_ptr); static PyObject * encoder_encode_float(PyEncoderObject *s, PyObject *obj); static int _is_namedtuple(PyObject *obj); static int _has_for_json_hook(PyObject *obj); static PyObject * moduleinit(void); #define S_CHAR(c) (c >= ' ' && c <= '~' && c != '\\' && c != '"') #define IS_WHITESPACE(c) (((c) == ' ') || ((c) == '\t') || ((c) == '\n') || ((c) == '\r')) #define MIN_EXPANSION 6 static PyObject* RawJSONType = NULL; static int is_raw_json(PyObject *obj) { return PyObject_IsInstance(obj, RawJSONType) ? 1 : 0; } static int JSON_Accu_Init(JSON_Accu *acc) { /* Lazily allocated */ acc->large_strings = NULL; acc->small_strings = PyList_New(0); if (acc->small_strings == NULL) return -1; return 0; } static int flush_accumulator(JSON_Accu *acc) { Py_ssize_t nsmall = PyList_GET_SIZE(acc->small_strings); if (nsmall) { int ret; PyObject *joined; if (acc->large_strings == NULL) { acc->large_strings = PyList_New(0); if (acc->large_strings == NULL) return -1; } #if PY_MAJOR_VERSION >= 3 joined = join_list_unicode(acc->small_strings); #else /* PY_MAJOR_VERSION >= 3 */ joined = join_list_string(acc->small_strings); #endif /* PY_MAJOR_VERSION < 3 */ if (joined == NULL) return -1; if (PyList_SetSlice(acc->small_strings, 0, nsmall, NULL)) { Py_DECREF(joined); return -1; } ret = PyList_Append(acc->large_strings, joined); Py_DECREF(joined); return ret; } return 0; } static int JSON_Accu_Accumulate(JSON_Accu *acc, PyObject *unicode) { Py_ssize_t nsmall; #if PY_MAJOR_VERSION >= 3 assert(PyUnicode_Check(unicode)); #else /* PY_MAJOR_VERSION >= 3 */ assert(PyString_Check(unicode) || PyUnicode_Check(unicode)); #endif /* PY_MAJOR_VERSION < 3 */ if (PyList_Append(acc->small_strings, unicode)) return -1; nsmall = PyList_GET_SIZE(acc->small_strings); /* Each item in a list of unicode objects has an overhead (in 64-bit * builds) of: * - 8 bytes for the list slot * - 56 bytes for the header of the unicode object * that is, 64 bytes. 100000 such objects waste more than 6MB * compared to a single concatenated string. */ if (nsmall < 100000) return 0; return flush_accumulator(acc); } static PyObject * JSON_Accu_FinishAsList(JSON_Accu *acc) { int ret; PyObject *res; ret = flush_accumulator(acc); Py_CLEAR(acc->small_strings); if (ret) { Py_CLEAR(acc->large_strings); return NULL; } res = acc->large_strings; acc->large_strings = NULL; if (res == NULL) return PyList_New(0); return res; } static void JSON_Accu_Destroy(JSON_Accu *acc) { Py_CLEAR(acc->small_strings); Py_CLEAR(acc->large_strings); } static int IS_DIGIT(JSON_UNICHR c) { return c >= '0' && c <= '9'; } static PyObject * maybe_quote_bigint(PyEncoderObject* s, PyObject *encoded, PyObject *obj) { if (s->max_long_size != Py_None && s->min_long_size != Py_None) { if (PyObject_RichCompareBool(obj, s->max_long_size, Py_GE) || PyObject_RichCompareBool(obj, s->min_long_size, Py_LE)) { #if PY_MAJOR_VERSION >= 3 PyObject* quoted = PyUnicode_FromFormat("\"%U\"", encoded); #else PyObject* quoted = PyString_FromFormat("\"%s\"", PyString_AsString(encoded)); #endif Py_DECREF(encoded); encoded = quoted; } } return encoded; } static int _is_namedtuple(PyObject *obj) { int rval = 0; PyObject *_asdict = PyObject_GetAttrString(obj, "_asdict"); if (_asdict == NULL) { PyErr_Clear(); return 0; } rval = PyCallable_Check(_asdict); Py_DECREF(_asdict); return rval; } static int _has_for_json_hook(PyObject *obj) { int rval = 0; PyObject *for_json = PyObject_GetAttrString(obj, "for_json"); if (for_json == NULL) { PyErr_Clear(); return 0; } rval = PyCallable_Check(for_json); Py_DECREF(for_json); return rval; } static int _convertPyInt_AsSsize_t(PyObject *o, Py_ssize_t *size_ptr) { /* PyObject to Py_ssize_t converter */ *size_ptr = PyInt_AsSsize_t(o); if (*size_ptr == -1 && PyErr_Occurred()) return 0; return 1; } static PyObject * _convertPyInt_FromSsize_t(Py_ssize_t *size_ptr) { /* Py_ssize_t to PyObject converter */ return PyInt_FromSsize_t(*size_ptr); } static Py_ssize_t ascii_escape_char(JSON_UNICHR c, char *output, Py_ssize_t chars) { /* Escape unicode code point c to ASCII escape sequences in char *output. output must have at least 12 bytes unused to accommodate an escaped surrogate pair "\uXXXX\uXXXX" */ if (S_CHAR(c)) { output[chars++] = (char)c; } else { output[chars++] = '\\'; switch (c) { case '\\': output[chars++] = (char)c; break; case '"': output[chars++] = (char)c; break; case '\b': output[chars++] = 'b'; break; case '\f': output[chars++] = 'f'; break; case '\n': output[chars++] = 'n'; break; case '\r': output[chars++] = 'r'; break; case '\t': output[chars++] = 't'; break; default: #if PY_MAJOR_VERSION >= 3 || defined(Py_UNICODE_WIDE) if (c >= 0x10000) { /* UTF-16 surrogate pair */ JSON_UNICHR v = c - 0x10000; c = 0xd800 | ((v >> 10) & 0x3ff); output[chars++] = 'u'; output[chars++] = "0123456789abcdef"[(c >> 12) & 0xf]; output[chars++] = "0123456789abcdef"[(c >> 8) & 0xf]; output[chars++] = "0123456789abcdef"[(c >> 4) & 0xf]; output[chars++] = "0123456789abcdef"[(c ) & 0xf]; c = 0xdc00 | (v & 0x3ff); output[chars++] = '\\'; } #endif output[chars++] = 'u'; output[chars++] = "0123456789abcdef"[(c >> 12) & 0xf]; output[chars++] = "0123456789abcdef"[(c >> 8) & 0xf]; output[chars++] = "0123456789abcdef"[(c >> 4) & 0xf]; output[chars++] = "0123456789abcdef"[(c ) & 0xf]; } } return chars; } static Py_ssize_t ascii_char_size(JSON_UNICHR c) { if (S_CHAR(c)) { return 1; } else if (c == '\\' || c == '"' || c == '\b' || c == '\f' || c == '\n' || c == '\r' || c == '\t') { return 2; } #if PY_MAJOR_VERSION >= 3 || defined(Py_UNICODE_WIDE) else if (c >= 0x10000U) { return 2 * MIN_EXPANSION; } #endif else { return MIN_EXPANSION; } } static PyObject * ascii_escape_unicode(PyObject *pystr) { /* Take a PyUnicode pystr and return a new ASCII-only escaped PyString */ Py_ssize_t i; Py_ssize_t input_chars = PyUnicode_GET_LENGTH(pystr); Py_ssize_t output_size = 2; Py_ssize_t chars; PY2_UNUSED int kind = PyUnicode_KIND(pystr); void *data = PyUnicode_DATA(pystr); PyObject *rval; char *output; output_size = 2; for (i = 0; i < input_chars; i++) { output_size += ascii_char_size(PyUnicode_READ(kind, data, i)); } #if PY_MAJOR_VERSION >= 3 rval = PyUnicode_New(output_size, 127); if (rval == NULL) { return NULL; } assert(PyUnicode_KIND(rval) == PyUnicode_1BYTE_KIND); output = (char *)PyUnicode_DATA(rval); #else rval = PyString_FromStringAndSize(NULL, output_size); if (rval == NULL) { return NULL; } output = PyString_AS_STRING(rval); #endif chars = 0; output[chars++] = '"'; for (i = 0; i < input_chars; i++) { chars = ascii_escape_char(PyUnicode_READ(kind, data, i), output, chars); } output[chars++] = '"'; assert(chars == output_size); return rval; } #if PY_MAJOR_VERSION >= 3 static PyObject * ascii_escape_str(PyObject *pystr) { PyObject *rval; PyObject *input = PyUnicode_DecodeUTF8(PyBytes_AS_STRING(pystr), PyBytes_GET_SIZE(pystr), NULL); if (input == NULL) return NULL; rval = ascii_escape_unicode(input); Py_DECREF(input); return rval; } #else /* PY_MAJOR_VERSION >= 3 */ static PyObject * ascii_escape_str(PyObject *pystr) { /* Take a PyString pystr and return a new ASCII-only escaped PyString */ Py_ssize_t i; Py_ssize_t input_chars; Py_ssize_t output_size; Py_ssize_t chars; PyObject *rval; char *output; char *input_str; input_chars = PyString_GET_SIZE(pystr); input_str = PyString_AS_STRING(pystr); output_size = 2; /* Fast path for a string that's already ASCII */ for (i = 0; i < input_chars; i++) { JSON_UNICHR c = (JSON_UNICHR)input_str[i]; if (c > 0x7f) { /* We hit a non-ASCII character, bail to unicode mode */ PyObject *uni; uni = PyUnicode_DecodeUTF8(input_str, input_chars, "strict"); if (uni == NULL) { return NULL; } rval = ascii_escape_unicode(uni); Py_DECREF(uni); return rval; } output_size += ascii_char_size(c); } rval = PyString_FromStringAndSize(NULL, output_size); if (rval == NULL) { return NULL; } chars = 0; output = PyString_AS_STRING(rval); output[chars++] = '"'; for (i = 0; i < input_chars; i++) { chars = ascii_escape_char((JSON_UNICHR)input_str[i], output, chars); } output[chars++] = '"'; assert(chars == output_size); return rval; } #endif /* PY_MAJOR_VERSION < 3 */ static PyObject * encoder_stringify_key(PyEncoderObject *s, PyObject *key) { if (PyUnicode_Check(key)) { Py_INCREF(key); return key; } #if PY_MAJOR_VERSION >= 3 else if (PyBytes_Check(key) && s->encoding != NULL) { const char *encoding = PyUnicode_AsUTF8(s->encoding); if (encoding == NULL) return NULL; return PyUnicode_Decode( PyBytes_AS_STRING(key), PyBytes_GET_SIZE(key), encoding, NULL); } #else /* PY_MAJOR_VERSION >= 3 */ else if (PyString_Check(key)) { Py_INCREF(key); return key; } #endif /* PY_MAJOR_VERSION < 3 */ else if (PyFloat_Check(key)) { return encoder_encode_float(s, key); } else if (key == Py_True || key == Py_False || key == Py_None) { /* This must come before the PyInt_Check because True and False are also 1 and 0.*/ return _encoded_const(key); } else if (PyInt_Check(key) || PyLong_Check(key)) { if (!(PyInt_CheckExact(key) || PyLong_CheckExact(key))) { /* See #118, do not trust custom str/repr */ PyObject *res; PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyLong_Type, key, NULL); if (tmp == NULL) { return NULL; } res = PyObject_Str(tmp); Py_DECREF(tmp); return res; } else { return PyObject_Str(key); } } else if (s->use_decimal && PyObject_TypeCheck(key, (PyTypeObject *)s->Decimal)) { return PyObject_Str(key); } if (s->skipkeys) { Py_INCREF(Py_None); return Py_None; } PyErr_Format(PyExc_TypeError, "keys must be str, int, float, bool or None, " "not %.100s", key->ob_type->tp_name); return NULL; } static PyObject * encoder_dict_iteritems(PyEncoderObject *s, PyObject *dct) { PyObject *items; PyObject *iter = NULL; PyObject *lst = NULL; PyObject *item = NULL; PyObject *kstr = NULL; PyObject *sortfun = NULL; PyObject *sortres; static PyObject *sortargs = NULL; if (sortargs == NULL) { sortargs = PyTuple_New(0); if (sortargs == NULL) return NULL; } if (PyDict_CheckExact(dct)) items = PyDict_Items(dct); else items = PyMapping_Items(dct); if (items == NULL) return NULL; iter = PyObject_GetIter(items); Py_DECREF(items); if (iter == NULL) return NULL; if (s->item_sort_kw == Py_None) return iter; lst = PyList_New(0); if (lst == NULL) goto bail; while ((item = PyIter_Next(iter))) { PyObject *key, *value; if (!PyTuple_Check(item) || Py_SIZE(item) != 2) { PyErr_SetString(PyExc_ValueError, "items must return 2-tuples"); goto bail; } key = PyTuple_GET_ITEM(item, 0); if (key == NULL) goto bail; #if PY_MAJOR_VERSION < 3 else if (PyString_Check(key)) { /* item can be added as-is */ } #endif /* PY_MAJOR_VERSION < 3 */ else if (PyUnicode_Check(key)) { /* item can be added as-is */ } else { PyObject *tpl; kstr = encoder_stringify_key(s, key); if (kstr == NULL) goto bail; else if (kstr == Py_None) { /* skipkeys */ Py_DECREF(kstr); continue; } value = PyTuple_GET_ITEM(item, 1); if (value == NULL) goto bail; tpl = PyTuple_Pack(2, kstr, value); if (tpl == NULL) goto bail; Py_CLEAR(kstr); Py_DECREF(item); item = tpl; } if (PyList_Append(lst, item)) goto bail; Py_DECREF(item); } Py_CLEAR(iter); if (PyErr_Occurred()) goto bail; sortfun = PyObject_GetAttrString(lst, "sort"); if (sortfun == NULL) goto bail; sortres = PyObject_Call(sortfun, sortargs, s->item_sort_kw); if (!sortres) goto bail; Py_DECREF(sortres); Py_CLEAR(sortfun); iter = PyObject_GetIter(lst); Py_CLEAR(lst); return iter; bail: Py_XDECREF(sortfun); Py_XDECREF(kstr); Py_XDECREF(item); Py_XDECREF(lst); Py_XDECREF(iter); return NULL; } /* Use JSONDecodeError exception to raise a nice looking ValueError subclass */ static PyObject *JSONDecodeError = NULL; static void raise_errmsg(char *msg, PyObject *s, Py_ssize_t end) { PyObject *exc = PyObject_CallFunction(JSONDecodeError, "(zOO&)", msg, s, _convertPyInt_FromSsize_t, &end); if (exc) { PyErr_SetObject(JSONDecodeError, exc); Py_DECREF(exc); } } static PyObject * join_list_unicode(PyObject *lst) { /* return u''.join(lst) */ return PyUnicode_Join(JSON_EmptyUnicode, lst); } #if PY_MAJOR_VERSION >= 3 #define join_list_string join_list_unicode #else /* PY_MAJOR_VERSION >= 3 */ static PyObject * join_list_string(PyObject *lst) { /* return ''.join(lst) */ static PyObject *joinfn = NULL; if (joinfn == NULL) { joinfn = PyObject_GetAttrString(JSON_EmptyStr, "join"); if (joinfn == NULL) return NULL; } return PyObject_CallFunctionObjArgs(joinfn, lst, NULL); } #endif /* PY_MAJOR_VERSION < 3 */ static PyObject * _build_rval_index_tuple(PyObject *rval, Py_ssize_t idx) { /* return (rval, idx) tuple, stealing reference to rval */ PyObject *tpl; PyObject *pyidx; /* steal a reference to rval, returns (rval, idx) */ if (rval == NULL) { assert(PyErr_Occurred()); return NULL; } pyidx = PyInt_FromSsize_t(idx); if (pyidx == NULL) { Py_DECREF(rval); return NULL; } tpl = PyTuple_New(2); if (tpl == NULL) { Py_DECREF(pyidx); Py_DECREF(rval); return NULL; } PyTuple_SET_ITEM(tpl, 0, rval); PyTuple_SET_ITEM(tpl, 1, pyidx); return tpl; } #define APPEND_OLD_CHUNK \ if (chunk != NULL) { \ if (chunks == NULL) { \ chunks = PyList_New(0); \ if (chunks == NULL) { \ goto bail; \ } \ } \ if (PyList_Append(chunks, chunk)) { \ goto bail; \ } \ Py_CLEAR(chunk); \ } #if PY_MAJOR_VERSION < 3 static PyObject * scanstring_str(PyObject *pystr, Py_ssize_t end, char *encoding, int strict, Py_ssize_t *next_end_ptr) { /* Read the JSON string from PyString pystr. end is the index of the first character after the quote. encoding is the encoding of pystr (must be an ASCII superset) if strict is zero then literal control characters are allowed *next_end_ptr is a return-by-reference index of the character after the end quote Return value is a new PyString (if ASCII-only) or PyUnicode */ PyObject *rval; Py_ssize_t len = PyString_GET_SIZE(pystr); Py_ssize_t begin = end - 1; Py_ssize_t next = begin; int has_unicode = 0; char *buf = PyString_AS_STRING(pystr); PyObject *chunks = NULL; PyObject *chunk = NULL; PyObject *strchunk = NULL; if (len == end) { raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin); goto bail; } else if (end < 0 || len < end) { PyErr_SetString(PyExc_ValueError, "end is out of bounds"); goto bail; } while (1) { /* Find the end of the string or the next escape */ Py_UNICODE c = 0; for (next = end; next < len; next++) { c = (unsigned char)buf[next]; if (c == '"' || c == '\\') { break; } else if (strict && c <= 0x1f) { raise_errmsg(ERR_STRING_CONTROL, pystr, next); goto bail; } else if (c > 0x7f) { has_unicode = 1; } } if (!(c == '"' || c == '\\')) { raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin); goto bail; } /* Pick up this chunk if it's not zero length */ if (next != end) { APPEND_OLD_CHUNK strchunk = PyString_FromStringAndSize(&buf[end], next - end); if (strchunk == NULL) { goto bail; } if (has_unicode) { chunk = PyUnicode_FromEncodedObject(strchunk, encoding, NULL); Py_DECREF(strchunk); if (chunk == NULL) { goto bail; } } else { chunk = strchunk; } } next++; if (c == '"') { end = next; break; } if (next == len) { raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin); goto bail; } c = buf[next]; if (c != 'u') { /* Non-unicode backslash escapes */ end = next + 1; switch (c) { case '"': break; case '\\': break; case '/': break; case 'b': c = '\b'; break; case 'f': c = '\f'; break; case 'n': c = '\n'; break; case 'r': c = '\r'; break; case 't': c = '\t'; break; default: c = 0; } if (c == 0) { raise_errmsg(ERR_STRING_ESC1, pystr, end - 2); goto bail; } } else { c = 0; next++; end = next + 4; if (end >= len) { raise_errmsg(ERR_STRING_ESC4, pystr, next - 1); goto bail; } /* Decode 4 hex digits */ for (; next < end; next++) { JSON_UNICHR digit = (JSON_UNICHR)buf[next]; c <<= 4; switch (digit) { case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': c |= (digit - '0'); break; case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': c |= (digit - 'a' + 10); break; case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': c |= (digit - 'A' + 10); break; default: raise_errmsg(ERR_STRING_ESC4, pystr, end - 5); goto bail; } } #if defined(Py_UNICODE_WIDE) /* Surrogate pair */ if ((c & 0xfc00) == 0xd800) { if (end + 6 < len && buf[next] == '\\' && buf[next+1] == 'u') { JSON_UNICHR c2 = 0; end += 6; /* Decode 4 hex digits */ for (next += 2; next < end; next++) { c2 <<= 4; JSON_UNICHR digit = buf[next]; switch (digit) { case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': c2 |= (digit - '0'); break; case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': c2 |= (digit - 'a' + 10); break; case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': c2 |= (digit - 'A' + 10); break; default: raise_errmsg(ERR_STRING_ESC4, pystr, end - 5); goto bail; } } if ((c2 & 0xfc00) != 0xdc00) { /* not a low surrogate, rewind */ end -= 6; next = end; } else { c = 0x10000 + (((c - 0xd800) << 10) | (c2 - 0xdc00)); } } } #endif /* Py_UNICODE_WIDE */ } if (c > 0x7f) { has_unicode = 1; } APPEND_OLD_CHUNK if (has_unicode) { chunk = PyUnicode_FromOrdinal(c); if (chunk == NULL) { goto bail; } } else { char c_char = Py_CHARMASK(c); chunk = PyString_FromStringAndSize(&c_char, 1); if (chunk == NULL) { goto bail; } } } if (chunks == NULL) { if (chunk != NULL) rval = chunk; else { rval = JSON_EmptyStr; Py_INCREF(rval); } } else { APPEND_OLD_CHUNK rval = join_list_string(chunks); if (rval == NULL) { goto bail; } Py_CLEAR(chunks); } *next_end_ptr = end; return rval; bail: *next_end_ptr = -1; Py_XDECREF(chunk); Py_XDECREF(chunks); return NULL; } #endif /* PY_MAJOR_VERSION < 3 */ static PyObject * scanstring_unicode(PyObject *pystr, Py_ssize_t end, int strict, Py_ssize_t *next_end_ptr) { /* Read the JSON string from PyUnicode pystr. end is the index of the first character after the quote. if strict is zero then literal control characters are allowed *next_end_ptr is a return-by-reference index of the character after the end quote Return value is a new PyUnicode */ PyObject *rval; Py_ssize_t begin = end - 1; Py_ssize_t next = begin; PY2_UNUSED int kind = PyUnicode_KIND(pystr); Py_ssize_t len = PyUnicode_GET_LENGTH(pystr); void *buf = PyUnicode_DATA(pystr); PyObject *chunks = NULL; PyObject *chunk = NULL; if (len == end) { raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin); goto bail; } else if (end < 0 || len < end) { PyErr_SetString(PyExc_ValueError, "end is out of bounds"); goto bail; } while (1) { /* Find the end of the string or the next escape */ JSON_UNICHR c = 0; for (next = end; next < len; next++) { c = PyUnicode_READ(kind, buf, next); if (c == '"' || c == '\\') { break; } else if (strict && c <= 0x1f) { raise_errmsg(ERR_STRING_CONTROL, pystr, next); goto bail; } } if (!(c == '"' || c == '\\')) { raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin); goto bail; } /* Pick up this chunk if it's not zero length */ if (next != end) { APPEND_OLD_CHUNK #if PY_MAJOR_VERSION < 3 chunk = PyUnicode_FromUnicode(&((const Py_UNICODE *)buf)[end], next - end); #else chunk = PyUnicode_Substring(pystr, end, next); #endif if (chunk == NULL) { goto bail; } } next++; if (c == '"') { end = next; break; } if (next == len) { raise_errmsg(ERR_STRING_UNTERMINATED, pystr, begin); goto bail; } c = PyUnicode_READ(kind, buf, next); if (c != 'u') { /* Non-unicode backslash escapes */ end = next + 1; switch (c) { case '"': break; case '\\': break; case '/': break; case 'b': c = '\b'; break; case 'f': c = '\f'; break; case 'n': c = '\n'; break; case 'r': c = '\r'; break; case 't': c = '\t'; break; default: c = 0; } if (c == 0) { raise_errmsg(ERR_STRING_ESC1, pystr, end - 2); goto bail; } } else { c = 0; next++; end = next + 4; if (end >= len) { raise_errmsg(ERR_STRING_ESC4, pystr, next - 1); goto bail; } /* Decode 4 hex digits */ for (; next < end; next++) { JSON_UNICHR digit = PyUnicode_READ(kind, buf, next); c <<= 4; switch (digit) { case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': c |= (digit - '0'); break; case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': c |= (digit - 'a' + 10); break; case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': c |= (digit - 'A' + 10); break; default: raise_errmsg(ERR_STRING_ESC4, pystr, end - 5); goto bail; } } #if PY_MAJOR_VERSION >= 3 || defined(Py_UNICODE_WIDE) /* Surrogate pair */ if ((c & 0xfc00) == 0xd800) { JSON_UNICHR c2 = 0; if (end + 6 < len && PyUnicode_READ(kind, buf, next) == '\\' && PyUnicode_READ(kind, buf, next + 1) == 'u') { end += 6; /* Decode 4 hex digits */ for (next += 2; next < end; next++) { JSON_UNICHR digit = PyUnicode_READ(kind, buf, next); c2 <<= 4; switch (digit) { case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': c2 |= (digit - '0'); break; case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': c2 |= (digit - 'a' + 10); break; case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': c2 |= (digit - 'A' + 10); break; default: raise_errmsg(ERR_STRING_ESC4, pystr, end - 5); goto bail; } } if ((c2 & 0xfc00) != 0xdc00) { /* not a low surrogate, rewind */ end -= 6; next = end; } else { c = 0x10000 + (((c - 0xd800) << 10) | (c2 - 0xdc00)); } } } #endif } APPEND_OLD_CHUNK chunk = PyUnicode_FromOrdinal(c); if (chunk == NULL) { goto bail; } } if (chunks == NULL) { if (chunk != NULL) rval = chunk; else { rval = JSON_EmptyUnicode; Py_INCREF(rval); } } else { APPEND_OLD_CHUNK rval = join_list_unicode(chunks); if (rval == NULL) { goto bail; } Py_CLEAR(chunks); } *next_end_ptr = end; return rval; bail: *next_end_ptr = -1; Py_XDECREF(chunk); Py_XDECREF(chunks); return NULL; } PyDoc_STRVAR(pydoc_scanstring, "scanstring(basestring, end, encoding, strict=True) -> (str, end)\n" "\n" "Scan the string s for a JSON string. End is the index of the\n" "character in s after the quote that started the JSON string.\n" "Unescapes all valid JSON string escape sequences and raises ValueError\n" "on attempt to decode an invalid string. If strict is False then literal\n" "control characters are allowed in the string.\n" "\n" "Returns a tuple of the decoded string and the index of the character in s\n" "after the end quote." ); static PyObject * py_scanstring(PyObject* self UNUSED, PyObject *args) { PyObject *pystr; PyObject *rval; Py_ssize_t end; Py_ssize_t next_end = -1; char *encoding = NULL; int strict = 1; if (!PyArg_ParseTuple(args, "OO&|zi:scanstring", &pystr, _convertPyInt_AsSsize_t, &end, &encoding, &strict)) { return NULL; } if (encoding == NULL) { encoding = DEFAULT_ENCODING; } if (PyUnicode_Check(pystr)) { if (PyUnicode_READY(pystr)) return NULL; rval = scanstring_unicode(pystr, end, strict, &next_end); } #if PY_MAJOR_VERSION < 3 /* Using a bytes input is unsupported for scanning in Python 3. It is coerced to str in the decoder before it gets here. */ else if (PyString_Check(pystr)) { rval = scanstring_str(pystr, end, encoding, strict, &next_end); } #endif else { PyErr_Format(PyExc_TypeError, "first argument must be a string, not %.80s", Py_TYPE(pystr)->tp_name); return NULL; } return _build_rval_index_tuple(rval, next_end); } PyDoc_STRVAR(pydoc_encode_basestring_ascii, "encode_basestring_ascii(basestring) -> str\n" "\n" "Return an ASCII-only JSON representation of a Python string" ); static PyObject * py_encode_basestring_ascii(PyObject* self UNUSED, PyObject *pystr) { /* Return an ASCII-only JSON representation of a Python string */ /* METH_O */ if (PyBytes_Check(pystr)) { return ascii_escape_str(pystr); } else if (PyUnicode_Check(pystr)) { if (PyUnicode_READY(pystr)) return NULL; return ascii_escape_unicode(pystr); } else { PyErr_Format(PyExc_TypeError, "first argument must be a string, not %.80s", Py_TYPE(pystr)->tp_name); return NULL; } } static void scanner_dealloc(PyObject *self) { /* bpo-31095: UnTrack is needed before calling any callbacks */ PyObject_GC_UnTrack(self); scanner_clear(self); Py_TYPE(self)->tp_free(self); } static int scanner_traverse(PyObject *self, visitproc visit, void *arg) { PyScannerObject *s; assert(PyScanner_Check(self)); s = (PyScannerObject *)self; Py_VISIT(s->encoding); Py_VISIT(s->strict_bool); Py_VISIT(s->object_hook); Py_VISIT(s->pairs_hook); Py_VISIT(s->parse_float); Py_VISIT(s->parse_int); Py_VISIT(s->parse_constant); Py_VISIT(s->memo); return 0; } static int scanner_clear(PyObject *self) { PyScannerObject *s; assert(PyScanner_Check(self)); s = (PyScannerObject *)self; Py_CLEAR(s->encoding); Py_CLEAR(s->strict_bool); Py_CLEAR(s->object_hook); Py_CLEAR(s->pairs_hook); Py_CLEAR(s->parse_float); Py_CLEAR(s->parse_int); Py_CLEAR(s->parse_constant); Py_CLEAR(s->memo); return 0; } #if PY_MAJOR_VERSION < 3 static PyObject * _parse_object_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr) { /* Read a JSON object from PyString pystr. idx is the index of the first character after the opening curly brace. *next_idx_ptr is a return-by-reference index to the first character after the closing curly brace. Returns a new PyObject (usually a dict, but object_hook or object_pairs_hook can change that) */ char *str = PyString_AS_STRING(pystr); Py_ssize_t end_idx = PyString_GET_SIZE(pystr) - 1; PyObject *rval = NULL; PyObject *pairs = NULL; PyObject *item; PyObject *key = NULL; PyObject *val = NULL; char *encoding = PyString_AS_STRING(s->encoding); int has_pairs_hook = (s->pairs_hook != Py_None); int did_parse = 0; Py_ssize_t next_idx; if (has_pairs_hook) { pairs = PyList_New(0); if (pairs == NULL) return NULL; } else { rval = PyDict_New(); if (rval == NULL) return NULL; } /* skip whitespace after { */ while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++; /* only loop if the object is non-empty */ if (idx <= end_idx && str[idx] != '}') { int trailing_delimiter = 0; while (idx <= end_idx) { PyObject *memokey; trailing_delimiter = 0; /* read key */ if (str[idx] != '"') { raise_errmsg(ERR_OBJECT_PROPERTY, pystr, idx); goto bail; } key = scanstring_str(pystr, idx + 1, encoding, s->strict, &next_idx); if (key == NULL) goto bail; memokey = PyDict_GetItem(s->memo, key); if (memokey != NULL) { Py_INCREF(memokey); Py_DECREF(key); key = memokey; } else { if (PyDict_SetItem(s->memo, key, key) < 0) goto bail; } idx = next_idx; /* skip whitespace between key and : delimiter, read :, skip whitespace */ while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++; if (idx > end_idx || str[idx] != ':') { raise_errmsg(ERR_OBJECT_PROPERTY_DELIMITER, pystr, idx); goto bail; } idx++; while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++; /* read any JSON data type */ val = scan_once_str(s, pystr, idx, &next_idx); if (val == NULL) goto bail; if (has_pairs_hook) { item = PyTuple_Pack(2, key, val); if (item == NULL) goto bail; Py_CLEAR(key); Py_CLEAR(val); if (PyList_Append(pairs, item) == -1) { Py_DECREF(item); goto bail; } Py_DECREF(item); } else { if (PyDict_SetItem(rval, key, val) < 0) goto bail; Py_CLEAR(key); Py_CLEAR(val); } idx = next_idx; /* skip whitespace before } or , */ while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++; /* bail if the object is closed or we didn't get the , delimiter */ did_parse = 1; if (idx > end_idx) break; if (str[idx] == '}') { break; } else if (str[idx] != ',') { raise_errmsg(ERR_OBJECT_DELIMITER, pystr, idx); goto bail; } idx++; /* skip whitespace after , delimiter */ while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++; trailing_delimiter = 1; } if (trailing_delimiter) { raise_errmsg(ERR_OBJECT_PROPERTY, pystr, idx); goto bail; } } /* verify that idx < end_idx, str[idx] should be '}' */ if (idx > end_idx || str[idx] != '}') { if (did_parse) { raise_errmsg(ERR_OBJECT_DELIMITER, pystr, idx); } else { raise_errmsg(ERR_OBJECT_PROPERTY_FIRST, pystr, idx); } goto bail; } /* if pairs_hook is not None: rval = object_pairs_hook(pairs) */ if (s->pairs_hook != Py_None) { val = PyObject_CallFunctionObjArgs(s->pairs_hook, pairs, NULL); if (val == NULL) goto bail; Py_DECREF(pairs); *next_idx_ptr = idx + 1; return val; } /* if object_hook is not None: rval = object_hook(rval) */ if (s->object_hook != Py_None) { val = PyObject_CallFunctionObjArgs(s->object_hook, rval, NULL); if (val == NULL) goto bail; Py_DECREF(rval); rval = val; val = NULL; } *next_idx_ptr = idx + 1; return rval; bail: Py_XDECREF(rval); Py_XDECREF(key); Py_XDECREF(val); Py_XDECREF(pairs); return NULL; } #endif /* PY_MAJOR_VERSION < 3 */ static PyObject * _parse_object_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr) { /* Read a JSON object from PyUnicode pystr. idx is the index of the first character after the opening curly brace. *next_idx_ptr is a return-by-reference index to the first character after the closing curly brace. Returns a new PyObject (usually a dict, but object_hook can change that) */ void *str = PyUnicode_DATA(pystr); Py_ssize_t end_idx = PyUnicode_GET_LENGTH(pystr) - 1; PY2_UNUSED int kind = PyUnicode_KIND(pystr); PyObject *rval = NULL; PyObject *pairs = NULL; PyObject *item; PyObject *key = NULL; PyObject *val = NULL; int has_pairs_hook = (s->pairs_hook != Py_None); int did_parse = 0; Py_ssize_t next_idx; if (has_pairs_hook) { pairs = PyList_New(0); if (pairs == NULL) return NULL; } else { rval = PyDict_New(); if (rval == NULL) return NULL; } /* skip whitespace after { */ while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++; /* only loop if the object is non-empty */ if (idx <= end_idx && PyUnicode_READ(kind, str, idx) != '}') { int trailing_delimiter = 0; while (idx <= end_idx) { PyObject *memokey; trailing_delimiter = 0; /* read key */ if (PyUnicode_READ(kind, str, idx) != '"') { raise_errmsg(ERR_OBJECT_PROPERTY, pystr, idx); goto bail; } key = scanstring_unicode(pystr, idx + 1, s->strict, &next_idx); if (key == NULL) goto bail; memokey = PyDict_GetItem(s->memo, key); if (memokey != NULL) { Py_INCREF(memokey); Py_DECREF(key); key = memokey; } else { if (PyDict_SetItem(s->memo, key, key) < 0) goto bail; } idx = next_idx; /* skip whitespace between key and : delimiter, read :, skip whitespace */ while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++; if (idx > end_idx || PyUnicode_READ(kind, str, idx) != ':') { raise_errmsg(ERR_OBJECT_PROPERTY_DELIMITER, pystr, idx); goto bail; } idx++; while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++; /* read any JSON term */ val = scan_once_unicode(s, pystr, idx, &next_idx); if (val == NULL) goto bail; if (has_pairs_hook) { item = PyTuple_Pack(2, key, val); if (item == NULL) goto bail; Py_CLEAR(key); Py_CLEAR(val); if (PyList_Append(pairs, item) == -1) { Py_DECREF(item); goto bail; } Py_DECREF(item); } else { if (PyDict_SetItem(rval, key, val) < 0) goto bail; Py_CLEAR(key); Py_CLEAR(val); } idx = next_idx; /* skip whitespace before } or , */ while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++; /* bail if the object is closed or we didn't get the , delimiter */ did_parse = 1; if (idx > end_idx) break; if (PyUnicode_READ(kind, str, idx) == '}') { break; } else if (PyUnicode_READ(kind, str, idx) != ',') { raise_errmsg(ERR_OBJECT_DELIMITER, pystr, idx); goto bail; } idx++; /* skip whitespace after , delimiter */ while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++; trailing_delimiter = 1; } if (trailing_delimiter) { raise_errmsg(ERR_OBJECT_PROPERTY, pystr, idx); goto bail; } } /* verify that idx < end_idx, str[idx] should be '}' */ if (idx > end_idx || PyUnicode_READ(kind, str, idx) != '}') { if (did_parse) { raise_errmsg(ERR_OBJECT_DELIMITER, pystr, idx); } else { raise_errmsg(ERR_OBJECT_PROPERTY_FIRST, pystr, idx); } goto bail; } /* if pairs_hook is not None: rval = object_pairs_hook(pairs) */ if (s->pairs_hook != Py_None) { val = PyObject_CallFunctionObjArgs(s->pairs_hook, pairs, NULL); if (val == NULL) goto bail; Py_DECREF(pairs); *next_idx_ptr = idx + 1; return val; } /* if object_hook is not None: rval = object_hook(rval) */ if (s->object_hook != Py_None) { val = PyObject_CallFunctionObjArgs(s->object_hook, rval, NULL); if (val == NULL) goto bail; Py_DECREF(rval); rval = val; val = NULL; } *next_idx_ptr = idx + 1; return rval; bail: Py_XDECREF(rval); Py_XDECREF(key); Py_XDECREF(val); Py_XDECREF(pairs); return NULL; } #if PY_MAJOR_VERSION < 3 static PyObject * _parse_array_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr) { /* Read a JSON array from PyString pystr. idx is the index of the first character after the opening brace. *next_idx_ptr is a return-by-reference index to the first character after the closing brace. Returns a new PyList */ char *str = PyString_AS_STRING(pystr); Py_ssize_t end_idx = PyString_GET_SIZE(pystr) - 1; PyObject *val = NULL; PyObject *rval = PyList_New(0); Py_ssize_t next_idx; if (rval == NULL) return NULL; /* skip whitespace after [ */ while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++; /* only loop if the array is non-empty */ if (idx <= end_idx && str[idx] != ']') { int trailing_delimiter = 0; while (idx <= end_idx) { trailing_delimiter = 0; /* read any JSON term and de-tuplefy the (rval, idx) */ val = scan_once_str(s, pystr, idx, &next_idx); if (val == NULL) { goto bail; } if (PyList_Append(rval, val) == -1) goto bail; Py_CLEAR(val); idx = next_idx; /* skip whitespace between term and , */ while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++; /* bail if the array is closed or we didn't get the , delimiter */ if (idx > end_idx) break; if (str[idx] == ']') { break; } else if (str[idx] != ',') { raise_errmsg(ERR_ARRAY_DELIMITER, pystr, idx); goto bail; } idx++; /* skip whitespace after , */ while (idx <= end_idx && IS_WHITESPACE(str[idx])) idx++; trailing_delimiter = 1; } if (trailing_delimiter) { raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx); goto bail; } } /* verify that idx < end_idx, str[idx] should be ']' */ if (idx > end_idx || str[idx] != ']') { if (PyList_GET_SIZE(rval)) { raise_errmsg(ERR_ARRAY_DELIMITER, pystr, idx); } else { raise_errmsg(ERR_ARRAY_VALUE_FIRST, pystr, idx); } goto bail; } *next_idx_ptr = idx + 1; return rval; bail: Py_XDECREF(val); Py_DECREF(rval); return NULL; } #endif /* PY_MAJOR_VERSION < 3 */ static PyObject * _parse_array_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr) { /* Read a JSON array from PyString pystr. idx is the index of the first character after the opening brace. *next_idx_ptr is a return-by-reference index to the first character after the closing brace. Returns a new PyList */ PY2_UNUSED int kind = PyUnicode_KIND(pystr); void *str = PyUnicode_DATA(pystr); Py_ssize_t end_idx = PyUnicode_GET_LENGTH(pystr) - 1; PyObject *val = NULL; PyObject *rval = PyList_New(0); Py_ssize_t next_idx; if (rval == NULL) return NULL; /* skip whitespace after [ */ while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++; /* only loop if the array is non-empty */ if (idx <= end_idx && PyUnicode_READ(kind, str, idx) != ']') { int trailing_delimiter = 0; while (idx <= end_idx) { trailing_delimiter = 0; /* read any JSON term */ val = scan_once_unicode(s, pystr, idx, &next_idx); if (val == NULL) { goto bail; } if (PyList_Append(rval, val) == -1) goto bail; Py_CLEAR(val); idx = next_idx; /* skip whitespace between term and , */ while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++; /* bail if the array is closed or we didn't get the , delimiter */ if (idx > end_idx) break; if (PyUnicode_READ(kind, str, idx) == ']') { break; } else if (PyUnicode_READ(kind, str, idx) != ',') { raise_errmsg(ERR_ARRAY_DELIMITER, pystr, idx); goto bail; } idx++; /* skip whitespace after , */ while (idx <= end_idx && IS_WHITESPACE(PyUnicode_READ(kind, str, idx))) idx++; trailing_delimiter = 1; } if (trailing_delimiter) { raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx); goto bail; } } /* verify that idx < end_idx, str[idx] should be ']' */ if (idx > end_idx || PyUnicode_READ(kind, str, idx) != ']') { if (PyList_GET_SIZE(rval)) { raise_errmsg(ERR_ARRAY_DELIMITER, pystr, idx); } else { raise_errmsg(ERR_ARRAY_VALUE_FIRST, pystr, idx); } goto bail; } *next_idx_ptr = idx + 1; return rval; bail: Py_XDECREF(val); Py_DECREF(rval); return NULL; } static PyObject * _parse_constant(PyScannerObject *s, PyObject *constant, Py_ssize_t idx, Py_ssize_t *next_idx_ptr) { /* Read a JSON constant from PyString pystr. constant is the Python string that was found ("NaN", "Infinity", "-Infinity"). idx is the index of the first character of the constant *next_idx_ptr is a return-by-reference index to the first character after the constant. Returns the result of parse_constant */ PyObject *rval; /* rval = parse_constant(constant) */ rval = PyObject_CallFunctionObjArgs(s->parse_constant, constant, NULL); idx += PyString_GET_SIZE(constant); *next_idx_ptr = idx; return rval; } #if PY_MAJOR_VERSION < 3 static PyObject * _match_number_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t start, Py_ssize_t *next_idx_ptr) { /* Read a JSON number from PyString pystr. idx is the index of the first character of the number *next_idx_ptr is a return-by-reference index to the first character after the number. Returns a new PyObject representation of that number: PyInt, PyLong, or PyFloat. May return other types if parse_int or parse_float are set */ char *str = PyString_AS_STRING(pystr); Py_ssize_t end_idx = PyString_GET_SIZE(pystr) - 1; Py_ssize_t idx = start; int is_float = 0; PyObject *rval; PyObject *numstr; /* read a sign if it's there, make sure it's not the end of the string */ if (str[idx] == '-') { if (idx >= end_idx) { raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx); return NULL; } idx++; } /* read as many integer digits as we find as long as it doesn't start with 0 */ if (str[idx] >= '1' && str[idx] <= '9') { idx++; while (idx <= end_idx && str[idx] >= '0' && str[idx] <= '9') idx++; } /* if it starts with 0 we only expect one integer digit */ else if (str[idx] == '0') { idx++; } /* no integer digits, error */ else { raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx); return NULL; } /* if the next char is '.' followed by a digit then read all float digits */ if (idx < end_idx && str[idx] == '.' && str[idx + 1] >= '0' && str[idx + 1] <= '9') { is_float = 1; idx += 2; while (idx <= end_idx && str[idx] >= '0' && str[idx] <= '9') idx++; } /* if the next char is 'e' or 'E' then maybe read the exponent (or backtrack) */ if (idx < end_idx && (str[idx] == 'e' || str[idx] == 'E')) { /* save the index of the 'e' or 'E' just in case we need to backtrack */ Py_ssize_t e_start = idx; idx++; /* read an exponent sign if present */ if (idx < end_idx && (str[idx] == '-' || str[idx] == '+')) idx++; /* read all digits */ while (idx <= end_idx && str[idx] >= '0' && str[idx] <= '9') idx++; /* if we got a digit, then parse as float. if not, backtrack */ if (str[idx - 1] >= '0' && str[idx - 1] <= '9') { is_float = 1; } else { idx = e_start; } } /* copy the section we determined to be a number */ numstr = PyString_FromStringAndSize(&str[start], idx - start); if (numstr == NULL) return NULL; if (is_float) { /* parse as a float using a fast path if available, otherwise call user defined method */ if (s->parse_float != (PyObject *)&PyFloat_Type) { rval = PyObject_CallFunctionObjArgs(s->parse_float, numstr, NULL); } else { /* rval = PyFloat_FromDouble(PyOS_ascii_atof(PyString_AS_STRING(numstr))); */ double d = PyOS_string_to_double(PyString_AS_STRING(numstr), NULL, NULL); if (d == -1.0 && PyErr_Occurred()) return NULL; rval = PyFloat_FromDouble(d); } } else { /* parse as an int using a fast path if available, otherwise call user defined method */ if (s->parse_int != (PyObject *)&PyInt_Type) { rval = PyObject_CallFunctionObjArgs(s->parse_int, numstr, NULL); } else { rval = PyInt_FromString(PyString_AS_STRING(numstr), NULL, 10); } } Py_DECREF(numstr); *next_idx_ptr = idx; return rval; } #endif /* PY_MAJOR_VERSION < 3 */ static PyObject * _match_number_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t start, Py_ssize_t *next_idx_ptr) { /* Read a JSON number from PyUnicode pystr. idx is the index of the first character of the number *next_idx_ptr is a return-by-reference index to the first character after the number. Returns a new PyObject representation of that number: PyInt, PyLong, or PyFloat. May return other types if parse_int or parse_float are set */ PY2_UNUSED int kind = PyUnicode_KIND(pystr); void *str = PyUnicode_DATA(pystr); Py_ssize_t end_idx = PyUnicode_GET_LENGTH(pystr) - 1; Py_ssize_t idx = start; int is_float = 0; JSON_UNICHR c; PyObject *rval; PyObject *numstr; /* read a sign if it's there, make sure it's not the end of the string */ if (PyUnicode_READ(kind, str, idx) == '-') { if (idx >= end_idx) { raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx); return NULL; } idx++; } /* read as many integer digits as we find as long as it doesn't start with 0 */ c = PyUnicode_READ(kind, str, idx); if (c == '0') { /* if it starts with 0 we only expect one integer digit */ idx++; } else if (IS_DIGIT(c)) { idx++; while (idx <= end_idx && IS_DIGIT(PyUnicode_READ(kind, str, idx))) { idx++; } } else { /* no integer digits, error */ raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx); return NULL; } /* if the next char is '.' followed by a digit then read all float digits */ if (idx < end_idx && PyUnicode_READ(kind, str, idx) == '.' && IS_DIGIT(PyUnicode_READ(kind, str, idx + 1))) { is_float = 1; idx += 2; while (idx <= end_idx && IS_DIGIT(PyUnicode_READ(kind, str, idx))) idx++; } /* if the next char is 'e' or 'E' then maybe read the exponent (or backtrack) */ if (idx < end_idx && (PyUnicode_READ(kind, str, idx) == 'e' || PyUnicode_READ(kind, str, idx) == 'E')) { Py_ssize_t e_start = idx; idx++; /* read an exponent sign if present */ if (idx < end_idx && (PyUnicode_READ(kind, str, idx) == '-' || PyUnicode_READ(kind, str, idx) == '+')) idx++; /* read all digits */ while (idx <= end_idx && IS_DIGIT(PyUnicode_READ(kind, str, idx))) idx++; /* if we got a digit, then parse as float. if not, backtrack */ if (IS_DIGIT(PyUnicode_READ(kind, str, idx - 1))) { is_float = 1; } else { idx = e_start; } } /* copy the section we determined to be a number */ #if PY_MAJOR_VERSION >= 3 numstr = PyUnicode_Substring(pystr, start, idx); #else numstr = PyUnicode_FromUnicode(&((Py_UNICODE *)str)[start], idx - start); #endif if (numstr == NULL) return NULL; if (is_float) { /* parse as a float using a fast path if available, otherwise call user defined method */ if (s->parse_float != (PyObject *)&PyFloat_Type) { rval = PyObject_CallFunctionObjArgs(s->parse_float, numstr, NULL); } else { #if PY_MAJOR_VERSION >= 3 rval = PyFloat_FromString(numstr); #else rval = PyFloat_FromString(numstr, NULL); #endif } } else { /* no fast path for unicode -> int, just call */ rval = PyObject_CallFunctionObjArgs(s->parse_int, numstr, NULL); } Py_DECREF(numstr); *next_idx_ptr = idx; return rval; } #if PY_MAJOR_VERSION < 3 static PyObject * scan_once_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr) { /* Read one JSON term (of any kind) from PyString pystr. idx is the index of the first character of the term *next_idx_ptr is a return-by-reference index to the first character after the number. Returns a new PyObject representation of the term. */ char *str = PyString_AS_STRING(pystr); Py_ssize_t length = PyString_GET_SIZE(pystr); PyObject *rval = NULL; int fallthrough = 0; if (idx < 0 || idx >= length) { raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx); return NULL; } switch (str[idx]) { case '"': /* string */ rval = scanstring_str(pystr, idx + 1, PyString_AS_STRING(s->encoding), s->strict, next_idx_ptr); break; case '{': /* object */ if (Py_EnterRecursiveCall(" while decoding a JSON object " "from a string")) return NULL; rval = _parse_object_str(s, pystr, idx + 1, next_idx_ptr); Py_LeaveRecursiveCall(); break; case '[': /* array */ if (Py_EnterRecursiveCall(" while decoding a JSON array " "from a string")) return NULL; rval = _parse_array_str(s, pystr, idx + 1, next_idx_ptr); Py_LeaveRecursiveCall(); break; case 'n': /* null */ if ((idx + 3 < length) && str[idx + 1] == 'u' && str[idx + 2] == 'l' && str[idx + 3] == 'l') { Py_INCREF(Py_None); *next_idx_ptr = idx + 4; rval = Py_None; } else fallthrough = 1; break; case 't': /* true */ if ((idx + 3 < length) && str[idx + 1] == 'r' && str[idx + 2] == 'u' && str[idx + 3] == 'e') { Py_INCREF(Py_True); *next_idx_ptr = idx + 4; rval = Py_True; } else fallthrough = 1; break; case 'f': /* false */ if ((idx + 4 < length) && str[idx + 1] == 'a' && str[idx + 2] == 'l' && str[idx + 3] == 's' && str[idx + 4] == 'e') { Py_INCREF(Py_False); *next_idx_ptr = idx + 5; rval = Py_False; } else fallthrough = 1; break; case 'N': /* NaN */ if ((idx + 2 < length) && str[idx + 1] == 'a' && str[idx + 2] == 'N') { rval = _parse_constant(s, JSON_NaN, idx, next_idx_ptr); } else fallthrough = 1; break; case 'I': /* Infinity */ if ((idx + 7 < length) && str[idx + 1] == 'n' && str[idx + 2] == 'f' && str[idx + 3] == 'i' && str[idx + 4] == 'n' && str[idx + 5] == 'i' && str[idx + 6] == 't' && str[idx + 7] == 'y') { rval = _parse_constant(s, JSON_Infinity, idx, next_idx_ptr); } else fallthrough = 1; break; case '-': /* -Infinity */ if ((idx + 8 < length) && str[idx + 1] == 'I' && str[idx + 2] == 'n' && str[idx + 3] == 'f' && str[idx + 4] == 'i' && str[idx + 5] == 'n' && str[idx + 6] == 'i' && str[idx + 7] == 't' && str[idx + 8] == 'y') { rval = _parse_constant(s, JSON_NegInfinity, idx, next_idx_ptr); } else fallthrough = 1; break; default: fallthrough = 1; } /* Didn't find a string, object, array, or named constant. Look for a number. */ if (fallthrough) rval = _match_number_str(s, pystr, idx, next_idx_ptr); return rval; } #endif /* PY_MAJOR_VERSION < 3 */ static PyObject * scan_once_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *next_idx_ptr) { /* Read one JSON term (of any kind) from PyUnicode pystr. idx is the index of the first character of the term *next_idx_ptr is a return-by-reference index to the first character after the number. Returns a new PyObject representation of the term. */ PY2_UNUSED int kind = PyUnicode_KIND(pystr); void *str = PyUnicode_DATA(pystr); Py_ssize_t length = PyUnicode_GET_LENGTH(pystr); PyObject *rval = NULL; int fallthrough = 0; if (idx < 0 || idx >= length) { raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx); return NULL; } switch (PyUnicode_READ(kind, str, idx)) { case '"': /* string */ rval = scanstring_unicode(pystr, idx + 1, s->strict, next_idx_ptr); break; case '{': /* object */ if (Py_EnterRecursiveCall(" while decoding a JSON object " "from a unicode string")) return NULL; rval = _parse_object_unicode(s, pystr, idx + 1, next_idx_ptr); Py_LeaveRecursiveCall(); break; case '[': /* array */ if (Py_EnterRecursiveCall(" while decoding a JSON array " "from a unicode string")) return NULL; rval = _parse_array_unicode(s, pystr, idx + 1, next_idx_ptr); Py_LeaveRecursiveCall(); break; case 'n': /* null */ if ((idx + 3 < length) && PyUnicode_READ(kind, str, idx + 1) == 'u' && PyUnicode_READ(kind, str, idx + 2) == 'l' && PyUnicode_READ(kind, str, idx + 3) == 'l') { Py_INCREF(Py_None); *next_idx_ptr = idx + 4; rval = Py_None; } else fallthrough = 1; break; case 't': /* true */ if ((idx + 3 < length) && PyUnicode_READ(kind, str, idx + 1) == 'r' && PyUnicode_READ(kind, str, idx + 2) == 'u' && PyUnicode_READ(kind, str, idx + 3) == 'e') { Py_INCREF(Py_True); *next_idx_ptr = idx + 4; rval = Py_True; } else fallthrough = 1; break; case 'f': /* false */ if ((idx + 4 < length) && PyUnicode_READ(kind, str, idx + 1) == 'a' && PyUnicode_READ(kind, str, idx + 2) == 'l' && PyUnicode_READ(kind, str, idx + 3) == 's' && PyUnicode_READ(kind, str, idx + 4) == 'e') { Py_INCREF(Py_False); *next_idx_ptr = idx + 5; rval = Py_False; } else fallthrough = 1; break; case 'N': /* NaN */ if ((idx + 2 < length) && PyUnicode_READ(kind, str, idx + 1) == 'a' && PyUnicode_READ(kind, str, idx + 2) == 'N') { rval = _parse_constant(s, JSON_NaN, idx, next_idx_ptr); } else fallthrough = 1; break; case 'I': /* Infinity */ if ((idx + 7 < length) && PyUnicode_READ(kind, str, idx + 1) == 'n' && PyUnicode_READ(kind, str, idx + 2) == 'f' && PyUnicode_READ(kind, str, idx + 3) == 'i' && PyUnicode_READ(kind, str, idx + 4) == 'n' && PyUnicode_READ(kind, str, idx + 5) == 'i' && PyUnicode_READ(kind, str, idx + 6) == 't' && PyUnicode_READ(kind, str, idx + 7) == 'y') { rval = _parse_constant(s, JSON_Infinity, idx, next_idx_ptr); } else fallthrough = 1; break; case '-': /* -Infinity */ if ((idx + 8 < length) && PyUnicode_READ(kind, str, idx + 1) == 'I' && PyUnicode_READ(kind, str, idx + 2) == 'n' && PyUnicode_READ(kind, str, idx + 3) == 'f' && PyUnicode_READ(kind, str, idx + 4) == 'i' && PyUnicode_READ(kind, str, idx + 5) == 'n' && PyUnicode_READ(kind, str, idx + 6) == 'i' && PyUnicode_READ(kind, str, idx + 7) == 't' && PyUnicode_READ(kind, str, idx + 8) == 'y') { rval = _parse_constant(s, JSON_NegInfinity, idx, next_idx_ptr); } else fallthrough = 1; break; default: fallthrough = 1; } /* Didn't find a string, object, array, or named constant. Look for a number. */ if (fallthrough) rval = _match_number_unicode(s, pystr, idx, next_idx_ptr); return rval; } static PyObject * scanner_call(PyObject *self, PyObject *args, PyObject *kwds) { /* Python callable interface to scan_once_{str,unicode} */ PyObject *pystr; PyObject *rval; Py_ssize_t idx; Py_ssize_t next_idx = -1; static char *kwlist[] = {"string", "idx", NULL}; PyScannerObject *s; assert(PyScanner_Check(self)); s = (PyScannerObject *)self; if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO&:scan_once", kwlist, &pystr, _convertPyInt_AsSsize_t, &idx)) return NULL; if (PyUnicode_Check(pystr)) { if (PyUnicode_READY(pystr)) return NULL; rval = scan_once_unicode(s, pystr, idx, &next_idx); } #if PY_MAJOR_VERSION < 3 else if (PyString_Check(pystr)) { rval = scan_once_str(s, pystr, idx, &next_idx); } #endif /* PY_MAJOR_VERSION < 3 */ else { PyErr_Format(PyExc_TypeError, "first argument must be a string, not %.80s", Py_TYPE(pystr)->tp_name); return NULL; } PyDict_Clear(s->memo); return _build_rval_index_tuple(rval, next_idx); } static PyObject * JSON_ParseEncoding(PyObject *encoding) { if (encoding == Py_None) return JSON_InternFromString(DEFAULT_ENCODING); #if PY_MAJOR_VERSION >= 3 if (PyUnicode_Check(encoding)) { if (PyUnicode_AsUTF8(encoding) == NULL) { return NULL; } Py_INCREF(encoding); return encoding; } #else /* PY_MAJOR_VERSION >= 3 */ if (PyString_Check(encoding)) { Py_INCREF(encoding); return encoding; } if (PyUnicode_Check(encoding)) return PyUnicode_AsEncodedString(encoding, NULL, NULL); #endif /* PY_MAJOR_VERSION >= 3 */ PyErr_SetString(PyExc_TypeError, "encoding must be a string"); return NULL; } static PyObject * scanner_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { /* Initialize Scanner object */ PyObject *ctx; static char *kwlist[] = {"context", NULL}; PyScannerObject *s; PyObject *encoding; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:make_scanner", kwlist, &ctx)) return NULL; s = (PyScannerObject *)type->tp_alloc(type, 0); if (s == NULL) return NULL; if (s->memo == NULL) { s->memo = PyDict_New(); if (s->memo == NULL) goto bail; } encoding = PyObject_GetAttrString(ctx, "encoding"); if (encoding == NULL) goto bail; s->encoding = JSON_ParseEncoding(encoding); Py_XDECREF(encoding); if (s->encoding == NULL) goto bail; /* All of these will fail "gracefully" so we don't need to verify them */ s->strict_bool = PyObject_GetAttrString(ctx, "strict"); if (s->strict_bool == NULL) goto bail; s->strict = PyObject_IsTrue(s->strict_bool); if (s->strict < 0) goto bail; s->object_hook = PyObject_GetAttrString(ctx, "object_hook"); if (s->object_hook == NULL) goto bail; s->pairs_hook = PyObject_GetAttrString(ctx, "object_pairs_hook"); if (s->pairs_hook == NULL) goto bail; s->parse_float = PyObject_GetAttrString(ctx, "parse_float"); if (s->parse_float == NULL) goto bail; s->parse_int = PyObject_GetAttrString(ctx, "parse_int"); if (s->parse_int == NULL) goto bail; s->parse_constant = PyObject_GetAttrString(ctx, "parse_constant"); if (s->parse_constant == NULL) goto bail; return (PyObject *)s; bail: Py_DECREF(s); return NULL; } PyDoc_STRVAR(scanner_doc, "JSON scanner object"); static PyTypeObject PyScannerType = { PyVarObject_HEAD_INIT(NULL, 0) "simplejson._speedups.Scanner", /* tp_name */ sizeof(PyScannerObject), /* tp_basicsize */ 0, /* tp_itemsize */ scanner_dealloc, /* tp_dealloc */ 0, /* tp_print */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_compare */ 0, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ scanner_call, /* tp_call */ 0, /* tp_str */ 0,/* PyObject_GenericGetAttr, */ /* tp_getattro */ 0,/* PyObject_GenericSetAttr, */ /* tp_setattro */ 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */ scanner_doc, /* tp_doc */ scanner_traverse, /* tp_traverse */ scanner_clear, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ 0, /* tp_methods */ scanner_members, /* tp_members */ 0, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ 0, /* tp_init */ 0,/* PyType_GenericAlloc, */ /* tp_alloc */ scanner_new, /* tp_new */ 0,/* PyObject_GC_Del, */ /* tp_free */ }; static PyObject * encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { static char *kwlist[] = { "markers", "default", "encoder", "indent", "key_separator", "item_separator", "sort_keys", "skipkeys", "allow_nan", "key_memo", "use_decimal", "namedtuple_as_object", "tuple_as_array", "int_as_string_bitcount", "item_sort_key", "encoding", "for_json", "ignore_nan", "Decimal", "iterable_as_array", NULL}; PyEncoderObject *s; PyObject *markers, *defaultfn, *encoder, *indent, *key_separator; PyObject *item_separator, *sort_keys, *skipkeys, *allow_nan, *key_memo; PyObject *use_decimal, *namedtuple_as_object, *tuple_as_array, *iterable_as_array; PyObject *int_as_string_bitcount, *item_sort_key, *encoding, *for_json; PyObject *ignore_nan, *Decimal; int is_true; if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOOOOOOOOOOOOOOOO:make_encoder", kwlist, &markers, &defaultfn, &encoder, &indent, &key_separator, &item_separator, &sort_keys, &skipkeys, &allow_nan, &key_memo, &use_decimal, &namedtuple_as_object, &tuple_as_array, &int_as_string_bitcount, &item_sort_key, &encoding, &for_json, &ignore_nan, &Decimal, &iterable_as_array)) return NULL; s = (PyEncoderObject *)type->tp_alloc(type, 0); if (s == NULL) return NULL; Py_INCREF(markers); s->markers = markers; Py_INCREF(defaultfn); s->defaultfn = defaultfn; Py_INCREF(encoder); s->encoder = encoder; #if PY_MAJOR_VERSION >= 3 if (encoding == Py_None) { s->encoding = NULL; } else #endif /* PY_MAJOR_VERSION >= 3 */ { s->encoding = JSON_ParseEncoding(encoding); if (s->encoding == NULL) goto bail; } Py_INCREF(indent); s->indent = indent; Py_INCREF(key_separator); s->key_separator = key_separator; Py_INCREF(item_separator); s->item_separator = item_separator; Py_INCREF(skipkeys); s->skipkeys_bool = skipkeys; s->skipkeys = PyObject_IsTrue(skipkeys); if (s->skipkeys < 0) goto bail; Py_INCREF(key_memo); s->key_memo = key_memo; s->fast_encode = (PyCFunction_Check(s->encoder) && PyCFunction_GetFunction(s->encoder) == (PyCFunction)py_encode_basestring_ascii); is_true = PyObject_IsTrue(ignore_nan); if (is_true < 0) goto bail; s->allow_or_ignore_nan = is_true ? JSON_IGNORE_NAN : 0; is_true = PyObject_IsTrue(allow_nan); if (is_true < 0) goto bail; s->allow_or_ignore_nan |= is_true ? JSON_ALLOW_NAN : 0; s->use_decimal = PyObject_IsTrue(use_decimal); if (s->use_decimal < 0) goto bail; s->namedtuple_as_object = PyObject_IsTrue(namedtuple_as_object); if (s->namedtuple_as_object < 0) goto bail; s->tuple_as_array = PyObject_IsTrue(tuple_as_array); if (s->tuple_as_array < 0) goto bail; s->iterable_as_array = PyObject_IsTrue(iterable_as_array); if (s->iterable_as_array < 0) goto bail; if (PyInt_Check(int_as_string_bitcount) || PyLong_Check(int_as_string_bitcount)) { static const unsigned long long_long_bitsize = SIZEOF_LONG_LONG * 8; long int_as_string_bitcount_val = PyLong_AsLong(int_as_string_bitcount); if (int_as_string_bitcount_val > 0 && int_as_string_bitcount_val < (long)long_long_bitsize) { s->max_long_size = PyLong_FromUnsignedLongLong(1ULL << (int)int_as_string_bitcount_val); s->min_long_size = PyLong_FromLongLong(-1LL << (int)int_as_string_bitcount_val); if (s->min_long_size == NULL || s->max_long_size == NULL) { goto bail; } } else { PyErr_Format(PyExc_TypeError, "int_as_string_bitcount (%ld) must be greater than 0 and less than the number of bits of a `long long` type (%lu bits)", int_as_string_bitcount_val, long_long_bitsize); goto bail; } } else if (int_as_string_bitcount == Py_None) { Py_INCREF(Py_None); s->max_long_size = Py_None; Py_INCREF(Py_None); s->min_long_size = Py_None; } else { PyErr_SetString(PyExc_TypeError, "int_as_string_bitcount must be None or an integer"); goto bail; } if (item_sort_key != Py_None) { if (!PyCallable_Check(item_sort_key)) { PyErr_SetString(PyExc_TypeError, "item_sort_key must be None or callable"); goto bail; } } else { is_true = PyObject_IsTrue(sort_keys); if (is_true < 0) goto bail; if (is_true) { static PyObject *itemgetter0 = NULL; if (!itemgetter0) { PyObject *operator = PyImport_ImportModule("operator"); if (!operator) goto bail; itemgetter0 = PyObject_CallMethod(operator, "itemgetter", "i", 0); Py_DECREF(operator); } item_sort_key = itemgetter0; if (!item_sort_key) goto bail; } } if (item_sort_key == Py_None) { Py_INCREF(Py_None); s->item_sort_kw = Py_None; } else { s->item_sort_kw = PyDict_New(); if (s->item_sort_kw == NULL) goto bail; if (PyDict_SetItemString(s->item_sort_kw, "key", item_sort_key)) goto bail; } Py_INCREF(sort_keys); s->sort_keys = sort_keys; Py_INCREF(item_sort_key); s->item_sort_key = item_sort_key; Py_INCREF(Decimal); s->Decimal = Decimal; s->for_json = PyObject_IsTrue(for_json); if (s->for_json < 0) goto bail; return (PyObject *)s; bail: Py_DECREF(s); return NULL; } static PyObject * encoder_call(PyObject *self, PyObject *args, PyObject *kwds) { /* Python callable interface to encode_listencode_obj */ static char *kwlist[] = {"obj", "_current_indent_level", NULL}; PyObject *obj; Py_ssize_t indent_level; PyEncoderObject *s; JSON_Accu rval; assert(PyEncoder_Check(self)); s = (PyEncoderObject *)self; if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO&:_iterencode", kwlist, &obj, _convertPyInt_AsSsize_t, &indent_level)) return NULL; if (JSON_Accu_Init(&rval)) return NULL; if (encoder_listencode_obj(s, &rval, obj, indent_level)) { JSON_Accu_Destroy(&rval); return NULL; } return JSON_Accu_FinishAsList(&rval); } static PyObject * _encoded_const(PyObject *obj) { /* Return the JSON string representation of None, True, False */ if (obj == Py_None) { static PyObject *s_null = NULL; if (s_null == NULL) { s_null = JSON_InternFromString("null"); } Py_INCREF(s_null); return s_null; } else if (obj == Py_True) { static PyObject *s_true = NULL; if (s_true == NULL) { s_true = JSON_InternFromString("true"); } Py_INCREF(s_true); return s_true; } else if (obj == Py_False) { static PyObject *s_false = NULL; if (s_false == NULL) { s_false = JSON_InternFromString("false"); } Py_INCREF(s_false); return s_false; } else { PyErr_SetString(PyExc_ValueError, "not a const"); return NULL; } } static PyObject * encoder_encode_float(PyEncoderObject *s, PyObject *obj) { /* Return the JSON representation of a PyFloat */ double i = PyFloat_AS_DOUBLE(obj); if (!Py_IS_FINITE(i)) { if (!s->allow_or_ignore_nan) { PyErr_SetString(PyExc_ValueError, "Out of range float values are not JSON compliant"); return NULL; } if (s->allow_or_ignore_nan & JSON_IGNORE_NAN) { return _encoded_const(Py_None); } /* JSON_ALLOW_NAN is set */ else if (i > 0) { Py_INCREF(JSON_Infinity); return JSON_Infinity; } else if (i < 0) { Py_INCREF(JSON_NegInfinity); return JSON_NegInfinity; } else { Py_INCREF(JSON_NaN); return JSON_NaN; } } /* Use a better float format here? */ if (PyFloat_CheckExact(obj)) { return PyObject_Repr(obj); } else { /* See #118, do not trust custom str/repr */ PyObject *res; PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyFloat_Type, obj, NULL); if (tmp == NULL) { return NULL; } res = PyObject_Repr(tmp); Py_DECREF(tmp); return res; } } static PyObject * encoder_encode_string(PyEncoderObject *s, PyObject *obj) { /* Return the JSON representation of a string */ PyObject *encoded; if (s->fast_encode) { return py_encode_basestring_ascii(NULL, obj); } encoded = PyObject_CallFunctionObjArgs(s->encoder, obj, NULL); if (encoded != NULL && #if PY_MAJOR_VERSION < 3 !PyString_Check(encoded) && #endif /* PY_MAJOR_VERSION < 3 */ !PyUnicode_Check(encoded)) { PyErr_Format(PyExc_TypeError, "encoder() must return a string, not %.80s", Py_TYPE(encoded)->tp_name); Py_DECREF(encoded); return NULL; } return encoded; } static int _steal_accumulate(JSON_Accu *accu, PyObject *stolen) { /* Append stolen and then decrement its reference count */ int rval = JSON_Accu_Accumulate(accu, stolen); Py_DECREF(stolen); return rval; } static int encoder_listencode_obj(PyEncoderObject *s, JSON_Accu *rval, PyObject *obj, Py_ssize_t indent_level) { /* Encode Python object obj to a JSON term, rval is a PyList */ int rv = -1; do { if (obj == Py_None || obj == Py_True || obj == Py_False) { PyObject *cstr = _encoded_const(obj); if (cstr != NULL) rv = _steal_accumulate(rval, cstr); } else if ((PyBytes_Check(obj) && s->encoding != NULL) || PyUnicode_Check(obj)) { PyObject *encoded = encoder_encode_string(s, obj); if (encoded != NULL) rv = _steal_accumulate(rval, encoded); } else if (PyInt_Check(obj) || PyLong_Check(obj)) { PyObject *encoded; if (PyInt_CheckExact(obj) || PyLong_CheckExact(obj)) { encoded = PyObject_Str(obj); } else { /* See #118, do not trust custom str/repr */ PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyLong_Type, obj, NULL); if (tmp == NULL) { encoded = NULL; } else { encoded = PyObject_Str(tmp); Py_DECREF(tmp); } } if (encoded != NULL) { encoded = maybe_quote_bigint(s, encoded, obj); if (encoded == NULL) break; rv = _steal_accumulate(rval, encoded); } } else if (PyFloat_Check(obj)) { PyObject *encoded = encoder_encode_float(s, obj); if (encoded != NULL) rv = _steal_accumulate(rval, encoded); } else if (s->for_json && _has_for_json_hook(obj)) { PyObject *newobj; if (Py_EnterRecursiveCall(" while encoding a JSON object")) return rv; newobj = PyObject_CallMethod(obj, "for_json", NULL); if (newobj != NULL) { rv = encoder_listencode_obj(s, rval, newobj, indent_level); Py_DECREF(newobj); } Py_LeaveRecursiveCall(); } else if (s->namedtuple_as_object && _is_namedtuple(obj)) { PyObject *newobj; if (Py_EnterRecursiveCall(" while encoding a JSON object")) return rv; newobj = PyObject_CallMethod(obj, "_asdict", NULL); if (newobj != NULL) { rv = encoder_listencode_dict(s, rval, newobj, indent_level); Py_DECREF(newobj); } Py_LeaveRecursiveCall(); } else if (PyList_Check(obj) || (s->tuple_as_array && PyTuple_Check(obj))) { if (Py_EnterRecursiveCall(" while encoding a JSON object")) return rv; rv = encoder_listencode_list(s, rval, obj, indent_level); Py_LeaveRecursiveCall(); } else if (PyDict_Check(obj)) { if (Py_EnterRecursiveCall(" while encoding a JSON object")) return rv; rv = encoder_listencode_dict(s, rval, obj, indent_level); Py_LeaveRecursiveCall(); } else if (s->use_decimal && PyObject_TypeCheck(obj, (PyTypeObject *)s->Decimal)) { PyObject *encoded = PyObject_Str(obj); if (encoded != NULL) rv = _steal_accumulate(rval, encoded); } else if (is_raw_json(obj)) { PyObject *encoded = PyObject_GetAttrString(obj, "encoded_json"); if (encoded != NULL) rv = _steal_accumulate(rval, encoded); } else { PyObject *ident = NULL; PyObject *newobj; if (s->iterable_as_array) { newobj = PyObject_GetIter(obj); if (newobj == NULL) PyErr_Clear(); else { rv = encoder_listencode_list(s, rval, newobj, indent_level); Py_DECREF(newobj); break; } } if (s->markers != Py_None) { int has_key; ident = PyLong_FromVoidPtr(obj); if (ident == NULL) break; has_key = PyDict_Contains(s->markers, ident); if (has_key) { if (has_key != -1) PyErr_SetString(PyExc_ValueError, "Circular reference detected"); Py_DECREF(ident); break; } if (PyDict_SetItem(s->markers, ident, obj)) { Py_DECREF(ident); break; } } if (Py_EnterRecursiveCall(" while encoding a JSON object")) return rv; newobj = PyObject_CallFunctionObjArgs(s->defaultfn, obj, NULL); if (newobj == NULL) { Py_XDECREF(ident); Py_LeaveRecursiveCall(); break; } rv = encoder_listencode_obj(s, rval, newobj, indent_level); Py_LeaveRecursiveCall(); Py_DECREF(newobj); if (rv) { Py_XDECREF(ident); rv = -1; } else if (ident != NULL) { if (PyDict_DelItem(s->markers, ident)) { Py_XDECREF(ident); rv = -1; } Py_XDECREF(ident); } } } while (0); return rv; } static int encoder_listencode_dict(PyEncoderObject *s, JSON_Accu *rval, PyObject *dct, Py_ssize_t indent_level) { /* Encode Python dict dct a JSON term */ static PyObject *open_dict = NULL; static PyObject *close_dict = NULL; static PyObject *empty_dict = NULL; PyObject *kstr = NULL; PyObject *ident = NULL; PyObject *iter = NULL; PyObject *item = NULL; PyObject *items = NULL; PyObject *encoded = NULL; Py_ssize_t idx; if (open_dict == NULL || close_dict == NULL || empty_dict == NULL) { open_dict = JSON_InternFromString("{"); close_dict = JSON_InternFromString("}"); empty_dict = JSON_InternFromString("{}"); if (open_dict == NULL || close_dict == NULL || empty_dict == NULL) return -1; } if (PyDict_Size(dct) == 0) return JSON_Accu_Accumulate(rval, empty_dict); if (s->markers != Py_None) { int has_key; ident = PyLong_FromVoidPtr(dct); if (ident == NULL) goto bail; has_key = PyDict_Contains(s->markers, ident); if (has_key) { if (has_key != -1) PyErr_SetString(PyExc_ValueError, "Circular reference detected"); goto bail; } if (PyDict_SetItem(s->markers, ident, dct)) { goto bail; } } if (JSON_Accu_Accumulate(rval, open_dict)) goto bail; if (s->indent != Py_None) { /* TODO: DOES NOT RUN */ indent_level += 1; /* newline_indent = '\n' + (_indent * _current_indent_level) separator = _item_separator + newline_indent buf += newline_indent */ } iter = encoder_dict_iteritems(s, dct); if (iter == NULL) goto bail; idx = 0; while ((item = PyIter_Next(iter))) { PyObject *encoded, *key, *value; if (!PyTuple_Check(item) || Py_SIZE(item) != 2) { PyErr_SetString(PyExc_ValueError, "items must return 2-tuples"); goto bail; } key = PyTuple_GET_ITEM(item, 0); if (key == NULL) goto bail; value = PyTuple_GET_ITEM(item, 1); if (value == NULL) goto bail; encoded = PyDict_GetItem(s->key_memo, key); if (encoded != NULL) { Py_INCREF(encoded); } else { kstr = encoder_stringify_key(s, key); if (kstr == NULL) goto bail; else if (kstr == Py_None) { /* skipkeys */ Py_DECREF(item); Py_DECREF(kstr); continue; } } if (idx) { if (JSON_Accu_Accumulate(rval, s->item_separator)) goto bail; } if (encoded == NULL) { encoded = encoder_encode_string(s, kstr); Py_CLEAR(kstr); if (encoded == NULL) goto bail; if (PyDict_SetItem(s->key_memo, key, encoded)) goto bail; } if (JSON_Accu_Accumulate(rval, encoded)) { goto bail; } Py_CLEAR(encoded); if (JSON_Accu_Accumulate(rval, s->key_separator)) goto bail; if (encoder_listencode_obj(s, rval, value, indent_level)) goto bail; Py_CLEAR(item); idx += 1; } Py_CLEAR(iter); if (PyErr_Occurred()) goto bail; if (ident != NULL) { if (PyDict_DelItem(s->markers, ident)) goto bail; Py_CLEAR(ident); } if (s->indent != Py_None) { /* TODO: DOES NOT RUN */ indent_level -= 1; /* yield '\n' + (_indent * _current_indent_level) */ } if (JSON_Accu_Accumulate(rval, close_dict)) goto bail; return 0; bail: Py_XDECREF(encoded); Py_XDECREF(items); Py_XDECREF(item); Py_XDECREF(iter); Py_XDECREF(kstr); Py_XDECREF(ident); return -1; } static int encoder_listencode_list(PyEncoderObject *s, JSON_Accu *rval, PyObject *seq, Py_ssize_t indent_level) { /* Encode Python list seq to a JSON term */ static PyObject *open_array = NULL; static PyObject *close_array = NULL; static PyObject *empty_array = NULL; PyObject *ident = NULL; PyObject *iter = NULL; PyObject *obj = NULL; int is_true; int i = 0; if (open_array == NULL || close_array == NULL || empty_array == NULL) { open_array = JSON_InternFromString("["); close_array = JSON_InternFromString("]"); empty_array = JSON_InternFromString("[]"); if (open_array == NULL || close_array == NULL || empty_array == NULL) return -1; } ident = NULL; is_true = PyObject_IsTrue(seq); if (is_true == -1) return -1; else if (is_true == 0) return JSON_Accu_Accumulate(rval, empty_array); if (s->markers != Py_None) { int has_key; ident = PyLong_FromVoidPtr(seq); if (ident == NULL) goto bail; has_key = PyDict_Contains(s->markers, ident); if (has_key) { if (has_key != -1) PyErr_SetString(PyExc_ValueError, "Circular reference detected"); goto bail; } if (PyDict_SetItem(s->markers, ident, seq)) { goto bail; } } iter = PyObject_GetIter(seq); if (iter == NULL) goto bail; if (JSON_Accu_Accumulate(rval, open_array)) goto bail; if (s->indent != Py_None) { /* TODO: DOES NOT RUN */ indent_level += 1; /* newline_indent = '\n' + (_indent * _current_indent_level) separator = _item_separator + newline_indent buf += newline_indent */ } while ((obj = PyIter_Next(iter))) { if (i) { if (JSON_Accu_Accumulate(rval, s->item_separator)) goto bail; } if (encoder_listencode_obj(s, rval, obj, indent_level)) goto bail; i++; Py_CLEAR(obj); } Py_CLEAR(iter); if (PyErr_Occurred()) goto bail; if (ident != NULL) { if (PyDict_DelItem(s->markers, ident)) goto bail; Py_CLEAR(ident); } if (s->indent != Py_None) { /* TODO: DOES NOT RUN */ indent_level -= 1; /* yield '\n' + (_indent * _current_indent_level) */ } if (JSON_Accu_Accumulate(rval, close_array)) goto bail; return 0; bail: Py_XDECREF(obj); Py_XDECREF(iter); Py_XDECREF(ident); return -1; } static void encoder_dealloc(PyObject *self) { /* bpo-31095: UnTrack is needed before calling any callbacks */ PyObject_GC_UnTrack(self); encoder_clear(self); Py_TYPE(self)->tp_free(self); } static int encoder_traverse(PyObject *self, visitproc visit, void *arg) { PyEncoderObject *s; assert(PyEncoder_Check(self)); s = (PyEncoderObject *)self; Py_VISIT(s->markers); Py_VISIT(s->defaultfn); Py_VISIT(s->encoder); Py_VISIT(s->encoding); Py_VISIT(s->indent); Py_VISIT(s->key_separator); Py_VISIT(s->item_separator); Py_VISIT(s->key_memo); Py_VISIT(s->sort_keys); Py_VISIT(s->item_sort_kw); Py_VISIT(s->item_sort_key); Py_VISIT(s->max_long_size); Py_VISIT(s->min_long_size); Py_VISIT(s->Decimal); return 0; } static int encoder_clear(PyObject *self) { /* Deallocate Encoder */ PyEncoderObject *s; assert(PyEncoder_Check(self)); s = (PyEncoderObject *)self; Py_CLEAR(s->markers); Py_CLEAR(s->defaultfn); Py_CLEAR(s->encoder); Py_CLEAR(s->encoding); Py_CLEAR(s->indent); Py_CLEAR(s->key_separator); Py_CLEAR(s->item_separator); Py_CLEAR(s->key_memo); Py_CLEAR(s->skipkeys_bool); Py_CLEAR(s->sort_keys); Py_CLEAR(s->item_sort_kw); Py_CLEAR(s->item_sort_key); Py_CLEAR(s->max_long_size); Py_CLEAR(s->min_long_size); Py_CLEAR(s->Decimal); return 0; } PyDoc_STRVAR(encoder_doc, "_iterencode(obj, _current_indent_level) -> iterable"); static PyTypeObject PyEncoderType = { PyVarObject_HEAD_INIT(NULL, 0) "simplejson._speedups.Encoder", /* tp_name */ sizeof(PyEncoderObject), /* tp_basicsize */ 0, /* tp_itemsize */ encoder_dealloc, /* tp_dealloc */ 0, /* tp_print */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_compare */ 0, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ encoder_call, /* tp_call */ 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */ encoder_doc, /* tp_doc */ encoder_traverse, /* tp_traverse */ encoder_clear, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ 0, /* tp_methods */ encoder_members, /* tp_members */ 0, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ 0, /* tp_init */ 0, /* tp_alloc */ encoder_new, /* tp_new */ 0, /* tp_free */ }; static PyMethodDef speedups_methods[] = { {"encode_basestring_ascii", (PyCFunction)py_encode_basestring_ascii, METH_O, pydoc_encode_basestring_ascii}, {"scanstring", (PyCFunction)py_scanstring, METH_VARARGS, pydoc_scanstring}, {NULL, NULL, 0, NULL} }; PyDoc_STRVAR(module_doc, "simplejson speedups\n"); #if PY_MAJOR_VERSION >= 3 static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, "_speedups", /* m_name */ module_doc, /* m_doc */ -1, /* m_size */ speedups_methods, /* m_methods */ NULL, /* m_reload */ NULL, /* m_traverse */ NULL, /* m_clear*/ NULL, /* m_free */ }; #endif PyObject * import_dependency(char *module_name, char *attr_name) { PyObject *rval; PyObject *module = PyImport_ImportModule(module_name); if (module == NULL) return NULL; rval = PyObject_GetAttrString(module, attr_name); Py_DECREF(module); return rval; } static int init_constants(void) { JSON_NaN = JSON_InternFromString("NaN"); if (JSON_NaN == NULL) return 0; JSON_Infinity = JSON_InternFromString("Infinity"); if (JSON_Infinity == NULL) return 0; JSON_NegInfinity = JSON_InternFromString("-Infinity"); if (JSON_NegInfinity == NULL) return 0; #if PY_MAJOR_VERSION >= 3 JSON_EmptyUnicode = PyUnicode_New(0, 127); #else /* PY_MAJOR_VERSION >= 3 */ JSON_EmptyStr = PyString_FromString(""); if (JSON_EmptyStr == NULL) return 0; JSON_EmptyUnicode = PyUnicode_FromUnicode(NULL, 0); #endif /* PY_MAJOR_VERSION >= 3 */ if (JSON_EmptyUnicode == NULL) return 0; return 1; } static PyObject * moduleinit(void) { PyObject *m; if (PyType_Ready(&PyScannerType) < 0) return NULL; if (PyType_Ready(&PyEncoderType) < 0) return NULL; if (!init_constants()) return NULL; #if PY_MAJOR_VERSION >= 3 m = PyModule_Create(&moduledef); #else m = Py_InitModule3("_speedups", speedups_methods, module_doc); #endif Py_INCREF((PyObject*)&PyScannerType); PyModule_AddObject(m, "make_scanner", (PyObject*)&PyScannerType); Py_INCREF((PyObject*)&PyEncoderType); PyModule_AddObject(m, "make_encoder", (PyObject*)&PyEncoderType); RawJSONType = import_dependency("simplejson.raw_json", "RawJSON"); if (RawJSONType == NULL) return NULL; JSONDecodeError = import_dependency("simplejson.errors", "JSONDecodeError"); if (JSONDecodeError == NULL) return NULL; return m; } #if PY_MAJOR_VERSION >= 3 PyMODINIT_FUNC PyInit__speedups(void) { return moduleinit(); } #else void init_speedups(void) { moduleinit(); } #endif ================================================ FILE: simplejson/compat.py ================================================ """Python 3 compatibility shims """ import sys if sys.version_info[0] < 3: PY3 = False def b(s): return s try: from cStringIO import StringIO except ImportError: from StringIO import StringIO BytesIO = StringIO text_type = unicode binary_type = str string_types = (basestring,) integer_types = (int, long) unichr = unichr reload_module = reload else: PY3 = True if sys.version_info[:2] >= (3, 4): from importlib import reload as reload_module else: from imp import reload as reload_module def b(s): return bytes(s, 'latin1') from io import StringIO, BytesIO text_type = str binary_type = bytes string_types = (str,) integer_types = (int,) unichr = chr long_type = integer_types[-1] ================================================ FILE: simplejson/decoder.py ================================================ """Implementation of JSONDecoder """ from __future__ import absolute_import import re import sys import struct from .compat import PY3, unichr from .scanner import make_scanner, JSONDecodeError def _import_c_scanstring(): try: from ._speedups import scanstring return scanstring except ImportError: return None c_scanstring = _import_c_scanstring() # NOTE (3.1.0): JSONDecodeError may still be imported from this module for # compatibility, but it was never in the __all__ __all__ = ['JSONDecoder'] FLAGS = re.VERBOSE | re.MULTILINE | re.DOTALL def _floatconstants(): if sys.version_info < (2, 6): _BYTES = '7FF80000000000007FF0000000000000'.decode('hex') nan, inf = struct.unpack('>dd', _BYTES) else: nan = float('nan') inf = float('inf') return nan, inf, -inf NaN, PosInf, NegInf = _floatconstants() _CONSTANTS = { '-Infinity': NegInf, 'Infinity': PosInf, 'NaN': NaN, } STRINGCHUNK = re.compile(r'(.*?)(["\\\x00-\x1f])', FLAGS) BACKSLASH = { '"': u'"', '\\': u'\\', '/': u'/', 'b': u'\b', 'f': u'\f', 'n': u'\n', 'r': u'\r', 't': u'\t', } DEFAULT_ENCODING = "utf-8" def py_scanstring(s, end, encoding=None, strict=True, _b=BACKSLASH, _m=STRINGCHUNK.match, _join=u''.join, _PY3=PY3, _maxunicode=sys.maxunicode): """Scan the string s for a JSON string. End is the index of the character in s after the quote that started the JSON string. Unescapes all valid JSON string escape sequences and raises ValueError on attempt to decode an invalid string. If strict is False then literal control characters are allowed in the string. Returns a tuple of the decoded string and the index of the character in s after the end quote.""" if encoding is None: encoding = DEFAULT_ENCODING chunks = [] _append = chunks.append begin = end - 1 while 1: chunk = _m(s, end) if chunk is None: raise JSONDecodeError( "Unterminated string starting at", s, begin) end = chunk.end() content, terminator = chunk.groups() # Content is contains zero or more unescaped string characters if content: if not _PY3 and not isinstance(content, unicode): content = unicode(content, encoding) _append(content) # Terminator is the end of string, a literal control character, # or a backslash denoting that an escape sequence follows if terminator == '"': break elif terminator != '\\': if strict: msg = "Invalid control character %r at" raise JSONDecodeError(msg, s, end) else: _append(terminator) continue try: esc = s[end] except IndexError: raise JSONDecodeError( "Unterminated string starting at", s, begin) # If not a unicode escape sequence, must be in the lookup table if esc != 'u': try: char = _b[esc] except KeyError: msg = "Invalid \\X escape sequence %r" raise JSONDecodeError(msg, s, end) end += 1 else: # Unicode escape sequence msg = "Invalid \\uXXXX escape sequence" esc = s[end + 1:end + 5] escX = esc[1:2] if len(esc) != 4 or escX == 'x' or escX == 'X': raise JSONDecodeError(msg, s, end - 1) try: uni = int(esc, 16) except ValueError: raise JSONDecodeError(msg, s, end - 1) end += 5 # Check for surrogate pair on UCS-4 systems # Note that this will join high/low surrogate pairs # but will also pass unpaired surrogates through if (_maxunicode > 65535 and uni & 0xfc00 == 0xd800 and s[end:end + 2] == '\\u'): esc2 = s[end + 2:end + 6] escX = esc2[1:2] if len(esc2) == 4 and not (escX == 'x' or escX == 'X'): try: uni2 = int(esc2, 16) except ValueError: raise JSONDecodeError(msg, s, end) if uni2 & 0xfc00 == 0xdc00: uni = 0x10000 + (((uni - 0xd800) << 10) | (uni2 - 0xdc00)) end += 6 char = unichr(uni) # Append the unescaped character _append(char) return _join(chunks), end # Use speedup if available scanstring = c_scanstring or py_scanstring WHITESPACE = re.compile(r'[ \t\n\r]*', FLAGS) WHITESPACE_STR = ' \t\n\r' def JSONObject(state, encoding, strict, scan_once, object_hook, object_pairs_hook, memo=None, _w=WHITESPACE.match, _ws=WHITESPACE_STR): (s, end) = state # Backwards compatibility if memo is None: memo = {} memo_get = memo.setdefault pairs = [] # Use a slice to prevent IndexError from being raised, the following # check will raise a more specific ValueError if the string is empty nextchar = s[end:end + 1] # Normally we expect nextchar == '"' if nextchar != '"': if nextchar in _ws: end = _w(s, end).end() nextchar = s[end:end + 1] # Trivial empty object if nextchar == '}': if object_pairs_hook is not None: result = object_pairs_hook(pairs) return result, end + 1 pairs = {} if object_hook is not None: pairs = object_hook(pairs) return pairs, end + 1 elif nextchar != '"': raise JSONDecodeError( "Expecting property name enclosed in double quotes", s, end) end += 1 while True: key, end = scanstring(s, end, encoding, strict) key = memo_get(key, key) # To skip some function call overhead we optimize the fast paths where # the JSON key separator is ": " or just ":". if s[end:end + 1] != ':': end = _w(s, end).end() if s[end:end + 1] != ':': raise JSONDecodeError("Expecting ':' delimiter", s, end) end += 1 try: if s[end] in _ws: end += 1 if s[end] in _ws: end = _w(s, end + 1).end() except IndexError: pass value, end = scan_once(s, end) pairs.append((key, value)) try: nextchar = s[end] if nextchar in _ws: end = _w(s, end + 1).end() nextchar = s[end] except IndexError: nextchar = '' end += 1 if nextchar == '}': break elif nextchar != ',': raise JSONDecodeError("Expecting ',' delimiter or '}'", s, end - 1) try: nextchar = s[end] if nextchar in _ws: end += 1 nextchar = s[end] if nextchar in _ws: end = _w(s, end + 1).end() nextchar = s[end] except IndexError: nextchar = '' end += 1 if nextchar != '"': raise JSONDecodeError( "Expecting property name enclosed in double quotes", s, end - 1) if object_pairs_hook is not None: result = object_pairs_hook(pairs) return result, end pairs = dict(pairs) if object_hook is not None: pairs = object_hook(pairs) return pairs, end def JSONArray(state, scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR): (s, end) = state values = [] nextchar = s[end:end + 1] if nextchar in _ws: end = _w(s, end + 1).end() nextchar = s[end:end + 1] # Look-ahead for trivial empty array if nextchar == ']': return values, end + 1 elif nextchar == '': raise JSONDecodeError("Expecting value or ']'", s, end) _append = values.append while True: value, end = scan_once(s, end) _append(value) nextchar = s[end:end + 1] if nextchar in _ws: end = _w(s, end + 1).end() nextchar = s[end:end + 1] end += 1 if nextchar == ']': break elif nextchar != ',': raise JSONDecodeError("Expecting ',' delimiter or ']'", s, end - 1) try: if s[end] in _ws: end += 1 if s[end] in _ws: end = _w(s, end + 1).end() except IndexError: pass return values, end class JSONDecoder(object): """Simple JSON decoder Performs the following translations in decoding by default: +---------------+-------------------+ | JSON | Python | +===============+===================+ | object | dict | +---------------+-------------------+ | array | list | +---------------+-------------------+ | string | str, unicode | +---------------+-------------------+ | number (int) | int, long | +---------------+-------------------+ | number (real) | float | +---------------+-------------------+ | true | True | +---------------+-------------------+ | false | False | +---------------+-------------------+ | null | None | +---------------+-------------------+ It also understands ``NaN``, ``Infinity``, and ``-Infinity`` as their corresponding ``float`` values, which is outside the JSON spec. """ def __init__(self, encoding=None, object_hook=None, parse_float=None, parse_int=None, parse_constant=None, strict=True, object_pairs_hook=None): """ *encoding* determines the encoding used to interpret any :class:`str` objects decoded by this instance (``'utf-8'`` by default). It has no effect when decoding :class:`unicode` objects. Note that currently only encodings that are a superset of ASCII work, strings of other encodings should be passed in as :class:`unicode`. *object_hook*, if specified, will be called with the result of every JSON object decoded and its return value will be used in place of the given :class:`dict`. This can be used to provide custom deserializations (e.g. to support JSON-RPC class hinting). *object_pairs_hook* is an optional function that will be called with the result of any object literal decode with an ordered list of pairs. The return value of *object_pairs_hook* will be used instead of the :class:`dict`. This feature can be used to implement custom decoders that rely on the order that the key and value pairs are decoded (for example, :func:`collections.OrderedDict` will remember the order of insertion). If *object_hook* is also defined, the *object_pairs_hook* takes priority. *parse_float*, if specified, will be called with the string of every JSON float to be decoded. By default, this is equivalent to ``float(num_str)``. This can be used to use another datatype or parser for JSON floats (e.g. :class:`decimal.Decimal`). *parse_int*, if specified, will be called with the string of every JSON int to be decoded. By default, this is equivalent to ``int(num_str)``. This can be used to use another datatype or parser for JSON integers (e.g. :class:`float`). *parse_constant*, if specified, will be called with one of the following strings: ``'-Infinity'``, ``'Infinity'``, ``'NaN'``. This can be used to raise an exception if invalid JSON numbers are encountered. *strict* controls the parser's behavior when it encounters an invalid control character in a string. The default setting of ``True`` means that unescaped control characters are parse errors, if ``False`` then control characters will be allowed in strings. """ if encoding is None: encoding = DEFAULT_ENCODING self.encoding = encoding self.object_hook = object_hook self.object_pairs_hook = object_pairs_hook self.parse_float = parse_float or float self.parse_int = parse_int or int self.parse_constant = parse_constant or _CONSTANTS.__getitem__ self.strict = strict self.parse_object = JSONObject self.parse_array = JSONArray self.parse_string = scanstring self.memo = {} self.scan_once = make_scanner(self) def decode(self, s, _w=WHITESPACE.match, _PY3=PY3): """Return the Python representation of ``s`` (a ``str`` or ``unicode`` instance containing a JSON document) """ if _PY3 and isinstance(s, bytes): s = str(s, self.encoding) obj, end = self.raw_decode(s) end = _w(s, end).end() if end != len(s): raise JSONDecodeError("Extra data", s, end, len(s)) return obj def raw_decode(self, s, idx=0, _w=WHITESPACE.match, _PY3=PY3): """Decode a JSON document from ``s`` (a ``str`` or ``unicode`` beginning with a JSON document) and return a 2-tuple of the Python representation and the index in ``s`` where the document ended. Optionally, ``idx`` can be used to specify an offset in ``s`` where the JSON document begins. This can be used to decode a JSON document from a string that may have extraneous data at the end. """ if idx < 0: # Ensure that raw_decode bails on negative indexes, the regex # would otherwise mask this behavior. #98 raise JSONDecodeError('Expecting value', s, idx) if _PY3 and not isinstance(s, str): raise TypeError("Input string must be text, not bytes") # strip UTF-8 bom if len(s) > idx: ord0 = ord(s[idx]) if ord0 == 0xfeff: idx += 1 elif ord0 == 0xef and s[idx:idx + 3] == '\xef\xbb\xbf': idx += 3 return self.scan_once(s, idx=_w(s, idx).end()) ================================================ FILE: simplejson/encoder.py ================================================ """Implementation of JSONEncoder """ from __future__ import absolute_import import re from operator import itemgetter # Do not import Decimal directly to avoid reload issues import decimal from .compat import unichr, binary_type, text_type, string_types, integer_types, PY3 def _import_speedups(): try: from . import _speedups return _speedups.encode_basestring_ascii, _speedups.make_encoder except ImportError: return None, None c_encode_basestring_ascii, c_make_encoder = _import_speedups() from .decoder import PosInf from .raw_json import RawJSON ESCAPE = re.compile(r'[\x00-\x1f\\"]') ESCAPE_ASCII = re.compile(r'([\\"]|[^\ -~])') HAS_UTF8 = re.compile(r'[\x80-\xff]') ESCAPE_DCT = { '\\': '\\\\', '"': '\\"', '\b': '\\b', '\f': '\\f', '\n': '\\n', '\r': '\\r', '\t': '\\t', } for i in range(0x20): #ESCAPE_DCT.setdefault(chr(i), '\\u{0:04x}'.format(i)) ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,)) FLOAT_REPR = repr def encode_basestring(s, _PY3=PY3, _q=u'"'): """Return a JSON representation of a Python string """ if _PY3: if isinstance(s, bytes): s = str(s, 'utf-8') elif type(s) is not str: # convert an str subclass instance to exact str # raise a TypeError otherwise s = str.__str__(s) else: if isinstance(s, str) and HAS_UTF8.search(s) is not None: s = unicode(s, 'utf-8') elif type(s) not in (str, unicode): # convert an str subclass instance to exact str # convert a unicode subclass instance to exact unicode # raise a TypeError otherwise if isinstance(s, str): s = str.__str__(s) else: s = unicode.__getnewargs__(s)[0] def replace(match): return ESCAPE_DCT[match.group(0)] return _q + ESCAPE.sub(replace, s) + _q def py_encode_basestring_ascii(s, _PY3=PY3): """Return an ASCII-only JSON representation of a Python string """ if _PY3: if isinstance(s, bytes): s = str(s, 'utf-8') elif type(s) is not str: # convert an str subclass instance to exact str # raise a TypeError otherwise s = str.__str__(s) else: if isinstance(s, str) and HAS_UTF8.search(s) is not None: s = unicode(s, 'utf-8') elif type(s) not in (str, unicode): # convert an str subclass instance to exact str # convert a unicode subclass instance to exact unicode # raise a TypeError otherwise if isinstance(s, str): s = str.__str__(s) else: s = unicode.__getnewargs__(s)[0] def replace(match): s = match.group(0) try: return ESCAPE_DCT[s] except KeyError: n = ord(s) if n < 0x10000: #return '\\u{0:04x}'.format(n) return '\\u%04x' % (n,) else: # surrogate pair n -= 0x10000 s1 = 0xd800 | ((n >> 10) & 0x3ff) s2 = 0xdc00 | (n & 0x3ff) #return '\\u{0:04x}\\u{1:04x}'.format(s1, s2) return '\\u%04x\\u%04x' % (s1, s2) return '"' + str(ESCAPE_ASCII.sub(replace, s)) + '"' encode_basestring_ascii = ( c_encode_basestring_ascii or py_encode_basestring_ascii) class JSONEncoder(object): """Extensible JSON encoder for Python data structures. Supports the following objects and types by default: +-------------------+---------------+ | Python | JSON | +===================+===============+ | dict, namedtuple | object | +-------------------+---------------+ | list, tuple | array | +-------------------+---------------+ | str, unicode | string | +-------------------+---------------+ | int, long, float | number | +-------------------+---------------+ | True | true | +-------------------+---------------+ | False | false | +-------------------+---------------+ | None | null | +-------------------+---------------+ To extend this to recognize other objects, subclass and implement a ``.default()`` method with another method that returns a serializable object for ``o`` if possible, otherwise it should call the superclass implementation (to raise ``TypeError``). """ item_separator = ', ' key_separator = ': ' def __init__(self, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, encoding='utf-8', default=None, use_decimal=True, namedtuple_as_object=True, tuple_as_array=True, bigint_as_string=False, item_sort_key=None, for_json=False, ignore_nan=False, int_as_string_bitcount=None, iterable_as_array=False): """Constructor for JSONEncoder, with sensible defaults. If skipkeys is false, then it is a TypeError to attempt encoding of keys that are not str, int, long, float or None. If skipkeys is True, such items are simply skipped. If ensure_ascii is true, the output is guaranteed to be str objects with all incoming unicode characters escaped. If ensure_ascii is false, the output will be unicode object. If check_circular is true, then lists, dicts, and custom encoded objects will be checked for circular references during encoding to prevent an infinite recursion (which would cause an OverflowError). Otherwise, no such check takes place. If allow_nan is true, then NaN, Infinity, and -Infinity will be encoded as such. This behavior is not JSON specification compliant, but is consistent with most JavaScript based encoders and decoders. Otherwise, it will be a ValueError to encode such floats. If sort_keys is true, then the output of dictionaries will be sorted by key; this is useful for regression tests to ensure that JSON serializations can be compared on a day-to-day basis. If indent is a string, then JSON array elements and object members will be pretty-printed with a newline followed by that string repeated for each level of nesting. ``None`` (the default) selects the most compact representation without any newlines. For backwards compatibility with versions of simplejson earlier than 2.1.0, an integer is also accepted and is converted to a string with that many spaces. If specified, separators should be an (item_separator, key_separator) tuple. The default is (', ', ': ') if *indent* is ``None`` and (',', ': ') otherwise. To get the most compact JSON representation, you should specify (',', ':') to eliminate whitespace. If specified, default is a function that gets called for objects that can't otherwise be serialized. It should return a JSON encodable version of the object or raise a ``TypeError``. If encoding is not None, then all input strings will be transformed into unicode using that encoding prior to JSON-encoding. The default is UTF-8. If use_decimal is true (default: ``True``), ``decimal.Decimal`` will be supported directly by the encoder. For the inverse, decode JSON with ``parse_float=decimal.Decimal``. If namedtuple_as_object is true (the default), objects with ``_asdict()`` methods will be encoded as JSON objects. If tuple_as_array is true (the default), tuple (and subclasses) will be encoded as JSON arrays. If *iterable_as_array* is true (default: ``False``), any object not in the above table that implements ``__iter__()`` will be encoded as a JSON array. If bigint_as_string is true (not the default), ints 2**53 and higher or lower than -2**53 will be encoded as strings. This is to avoid the rounding that happens in Javascript otherwise. If int_as_string_bitcount is a positive number (n), then int of size greater than or equal to 2**n or lower than or equal to -2**n will be encoded as strings. If specified, item_sort_key is a callable used to sort the items in each dictionary. This is useful if you want to sort items other than in alphabetical order by key. If for_json is true (not the default), objects with a ``for_json()`` method will use the return value of that method for encoding as JSON instead of the object. If *ignore_nan* is true (default: ``False``), then out of range :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as ``null`` in compliance with the ECMA-262 specification. If true, this will override *allow_nan*. """ self.skipkeys = skipkeys self.ensure_ascii = ensure_ascii self.check_circular = check_circular self.allow_nan = allow_nan self.sort_keys = sort_keys self.use_decimal = use_decimal self.namedtuple_as_object = namedtuple_as_object self.tuple_as_array = tuple_as_array self.iterable_as_array = iterable_as_array self.bigint_as_string = bigint_as_string self.item_sort_key = item_sort_key self.for_json = for_json self.ignore_nan = ignore_nan self.int_as_string_bitcount = int_as_string_bitcount if indent is not None and not isinstance(indent, string_types): indent = indent * ' ' self.indent = indent if separators is not None: self.item_separator, self.key_separator = separators elif indent is not None: self.item_separator = ',' if default is not None: self.default = default self.encoding = encoding def default(self, o): """Implement this method in a subclass such that it returns a serializable object for ``o``, or calls the base implementation (to raise a ``TypeError``). For example, to support arbitrary iterators, you could implement default like this:: def default(self, o): try: iterable = iter(o) except TypeError: pass else: return list(iterable) return JSONEncoder.default(self, o) """ raise TypeError('Object of type %s is not JSON serializable' % o.__class__.__name__) def encode(self, o): """Return a JSON string representation of a Python data structure. >>> from simplejson import JSONEncoder >>> JSONEncoder().encode({"foo": ["bar", "baz"]}) '{"foo": ["bar", "baz"]}' """ # This is for extremely simple cases and benchmarks. if isinstance(o, binary_type): _encoding = self.encoding if (_encoding is not None and not (_encoding == 'utf-8')): o = text_type(o, _encoding) if isinstance(o, string_types): if self.ensure_ascii: return encode_basestring_ascii(o) else: return encode_basestring(o) # This doesn't pass the iterator directly to ''.join() because the # exceptions aren't as detailed. The list call should be roughly # equivalent to the PySequence_Fast that ''.join() would do. chunks = self.iterencode(o, _one_shot=True) if not isinstance(chunks, (list, tuple)): chunks = list(chunks) if self.ensure_ascii: return ''.join(chunks) else: return u''.join(chunks) def iterencode(self, o, _one_shot=False): """Encode the given object and yield each string representation as available. For example:: for chunk in JSONEncoder().iterencode(bigobject): mysocket.write(chunk) """ if self.check_circular: markers = {} else: markers = None if self.ensure_ascii: _encoder = encode_basestring_ascii else: _encoder = encode_basestring if self.encoding != 'utf-8' and self.encoding is not None: def _encoder(o, _orig_encoder=_encoder, _encoding=self.encoding): if isinstance(o, binary_type): o = text_type(o, _encoding) return _orig_encoder(o) def floatstr(o, allow_nan=self.allow_nan, ignore_nan=self.ignore_nan, _repr=FLOAT_REPR, _inf=PosInf, _neginf=-PosInf): # Check for specials. Note that this type of test is processor # and/or platform-specific, so do tests which don't depend on # the internals. if o != o: text = 'NaN' elif o == _inf: text = 'Infinity' elif o == _neginf: text = '-Infinity' else: if type(o) != float: # See #118, do not trust custom str/repr o = float(o) return _repr(o) if ignore_nan: text = 'null' elif not allow_nan: raise ValueError( "Out of range float values are not JSON compliant: " + repr(o)) return text key_memo = {} int_as_string_bitcount = ( 53 if self.bigint_as_string else self.int_as_string_bitcount) if (_one_shot and c_make_encoder is not None and self.indent is None): _iterencode = c_make_encoder( markers, self.default, _encoder, self.indent, self.key_separator, self.item_separator, self.sort_keys, self.skipkeys, self.allow_nan, key_memo, self.use_decimal, self.namedtuple_as_object, self.tuple_as_array, int_as_string_bitcount, self.item_sort_key, self.encoding, self.for_json, self.ignore_nan, decimal.Decimal, self.iterable_as_array) else: _iterencode = _make_iterencode( markers, self.default, _encoder, self.indent, floatstr, self.key_separator, self.item_separator, self.sort_keys, self.skipkeys, _one_shot, self.use_decimal, self.namedtuple_as_object, self.tuple_as_array, int_as_string_bitcount, self.item_sort_key, self.encoding, self.for_json, self.iterable_as_array, Decimal=decimal.Decimal) try: return _iterencode(o, 0) finally: key_memo.clear() class JSONEncoderForHTML(JSONEncoder): """An encoder that produces JSON safe to embed in HTML. To embed JSON content in, say, a script tag on a web page, the characters &, < and > should be escaped. They cannot be escaped with the usual entities (e.g. &) because they are not expanded within ' self.assertEqual( r'"\u003c/script\u003e\u003cscript\u003e' r'alert(\"gotcha\")\u003c/script\u003e"', self.encoder.encode(bad_string)) self.assertEqual( bad_string, self.decoder.decode( self.encoder.encode(bad_string))) ================================================ FILE: simplejson/tests/test_errors.py ================================================ import sys, pickle from unittest import TestCase import simplejson as json from simplejson.compat import text_type, b class TestErrors(TestCase): def test_string_keys_error(self): data = [{'a': 'A', 'b': (2, 4), 'c': 3.0, ('d',): 'D tuple'}] try: json.dumps(data) except TypeError: err = sys.exc_info()[1] else: self.fail('Expected TypeError') self.assertEqual(str(err), 'keys must be str, int, float, bool or None, not tuple') def test_not_serializable(self): try: json.dumps(json) except TypeError: err = sys.exc_info()[1] else: self.fail('Expected TypeError') self.assertEqual(str(err), 'Object of type module is not JSON serializable') def test_decode_error(self): err = None try: json.loads('{}\na\nb') except json.JSONDecodeError: err = sys.exc_info()[1] else: self.fail('Expected JSONDecodeError') self.assertEqual(err.lineno, 2) self.assertEqual(err.colno, 1) self.assertEqual(err.endlineno, 3) self.assertEqual(err.endcolno, 2) def test_scan_error(self): err = None for t in (text_type, b): try: json.loads(t('{"asdf": "')) except json.JSONDecodeError: err = sys.exc_info()[1] else: self.fail('Expected JSONDecodeError') self.assertEqual(err.lineno, 1) self.assertEqual(err.colno, 10) def test_error_is_pickable(self): err = None try: json.loads('{}\na\nb') except json.JSONDecodeError: err = sys.exc_info()[1] else: self.fail('Expected JSONDecodeError') s = pickle.dumps(err) e = pickle.loads(s) self.assertEqual(err.msg, e.msg) self.assertEqual(err.doc, e.doc) self.assertEqual(err.pos, e.pos) self.assertEqual(err.end, e.end) ================================================ FILE: simplejson/tests/test_fail.py ================================================ import sys from unittest import TestCase import simplejson as json # 2007-10-05 JSONDOCS = [ # http://json.org/JSON_checker/test/fail1.json '"A JSON payload should be an object or array, not a string."', # http://json.org/JSON_checker/test/fail2.json '["Unclosed array"', # http://json.org/JSON_checker/test/fail3.json '{unquoted_key: "keys must be quoted"}', # http://json.org/JSON_checker/test/fail4.json '["extra comma",]', # http://json.org/JSON_checker/test/fail5.json '["double extra comma",,]', # http://json.org/JSON_checker/test/fail6.json '[ , "<-- missing value"]', # http://json.org/JSON_checker/test/fail7.json '["Comma after the close"],', # http://json.org/JSON_checker/test/fail8.json '["Extra close"]]', # http://json.org/JSON_checker/test/fail9.json '{"Extra comma": true,}', # http://json.org/JSON_checker/test/fail10.json '{"Extra value after close": true} "misplaced quoted value"', # http://json.org/JSON_checker/test/fail11.json '{"Illegal expression": 1 + 2}', # http://json.org/JSON_checker/test/fail12.json '{"Illegal invocation": alert()}', # http://json.org/JSON_checker/test/fail13.json '{"Numbers cannot have leading zeroes": 013}', # http://json.org/JSON_checker/test/fail14.json '{"Numbers cannot be hex": 0x14}', # http://json.org/JSON_checker/test/fail15.json '["Illegal backslash escape: \\x15"]', # http://json.org/JSON_checker/test/fail16.json '[\\naked]', # http://json.org/JSON_checker/test/fail17.json '["Illegal backslash escape: \\017"]', # http://json.org/JSON_checker/test/fail18.json '[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', # http://json.org/JSON_checker/test/fail19.json '{"Missing colon" null}', # http://json.org/JSON_checker/test/fail20.json '{"Double colon":: null}', # http://json.org/JSON_checker/test/fail21.json '{"Comma instead of colon", null}', # http://json.org/JSON_checker/test/fail22.json '["Colon instead of comma": false]', # http://json.org/JSON_checker/test/fail23.json '["Bad value", truth]', # http://json.org/JSON_checker/test/fail24.json "['single quote']", # http://json.org/JSON_checker/test/fail25.json '["\ttab\tcharacter\tin\tstring\t"]', # http://json.org/JSON_checker/test/fail26.json '["tab\\ character\\ in\\ string\\ "]', # http://json.org/JSON_checker/test/fail27.json '["line\nbreak"]', # http://json.org/JSON_checker/test/fail28.json '["line\\\nbreak"]', # http://json.org/JSON_checker/test/fail29.json '[0e]', # http://json.org/JSON_checker/test/fail30.json '[0e+]', # http://json.org/JSON_checker/test/fail31.json '[0e+-1]', # http://json.org/JSON_checker/test/fail32.json '{"Comma instead if closing brace": true,', # http://json.org/JSON_checker/test/fail33.json '["mismatch"}', # http://code.google.com/p/simplejson/issues/detail?id=3 u'["A\u001FZ control characters in string"]', # misc based on coverage '{', '{]', '{"foo": "bar"]', '{"foo": "bar"', 'nul', 'nulx', '-', '-x', '-e', '-e0', '-Infinite', '-Inf', 'Infinit', 'Infinite', 'NaM', 'NuN', 'falsy', 'fal', 'trug', 'tru', '1e', '1ex', '1e-', '1e-x', ] SKIPS = { 1: "why not have a string payload?", 18: "spec doesn't specify any nesting limitations", } class TestFail(TestCase): def test_failures(self): for idx, doc in enumerate(JSONDOCS): idx = idx + 1 if idx in SKIPS: json.loads(doc) continue try: json.loads(doc) except json.JSONDecodeError: pass else: self.fail("Expected failure for fail%d.json: %r" % (idx, doc)) def test_array_decoder_issue46(self): # http://code.google.com/p/simplejson/issues/detail?id=46 for doc in [u'[,]', '[,]']: try: json.loads(doc) except json.JSONDecodeError: e = sys.exc_info()[1] self.assertEqual(e.pos, 1) self.assertEqual(e.lineno, 1) self.assertEqual(e.colno, 2) except Exception: e = sys.exc_info()[1] self.fail("Unexpected exception raised %r %s" % (e, e)) else: self.fail("Unexpected success parsing '[,]'") def test_truncated_input(self): test_cases = [ ('', 'Expecting value', 0), ('[', "Expecting value or ']'", 1), ('[42', "Expecting ',' delimiter", 3), ('[42,', 'Expecting value', 4), ('["', 'Unterminated string starting at', 1), ('["spam', 'Unterminated string starting at', 1), ('["spam"', "Expecting ',' delimiter", 7), ('["spam",', 'Expecting value', 8), ('{', 'Expecting property name enclosed in double quotes', 1), ('{"', 'Unterminated string starting at', 1), ('{"spam', 'Unterminated string starting at', 1), ('{"spam"', "Expecting ':' delimiter", 7), ('{"spam":', 'Expecting value', 8), ('{"spam":42', "Expecting ',' delimiter", 10), ('{"spam":42,', 'Expecting property name enclosed in double quotes', 11), ('"', 'Unterminated string starting at', 0), ('"spam', 'Unterminated string starting at', 0), ('[,', "Expecting value", 1), ] for data, msg, idx in test_cases: try: json.loads(data) except json.JSONDecodeError: e = sys.exc_info()[1] self.assertEqual( e.msg[:len(msg)], msg, "%r doesn't start with %r for %r" % (e.msg, msg, data)) self.assertEqual( e.pos, idx, "pos %r != %r for %r" % (e.pos, idx, data)) except Exception: e = sys.exc_info()[1] self.fail("Unexpected exception raised %r %s" % (e, e)) else: self.fail("Unexpected success parsing '%r'" % (data,)) ================================================ FILE: simplejson/tests/test_float.py ================================================ import math from unittest import TestCase from simplejson.compat import long_type, text_type import simplejson as json from simplejson.decoder import NaN, PosInf, NegInf class TestFloat(TestCase): def test_degenerates_allow(self): for inf in (PosInf, NegInf): self.assertEqual(json.loads(json.dumps(inf)), inf) # Python 2.5 doesn't have math.isnan nan = json.loads(json.dumps(NaN)) self.assertTrue((0 + nan) != nan) def test_degenerates_ignore(self): for f in (PosInf, NegInf, NaN): self.assertEqual(json.loads(json.dumps(f, ignore_nan=True)), None) def test_degenerates_deny(self): for f in (PosInf, NegInf, NaN): self.assertRaises(ValueError, json.dumps, f, allow_nan=False) def test_floats(self): for num in [1617161771.7650001, math.pi, math.pi**100, math.pi**-100, 3.1]: self.assertEqual(float(json.dumps(num)), num) self.assertEqual(json.loads(json.dumps(num)), num) self.assertEqual(json.loads(text_type(json.dumps(num))), num) def test_ints(self): for num in [1, long_type(1), 1<<32, 1<<64]: self.assertEqual(json.dumps(num), str(num)) self.assertEqual(int(json.dumps(num)), num) self.assertEqual(json.loads(json.dumps(num)), num) self.assertEqual(json.loads(text_type(json.dumps(num))), num) ================================================ FILE: simplejson/tests/test_for_json.py ================================================ import unittest import simplejson as json class ForJson(object): def for_json(self): return {'for_json': 1} class NestedForJson(object): def for_json(self): return {'nested': ForJson()} class ForJsonList(object): def for_json(self): return ['list'] class DictForJson(dict): def for_json(self): return {'alpha': 1} class ListForJson(list): def for_json(self): return ['list'] class TestForJson(unittest.TestCase): def assertRoundTrip(self, obj, other, for_json=True): if for_json is None: # None will use the default s = json.dumps(obj) else: s = json.dumps(obj, for_json=for_json) self.assertEqual( json.loads(s), other) def test_for_json_encodes_stand_alone_object(self): self.assertRoundTrip( ForJson(), ForJson().for_json()) def test_for_json_encodes_object_nested_in_dict(self): self.assertRoundTrip( {'hooray': ForJson()}, {'hooray': ForJson().for_json()}) def test_for_json_encodes_object_nested_in_list_within_dict(self): self.assertRoundTrip( {'list': [0, ForJson(), 2, 3]}, {'list': [0, ForJson().for_json(), 2, 3]}) def test_for_json_encodes_object_nested_within_object(self): self.assertRoundTrip( NestedForJson(), {'nested': {'for_json': 1}}) def test_for_json_encodes_list(self): self.assertRoundTrip( ForJsonList(), ForJsonList().for_json()) def test_for_json_encodes_list_within_object(self): self.assertRoundTrip( {'nested': ForJsonList()}, {'nested': ForJsonList().for_json()}) def test_for_json_encodes_dict_subclass(self): self.assertRoundTrip( DictForJson(a=1), DictForJson(a=1).for_json()) def test_for_json_encodes_list_subclass(self): self.assertRoundTrip( ListForJson(['l']), ListForJson(['l']).for_json()) def test_for_json_ignored_if_not_true_with_dict_subclass(self): for for_json in (None, False): self.assertRoundTrip( DictForJson(a=1), {'a': 1}, for_json=for_json) def test_for_json_ignored_if_not_true_with_list_subclass(self): for for_json in (None, False): self.assertRoundTrip( ListForJson(['l']), ['l'], for_json=for_json) def test_raises_typeerror_if_for_json_not_true_with_object(self): self.assertRaises(TypeError, json.dumps, ForJson()) self.assertRaises(TypeError, json.dumps, ForJson(), for_json=False) ================================================ FILE: simplejson/tests/test_indent.py ================================================ from unittest import TestCase import textwrap import simplejson as json from simplejson.compat import StringIO class TestIndent(TestCase): def test_indent(self): h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', 'i-vhbjkhnth', {'nifty': 87}, {'field': 'yes', 'morefield': False} ] expect = textwrap.dedent("""\ [ \t[ \t\t"blorpie" \t], \t[ \t\t"whoops" \t], \t[], \t"d-shtaeou", \t"d-nthiouh", \t"i-vhbjkhnth", \t{ \t\t"nifty": 87 \t}, \t{ \t\t"field": "yes", \t\t"morefield": false \t} ]""") d1 = json.dumps(h) d2 = json.dumps(h, indent='\t', sort_keys=True, separators=(',', ': ')) d3 = json.dumps(h, indent=' ', sort_keys=True, separators=(',', ': ')) d4 = json.dumps(h, indent=2, sort_keys=True, separators=(',', ': ')) h1 = json.loads(d1) h2 = json.loads(d2) h3 = json.loads(d3) h4 = json.loads(d4) self.assertEqual(h1, h) self.assertEqual(h2, h) self.assertEqual(h3, h) self.assertEqual(h4, h) self.assertEqual(d3, expect.replace('\t', ' ')) self.assertEqual(d4, expect.replace('\t', ' ')) # NOTE: Python 2.4 textwrap.dedent converts tabs to spaces, # so the following is expected to fail. Python 2.4 is not a # supported platform in simplejson 2.1.0+. self.assertEqual(d2, expect) def test_indent0(self): h = {3: 1} def check(indent, expected): d1 = json.dumps(h, indent=indent) self.assertEqual(d1, expected) sio = StringIO() json.dump(h, sio, indent=indent) self.assertEqual(sio.getvalue(), expected) # indent=0 should emit newlines check(0, '{\n"3": 1\n}') # indent=None is more compact check(None, '{"3": 1}') def test_separators(self): lst = [1,2,3,4] expect = '[\n1,\n2,\n3,\n4\n]' expect_spaces = '[\n1, \n2, \n3, \n4\n]' # Ensure that separators still works self.assertEqual( expect_spaces, json.dumps(lst, indent=0, separators=(', ', ': '))) # Force the new defaults self.assertEqual( expect, json.dumps(lst, indent=0, separators=(',', ': '))) # Added in 2.1.4 self.assertEqual( expect, json.dumps(lst, indent=0)) ================================================ FILE: simplejson/tests/test_item_sort_key.py ================================================ from unittest import TestCase import simplejson as json from operator import itemgetter class TestItemSortKey(TestCase): def test_simple_first(self): a = {'a': 1, 'c': 5, 'jack': 'jill', 'pick': 'axe', 'array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'} self.assertEqual( '{"a": 1, "c": 5, "crate": "dog", "jack": "jill", "pick": "axe", "zeak": "oh", "array": [1, 5, 6, 9], "tuple": [83, 12, 3]}', json.dumps(a, item_sort_key=json.simple_first)) def test_case(self): a = {'a': 1, 'c': 5, 'Jack': 'jill', 'pick': 'axe', 'Array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'} self.assertEqual( '{"Array": [1, 5, 6, 9], "Jack": "jill", "a": 1, "c": 5, "crate": "dog", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}', json.dumps(a, item_sort_key=itemgetter(0))) self.assertEqual( '{"a": 1, "Array": [1, 5, 6, 9], "c": 5, "crate": "dog", "Jack": "jill", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}', json.dumps(a, item_sort_key=lambda kv: kv[0].lower())) def test_item_sort_key_value(self): # https://github.com/simplejson/simplejson/issues/173 a = {'a': 1, 'b': 0} self.assertEqual( '{"b": 0, "a": 1}', json.dumps(a, item_sort_key=lambda kv: kv[1])) ================================================ FILE: simplejson/tests/test_iterable.py ================================================ import unittest from simplejson.compat import StringIO import simplejson as json def iter_dumps(obj, **kw): return ''.join(json.JSONEncoder(**kw).iterencode(obj)) def sio_dump(obj, **kw): sio = StringIO() json.dumps(obj, **kw) return sio.getvalue() class TestIterable(unittest.TestCase): def test_iterable(self): for l in ([], [1], [1, 2], [1, 2, 3]): for opts in [{}, {'indent': 2}]: for dumps in (json.dumps, iter_dumps, sio_dump): expect = dumps(l, **opts) default_expect = dumps(sum(l), **opts) # Default is False self.assertRaises(TypeError, dumps, iter(l), **opts) self.assertRaises(TypeError, dumps, iter(l), iterable_as_array=False, **opts) self.assertEqual(expect, dumps(iter(l), iterable_as_array=True, **opts)) # Ensure that the "default" gets called self.assertEqual(default_expect, dumps(iter(l), default=sum, **opts)) self.assertEqual(default_expect, dumps(iter(l), iterable_as_array=False, default=sum, **opts)) # Ensure that the "default" does not get called self.assertEqual( expect, dumps(iter(l), iterable_as_array=True, default=sum, **opts)) ================================================ FILE: simplejson/tests/test_namedtuple.py ================================================ from __future__ import absolute_import import unittest import simplejson as json from simplejson.compat import StringIO try: from collections import namedtuple except ImportError: class Value(tuple): def __new__(cls, *args): return tuple.__new__(cls, args) def _asdict(self): return {'value': self[0]} class Point(tuple): def __new__(cls, *args): return tuple.__new__(cls, args) def _asdict(self): return {'x': self[0], 'y': self[1]} else: Value = namedtuple('Value', ['value']) Point = namedtuple('Point', ['x', 'y']) class DuckValue(object): def __init__(self, *args): self.value = Value(*args) def _asdict(self): return self.value._asdict() class DuckPoint(object): def __init__(self, *args): self.point = Point(*args) def _asdict(self): return self.point._asdict() class DeadDuck(object): _asdict = None class DeadDict(dict): _asdict = None CONSTRUCTORS = [ lambda v: v, lambda v: [v], lambda v: [{'key': v}], ] class TestNamedTuple(unittest.TestCase): def test_namedtuple_dumps(self): for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]: d = v._asdict() self.assertEqual(d, json.loads(json.dumps(v))) self.assertEqual( d, json.loads(json.dumps(v, namedtuple_as_object=True))) self.assertEqual(d, json.loads(json.dumps(v, tuple_as_array=False))) self.assertEqual( d, json.loads(json.dumps(v, namedtuple_as_object=True, tuple_as_array=False))) def test_namedtuple_dumps_false(self): for v in [Value(1), Point(1, 2)]: l = list(v) self.assertEqual( l, json.loads(json.dumps(v, namedtuple_as_object=False))) self.assertRaises(TypeError, json.dumps, v, tuple_as_array=False, namedtuple_as_object=False) def test_namedtuple_dump(self): for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]: d = v._asdict() sio = StringIO() json.dump(v, sio) self.assertEqual(d, json.loads(sio.getvalue())) sio = StringIO() json.dump(v, sio, namedtuple_as_object=True) self.assertEqual( d, json.loads(sio.getvalue())) sio = StringIO() json.dump(v, sio, tuple_as_array=False) self.assertEqual(d, json.loads(sio.getvalue())) sio = StringIO() json.dump(v, sio, namedtuple_as_object=True, tuple_as_array=False) self.assertEqual( d, json.loads(sio.getvalue())) def test_namedtuple_dump_false(self): for v in [Value(1), Point(1, 2)]: l = list(v) sio = StringIO() json.dump(v, sio, namedtuple_as_object=False) self.assertEqual( l, json.loads(sio.getvalue())) self.assertRaises(TypeError, json.dump, v, StringIO(), tuple_as_array=False, namedtuple_as_object=False) def test_asdict_not_callable_dump(self): for f in CONSTRUCTORS: self.assertRaises(TypeError, json.dump, f(DeadDuck()), StringIO(), namedtuple_as_object=True) sio = StringIO() json.dump(f(DeadDict()), sio, namedtuple_as_object=True) self.assertEqual( json.dumps(f({})), sio.getvalue()) def test_asdict_not_callable_dumps(self): for f in CONSTRUCTORS: self.assertRaises(TypeError, json.dumps, f(DeadDuck()), namedtuple_as_object=True) self.assertEqual( json.dumps(f({})), json.dumps(f(DeadDict()), namedtuple_as_object=True)) ================================================ FILE: simplejson/tests/test_pass1.py ================================================ from unittest import TestCase import simplejson as json # from http://json.org/JSON_checker/test/pass1.json JSON = r''' [ "JSON Test Pattern pass1", {"object with 1 member":["array with 1 element"]}, {}, [], -42, true, false, null, { "integer": 1234567890, "real": -9876.543210, "e": 0.123456789e-12, "E": 1.234567890E+34, "": 23456789012E66, "zero": 0, "one": 1, "space": " ", "quote": "\"", "backslash": "\\", "controls": "\b\f\n\r\t", "slash": "/ & \/", "alpha": "abcdefghijklmnopqrstuvwyz", "ALPHA": "ABCDEFGHIJKLMNOPQRSTUVWYZ", "digit": "0123456789", "special": "`1~!@#$%^&*()_+-={':[,]}|;.?", "hex": "\u0123\u4567\u89AB\uCDEF\uabcd\uef4A", "true": true, "false": false, "null": null, "array":[ ], "object":{ }, "address": "50 St. James Street", "url": "http://www.JSON.org/", "comment": "// /* */": " ", " s p a c e d " :[1,2 , 3 , 4 , 5 , 6 ,7 ],"compact": [1,2,3,4,5,6,7], "jsontext": "{\"object with 1 member\":[\"array with 1 element\"]}", "quotes": "" \u0022 %22 0x22 034 "", "\/\\\"\uCAFE\uBABE\uAB98\uFCDE\ubcda\uef4A\b\f\n\r\t`1~!@#$%^&*()_+-=[]{}|;:',./<>?" : "A key can be any string" }, 0.5 ,98.6 , 99.44 , 1066, 1e1, 0.1e1, 1e-1, 1e00,2e+00,2e-00 ,"rosebud"] ''' class TestPass1(TestCase): def test_parse(self): # test in/out equivalence and parsing res = json.loads(JSON) out = json.dumps(res) self.assertEqual(res, json.loads(out)) ================================================ FILE: simplejson/tests/test_pass2.py ================================================ from unittest import TestCase import simplejson as json # from http://json.org/JSON_checker/test/pass2.json JSON = r''' [[[[[[[[[[[[[[[[[[["Not too deep"]]]]]]]]]]]]]]]]]]] ''' class TestPass2(TestCase): def test_parse(self): # test in/out equivalence and parsing res = json.loads(JSON) out = json.dumps(res) self.assertEqual(res, json.loads(out)) ================================================ FILE: simplejson/tests/test_pass3.py ================================================ from unittest import TestCase import simplejson as json # from http://json.org/JSON_checker/test/pass3.json JSON = r''' { "JSON Test Pattern pass3": { "The outermost value": "must be an object or array.", "In this test": "It is an object." } } ''' class TestPass3(TestCase): def test_parse(self): # test in/out equivalence and parsing res = json.loads(JSON) out = json.dumps(res) self.assertEqual(res, json.loads(out)) ================================================ FILE: simplejson/tests/test_raw_json.py ================================================ import unittest import simplejson as json dct1 = { 'key1': 'value1' } dct2 = { 'key2': 'value2', 'd1': dct1 } dct3 = { 'key2': 'value2', 'd1': json.dumps(dct1) } dct4 = { 'key2': 'value2', 'd1': json.RawJSON(json.dumps(dct1)) } class TestRawJson(unittest.TestCase): def test_normal_str(self): self.assertNotEqual(json.dumps(dct2), json.dumps(dct3)) def test_raw_json_str(self): self.assertEqual(json.dumps(dct2), json.dumps(dct4)) self.assertEqual(dct2, json.loads(json.dumps(dct4))) def test_list(self): self.assertEqual( json.dumps([dct2]), json.dumps([json.RawJSON(json.dumps(dct2))])) self.assertEqual( [dct2], json.loads(json.dumps([json.RawJSON(json.dumps(dct2))]))) def test_direct(self): self.assertEqual( json.dumps(dct2), json.dumps(json.RawJSON(json.dumps(dct2)))) self.assertEqual( dct2, json.loads(json.dumps(json.RawJSON(json.dumps(dct2))))) ================================================ FILE: simplejson/tests/test_recursion.py ================================================ from unittest import TestCase import simplejson as json class JSONTestObject: pass class RecursiveJSONEncoder(json.JSONEncoder): recurse = False def default(self, o): if o is JSONTestObject: if self.recurse: return [JSONTestObject] else: return 'JSONTestObject' return json.JSONEncoder.default(o) class TestRecursion(TestCase): def test_listrecursion(self): x = [] x.append(x) try: json.dumps(x) except ValueError: pass else: self.fail("didn't raise ValueError on list recursion") x = [] y = [x] x.append(y) try: json.dumps(x) except ValueError: pass else: self.fail("didn't raise ValueError on alternating list recursion") y = [] x = [y, y] # ensure that the marker is cleared json.dumps(x) def test_dictrecursion(self): x = {} x["test"] = x try: json.dumps(x) except ValueError: pass else: self.fail("didn't raise ValueError on dict recursion") x = {} y = {"a": x, "b": x} # ensure that the marker is cleared json.dumps(y) def test_defaultrecursion(self): enc = RecursiveJSONEncoder() self.assertEqual(enc.encode(JSONTestObject), '"JSONTestObject"') enc.recurse = True try: enc.encode(JSONTestObject) except ValueError: pass else: self.fail("didn't raise ValueError on default recursion") ================================================ FILE: simplejson/tests/test_scanstring.py ================================================ import sys from unittest import TestCase import simplejson as json import simplejson.decoder from simplejson.compat import b, PY3 class TestScanString(TestCase): # The bytes type is intentionally not used in most of these tests # under Python 3 because the decoder immediately coerces to str before # calling scanstring. In Python 2 we are testing the code paths # for both unicode and str. # # The reason this is done is because Python 3 would require # entirely different code paths for parsing bytes and str. # def test_py_scanstring(self): self._test_scanstring(simplejson.decoder.py_scanstring) def test_c_scanstring(self): if not simplejson.decoder.c_scanstring: return self._test_scanstring(simplejson.decoder.c_scanstring) self.assertTrue(isinstance(simplejson.decoder.c_scanstring('""', 0)[0], str)) def _test_scanstring(self, scanstring): if sys.maxunicode == 65535: self.assertEqual( scanstring(u'"z\U0001d120x"', 1, None, True), (u'z\U0001d120x', 6)) else: self.assertEqual( scanstring(u'"z\U0001d120x"', 1, None, True), (u'z\U0001d120x', 5)) self.assertEqual( scanstring('"\\u007b"', 1, None, True), (u'{', 8)) self.assertEqual( scanstring('"A JSON payload should be an object or array, not a string."', 1, None, True), (u'A JSON payload should be an object or array, not a string.', 60)) self.assertEqual( scanstring('["Unclosed array"', 2, None, True), (u'Unclosed array', 17)) self.assertEqual( scanstring('["extra comma",]', 2, None, True), (u'extra comma', 14)) self.assertEqual( scanstring('["double extra comma",,]', 2, None, True), (u'double extra comma', 21)) self.assertEqual( scanstring('["Comma after the close"],', 2, None, True), (u'Comma after the close', 24)) self.assertEqual( scanstring('["Extra close"]]', 2, None, True), (u'Extra close', 14)) self.assertEqual( scanstring('{"Extra comma": true,}', 2, None, True), (u'Extra comma', 14)) self.assertEqual( scanstring('{"Extra value after close": true} "misplaced quoted value"', 2, None, True), (u'Extra value after close', 26)) self.assertEqual( scanstring('{"Illegal expression": 1 + 2}', 2, None, True), (u'Illegal expression', 21)) self.assertEqual( scanstring('{"Illegal invocation": alert()}', 2, None, True), (u'Illegal invocation', 21)) self.assertEqual( scanstring('{"Numbers cannot have leading zeroes": 013}', 2, None, True), (u'Numbers cannot have leading zeroes', 37)) self.assertEqual( scanstring('{"Numbers cannot be hex": 0x14}', 2, None, True), (u'Numbers cannot be hex', 24)) self.assertEqual( scanstring('[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', 21, None, True), (u'Too deep', 30)) self.assertEqual( scanstring('{"Missing colon" null}', 2, None, True), (u'Missing colon', 16)) self.assertEqual( scanstring('{"Double colon":: null}', 2, None, True), (u'Double colon', 15)) self.assertEqual( scanstring('{"Comma instead of colon", null}', 2, None, True), (u'Comma instead of colon', 25)) self.assertEqual( scanstring('["Colon instead of comma": false]', 2, None, True), (u'Colon instead of comma', 25)) self.assertEqual( scanstring('["Bad value", truth]', 2, None, True), (u'Bad value', 12)) for c in map(chr, range(0x00, 0x1f)): self.assertEqual( scanstring(c + '"', 0, None, False), (c, 2)) self.assertRaises( ValueError, scanstring, c + '"', 0, None, True) self.assertRaises(ValueError, scanstring, '', 0, None, True) self.assertRaises(ValueError, scanstring, 'a', 0, None, True) self.assertRaises(ValueError, scanstring, '\\', 0, None, True) self.assertRaises(ValueError, scanstring, '\\u', 0, None, True) self.assertRaises(ValueError, scanstring, '\\u0', 0, None, True) self.assertRaises(ValueError, scanstring, '\\u01', 0, None, True) self.assertRaises(ValueError, scanstring, '\\u012', 0, None, True) self.assertRaises(ValueError, scanstring, '\\u0123', 0, None, True) if sys.maxunicode > 65535: self.assertRaises(ValueError, scanstring, '\\ud834\\u"', 0, None, True) self.assertRaises(ValueError, scanstring, '\\ud834\\x0123"', 0, None, True) def test_issue3623(self): self.assertRaises(ValueError, json.decoder.scanstring, "xxx", 1, "xxx") self.assertRaises(UnicodeDecodeError, json.encoder.encode_basestring_ascii, b("xx\xff")) def test_overflow(self): # Python 2.5 does not have maxsize, Python 3 does not have maxint maxsize = getattr(sys, 'maxsize', getattr(sys, 'maxint', None)) assert maxsize is not None self.assertRaises(OverflowError, json.decoder.scanstring, "xxx", maxsize + 1) def test_surrogates(self): scanstring = json.decoder.scanstring def assertScan(given, expect, test_utf8=True): givens = [given] if not PY3 and test_utf8: givens.append(given.encode('utf8')) for given in givens: (res, count) = scanstring(given, 1, None, True) self.assertEqual(len(given), count) self.assertEqual(res, expect) assertScan( u'"z\\ud834\\u0079x"', u'z\ud834yx') assertScan( u'"z\\ud834\\udd20x"', u'z\U0001d120x') assertScan( u'"z\\ud834\\ud834\\udd20x"', u'z\ud834\U0001d120x') assertScan( u'"z\\ud834x"', u'z\ud834x') assertScan( u'"z\\udd20x"', u'z\udd20x') assertScan( u'"z\ud834x"', u'z\ud834x') # It may look strange to join strings together, but Python is drunk. # https://gist.github.com/etrepum/5538443 assertScan( u'"z\\ud834\udd20x12345"', u''.join([u'z\ud834', u'\udd20x12345'])) assertScan( u'"z\ud834\\udd20x"', u''.join([u'z\ud834', u'\udd20x'])) # these have different behavior given UTF8 input, because the surrogate # pair may be joined (in maxunicode > 65535 builds) assertScan( u''.join([u'"z\ud834', u'\udd20x"']), u''.join([u'z\ud834', u'\udd20x']), test_utf8=False) self.assertRaises(ValueError, scanstring, u'"z\\ud83x"', 1, None, True) self.assertRaises(ValueError, scanstring, u'"z\\ud834\\udd2x"', 1, None, True) ================================================ FILE: simplejson/tests/test_separators.py ================================================ import textwrap from unittest import TestCase import simplejson as json class TestSeparators(TestCase): def test_separators(self): h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', 'i-vhbjkhnth', {'nifty': 87}, {'field': 'yes', 'morefield': False} ] expect = textwrap.dedent("""\ [ [ "blorpie" ] , [ "whoops" ] , [] , "d-shtaeou" , "d-nthiouh" , "i-vhbjkhnth" , { "nifty" : 87 } , { "field" : "yes" , "morefield" : false } ]""") d1 = json.dumps(h) d2 = json.dumps(h, indent=' ', sort_keys=True, separators=(' ,', ' : ')) h1 = json.loads(d1) h2 = json.loads(d2) self.assertEqual(h1, h) self.assertEqual(h2, h) self.assertEqual(d2, expect) ================================================ FILE: simplejson/tests/test_speedups.py ================================================ from __future__ import with_statement import sys import unittest from unittest import TestCase import simplejson from simplejson import encoder, decoder, scanner from simplejson.compat import PY3, long_type, b def has_speedups(): return encoder.c_make_encoder is not None def skip_if_speedups_missing(func): def wrapper(*args, **kwargs): if not has_speedups(): if hasattr(unittest, 'SkipTest'): raise unittest.SkipTest("C Extension not available") else: sys.stdout.write("C Extension not available") return return func(*args, **kwargs) return wrapper class BadBool: def __bool__(self): 1/0 __nonzero__ = __bool__ class TestDecode(TestCase): @skip_if_speedups_missing def test_make_scanner(self): self.assertRaises(AttributeError, scanner.c_make_scanner, 1) @skip_if_speedups_missing def test_bad_bool_args(self): def test(value): decoder.JSONDecoder(strict=BadBool()).decode(value) self.assertRaises(ZeroDivisionError, test, '""') self.assertRaises(ZeroDivisionError, test, '{}') if not PY3: self.assertRaises(ZeroDivisionError, test, u'""') self.assertRaises(ZeroDivisionError, test, u'{}') class TestEncode(TestCase): @skip_if_speedups_missing def test_make_encoder(self): self.assertRaises( TypeError, encoder.c_make_encoder, None, ("\xCD\x7D\x3D\x4E\x12\x4C\xF9\x79\xD7" "\x52\xBA\x82\xF2\x27\x4A\x7D\xA0\xCA\x75"), None ) @skip_if_speedups_missing def test_bad_str_encoder(self): # Issue #31505: There shouldn't be an assertion failure in case # c_make_encoder() receives a bad encoder() argument. import decimal def bad_encoder1(*args): return None enc = encoder.c_make_encoder( None, lambda obj: str(obj), bad_encoder1, None, ': ', ', ', False, False, False, {}, False, False, False, None, None, 'utf-8', False, False, decimal.Decimal, False) self.assertRaises(TypeError, enc, 'spam', 4) self.assertRaises(TypeError, enc, {'spam': 42}, 4) def bad_encoder2(*args): 1/0 enc = encoder.c_make_encoder( None, lambda obj: str(obj), bad_encoder2, None, ': ', ', ', False, False, False, {}, False, False, False, None, None, 'utf-8', False, False, decimal.Decimal, False) self.assertRaises(ZeroDivisionError, enc, 'spam', 4) @skip_if_speedups_missing def test_bad_bool_args(self): def test(name): encoder.JSONEncoder(**{name: BadBool()}).encode({}) self.assertRaises(ZeroDivisionError, test, 'skipkeys') self.assertRaises(ZeroDivisionError, test, 'ensure_ascii') self.assertRaises(ZeroDivisionError, test, 'check_circular') self.assertRaises(ZeroDivisionError, test, 'allow_nan') self.assertRaises(ZeroDivisionError, test, 'sort_keys') self.assertRaises(ZeroDivisionError, test, 'use_decimal') self.assertRaises(ZeroDivisionError, test, 'namedtuple_as_object') self.assertRaises(ZeroDivisionError, test, 'tuple_as_array') self.assertRaises(ZeroDivisionError, test, 'bigint_as_string') self.assertRaises(ZeroDivisionError, test, 'for_json') self.assertRaises(ZeroDivisionError, test, 'ignore_nan') self.assertRaises(ZeroDivisionError, test, 'iterable_as_array') @skip_if_speedups_missing def test_int_as_string_bitcount_overflow(self): long_count = long_type(2)**32+31 def test(): encoder.JSONEncoder(int_as_string_bitcount=long_count).encode(0) self.assertRaises((TypeError, OverflowError), test) if PY3: @skip_if_speedups_missing def test_bad_encoding(self): with self.assertRaises(UnicodeEncodeError): encoder.JSONEncoder(encoding='\udcff').encode({b('key'): 123}) ================================================ FILE: simplejson/tests/test_str_subclass.py ================================================ from unittest import TestCase import simplejson from simplejson.compat import text_type # Tests for issue demonstrated in https://github.com/simplejson/simplejson/issues/144 class WonkyTextSubclass(text_type): def __getslice__(self, start, end): return self.__class__('not what you wanted!') class TestStrSubclass(TestCase): def test_dump_load(self): for s in ['', '"hello"', 'text', u'\u005c']: self.assertEqual( s, simplejson.loads(simplejson.dumps(WonkyTextSubclass(s)))) self.assertEqual( s, simplejson.loads(simplejson.dumps(WonkyTextSubclass(s), ensure_ascii=False))) ================================================ FILE: simplejson/tests/test_subclass.py ================================================ from unittest import TestCase import simplejson as json from decimal import Decimal class AlternateInt(int): def __repr__(self): return 'invalid json' __str__ = __repr__ class AlternateFloat(float): def __repr__(self): return 'invalid json' __str__ = __repr__ # class AlternateDecimal(Decimal): # def __repr__(self): # return 'invalid json' class TestSubclass(TestCase): def test_int(self): self.assertEqual(json.dumps(AlternateInt(1)), '1') self.assertEqual(json.dumps(AlternateInt(-1)), '-1') self.assertEqual(json.loads(json.dumps({AlternateInt(1): 1})), {'1': 1}) def test_float(self): self.assertEqual(json.dumps(AlternateFloat(1.0)), '1.0') self.assertEqual(json.dumps(AlternateFloat(-1.0)), '-1.0') self.assertEqual(json.loads(json.dumps({AlternateFloat(1.0): 1})), {'1.0': 1}) # NOTE: Decimal subclasses are not supported as-is # def test_decimal(self): # self.assertEqual(json.dumps(AlternateDecimal('1.0')), '1.0') # self.assertEqual(json.dumps(AlternateDecimal('-1.0')), '-1.0') ================================================ FILE: simplejson/tests/test_tool.py ================================================ from __future__ import with_statement import os import sys import textwrap import unittest import subprocess import tempfile try: # Python 3.x from test.support import strip_python_stderr except ImportError: # Python 2.6+ try: from test.test_support import strip_python_stderr except ImportError: # Python 2.5 import re def strip_python_stderr(stderr): return re.sub( r"\[\d+ refs\]\r?\n?$".encode(), "".encode(), stderr).strip() def open_temp_file(): if sys.version_info >= (2, 6): file = tempfile.NamedTemporaryFile(delete=False) filename = file.name else: fd, filename = tempfile.mkstemp() file = os.fdopen(fd, 'w+b') return file, filename class TestTool(unittest.TestCase): data = """ [["blorpie"],[ "whoops" ] , [ ],\t"d-shtaeou",\r"d-nthiouh", "i-vhbjkhnth", {"nifty":87}, {"morefield" :\tfalse,"field" :"yes"} ] """ expect = textwrap.dedent("""\ [ [ "blorpie" ], [ "whoops" ], [], "d-shtaeou", "d-nthiouh", "i-vhbjkhnth", { "nifty": 87 }, { "field": "yes", "morefield": false } ] """) def runTool(self, args=None, data=None): argv = [sys.executable, '-m', 'simplejson.tool'] if args: argv.extend(args) proc = subprocess.Popen(argv, stdin=subprocess.PIPE, stderr=subprocess.PIPE, stdout=subprocess.PIPE) out, err = proc.communicate(data) self.assertEqual(strip_python_stderr(err), ''.encode()) self.assertEqual(proc.returncode, 0) return out.decode('utf8').splitlines() def test_stdin_stdout(self): self.assertEqual( self.runTool(data=self.data.encode()), self.expect.splitlines()) def test_infile_stdout(self): infile, infile_name = open_temp_file() try: infile.write(self.data.encode()) infile.close() self.assertEqual( self.runTool(args=[infile_name]), self.expect.splitlines()) finally: os.unlink(infile_name) def test_infile_outfile(self): infile, infile_name = open_temp_file() try: infile.write(self.data.encode()) infile.close() # outfile will get overwritten by tool, so the delete # may not work on some platforms. Do it manually. outfile, outfile_name = open_temp_file() try: outfile.close() self.assertEqual( self.runTool(args=[infile_name, outfile_name]), []) with open(outfile_name, 'rb') as f: self.assertEqual( f.read().decode('utf8').splitlines(), self.expect.splitlines() ) finally: os.unlink(outfile_name) finally: os.unlink(infile_name) ================================================ FILE: simplejson/tests/test_tuple.py ================================================ import unittest from simplejson.compat import StringIO import simplejson as json class TestTuples(unittest.TestCase): def test_tuple_array_dumps(self): t = (1, 2, 3) expect = json.dumps(list(t)) # Default is True self.assertEqual(expect, json.dumps(t)) self.assertEqual(expect, json.dumps(t, tuple_as_array=True)) self.assertRaises(TypeError, json.dumps, t, tuple_as_array=False) # Ensure that the "default" does not get called self.assertEqual(expect, json.dumps(t, default=repr)) self.assertEqual(expect, json.dumps(t, tuple_as_array=True, default=repr)) # Ensure that the "default" gets called self.assertEqual( json.dumps(repr(t)), json.dumps(t, tuple_as_array=False, default=repr)) def test_tuple_array_dump(self): t = (1, 2, 3) expect = json.dumps(list(t)) # Default is True sio = StringIO() json.dump(t, sio) self.assertEqual(expect, sio.getvalue()) sio = StringIO() json.dump(t, sio, tuple_as_array=True) self.assertEqual(expect, sio.getvalue()) self.assertRaises(TypeError, json.dump, t, StringIO(), tuple_as_array=False) # Ensure that the "default" does not get called sio = StringIO() json.dump(t, sio, default=repr) self.assertEqual(expect, sio.getvalue()) sio = StringIO() json.dump(t, sio, tuple_as_array=True, default=repr) self.assertEqual(expect, sio.getvalue()) # Ensure that the "default" gets called sio = StringIO() json.dump(t, sio, tuple_as_array=False, default=repr) self.assertEqual( json.dumps(repr(t)), sio.getvalue()) ================================================ FILE: simplejson/tests/test_unicode.py ================================================ import sys import codecs from unittest import TestCase import simplejson as json from simplejson.compat import unichr, text_type, b, BytesIO class TestUnicode(TestCase): def test_encoding1(self): encoder = json.JSONEncoder(encoding='utf-8') u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' s = u.encode('utf-8') ju = encoder.encode(u) js = encoder.encode(s) self.assertEqual(ju, js) def test_encoding2(self): u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' s = u.encode('utf-8') ju = json.dumps(u, encoding='utf-8') js = json.dumps(s, encoding='utf-8') self.assertEqual(ju, js) def test_encoding3(self): u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' j = json.dumps(u) self.assertEqual(j, '"\\u03b1\\u03a9"') def test_encoding4(self): u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' j = json.dumps([u]) self.assertEqual(j, '["\\u03b1\\u03a9"]') def test_encoding5(self): u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' j = json.dumps(u, ensure_ascii=False) self.assertEqual(j, u'"' + u + u'"') def test_encoding6(self): u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' j = json.dumps([u], ensure_ascii=False) self.assertEqual(j, u'["' + u + u'"]') def test_big_unicode_encode(self): u = u'\U0001d120' self.assertEqual(json.dumps(u), '"\\ud834\\udd20"') self.assertEqual(json.dumps(u, ensure_ascii=False), u'"\U0001d120"') def test_big_unicode_decode(self): u = u'z\U0001d120x' self.assertEqual(json.loads('"' + u + '"'), u) self.assertEqual(json.loads('"z\\ud834\\udd20x"'), u) def test_unicode_decode(self): for i in range(0, 0xd7ff): u = unichr(i) #s = '"\\u{0:04x}"'.format(i) s = '"\\u%04x"' % (i,) self.assertEqual(json.loads(s), u) def test_object_pairs_hook_with_unicode(self): s = u'{"xkd":1, "kcw":2, "art":3, "hxm":4, "qrt":5, "pad":6, "hoy":7}' p = [(u"xkd", 1), (u"kcw", 2), (u"art", 3), (u"hxm", 4), (u"qrt", 5), (u"pad", 6), (u"hoy", 7)] self.assertEqual(json.loads(s), eval(s)) self.assertEqual(json.loads(s, object_pairs_hook=lambda x: x), p) od = json.loads(s, object_pairs_hook=json.OrderedDict) self.assertEqual(od, json.OrderedDict(p)) self.assertEqual(type(od), json.OrderedDict) # the object_pairs_hook takes priority over the object_hook self.assertEqual(json.loads(s, object_pairs_hook=json.OrderedDict, object_hook=lambda x: None), json.OrderedDict(p)) def test_default_encoding(self): self.assertEqual(json.loads(u'{"a": "\xe9"}'.encode('utf-8')), {'a': u'\xe9'}) def test_unicode_preservation(self): self.assertEqual(type(json.loads(u'""')), text_type) self.assertEqual(type(json.loads(u'"a"')), text_type) self.assertEqual(type(json.loads(u'["a"]')[0]), text_type) def test_ensure_ascii_false_returns_unicode(self): # http://code.google.com/p/simplejson/issues/detail?id=48 self.assertEqual(type(json.dumps([], ensure_ascii=False)), text_type) self.assertEqual(type(json.dumps(0, ensure_ascii=False)), text_type) self.assertEqual(type(json.dumps({}, ensure_ascii=False)), text_type) self.assertEqual(type(json.dumps("", ensure_ascii=False)), text_type) def test_ensure_ascii_false_bytestring_encoding(self): # http://code.google.com/p/simplejson/issues/detail?id=48 doc1 = {u'quux': b('Arr\xc3\xaat sur images')} doc2 = {u'quux': u'Arr\xeat sur images'} doc_ascii = '{"quux": "Arr\\u00eat sur images"}' doc_unicode = u'{"quux": "Arr\xeat sur images"}' self.assertEqual(json.dumps(doc1), doc_ascii) self.assertEqual(json.dumps(doc2), doc_ascii) self.assertEqual(json.dumps(doc1, ensure_ascii=False), doc_unicode) self.assertEqual(json.dumps(doc2, ensure_ascii=False), doc_unicode) def test_ensure_ascii_linebreak_encoding(self): # http://timelessrepo.com/json-isnt-a-javascript-subset s1 = u'\u2029\u2028' s2 = s1.encode('utf8') expect = '"\\u2029\\u2028"' expect_non_ascii = u'"\u2029\u2028"' self.assertEqual(json.dumps(s1), expect) self.assertEqual(json.dumps(s2), expect) self.assertEqual(json.dumps(s1, ensure_ascii=False), expect_non_ascii) self.assertEqual(json.dumps(s2, ensure_ascii=False), expect_non_ascii) def test_invalid_escape_sequences(self): # incomplete escape sequence self.assertRaises(json.JSONDecodeError, json.loads, '"\\u') self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1') self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12') self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123') self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1234') # invalid escape sequence self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123x"') self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12x4"') self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1x34"') self.assertRaises(json.JSONDecodeError, json.loads, '"\\ux234"') if sys.maxunicode > 65535: # invalid escape sequence for low surrogate self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u"') self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0"') self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00"') self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000"') self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000x"') self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00x0"') self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0x00"') self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\ux000"') def test_ensure_ascii_still_works(self): # in the ascii range, ensure that everything is the same for c in map(unichr, range(0, 127)): self.assertEqual( json.dumps(c, ensure_ascii=False), json.dumps(c)) snowman = u'\N{SNOWMAN}' self.assertEqual( json.dumps(c, ensure_ascii=False), '"' + c + '"') def test_strip_bom(self): content = u"\u3053\u3093\u306b\u3061\u308f" json_doc = codecs.BOM_UTF8 + b(json.dumps(content)) self.assertEqual(json.load(BytesIO(json_doc)), content) for doc in json_doc, json_doc.decode('utf8'): self.assertEqual(json.loads(doc), content) ================================================ FILE: simplejson/tool.py ================================================ r"""Command-line tool to validate and pretty-print JSON Usage:: $ echo '{"json":"obj"}' | python -m simplejson.tool { "json": "obj" } $ echo '{ 1.2:3.4}' | python -m simplejson.tool Expecting property name: line 1 column 2 (char 2) """ from __future__ import with_statement import sys import simplejson as json def main(): if len(sys.argv) == 1: infile = sys.stdin outfile = sys.stdout elif len(sys.argv) == 2: infile = open(sys.argv[1], 'r') outfile = sys.stdout elif len(sys.argv) == 3: infile = open(sys.argv[1], 'r') outfile = open(sys.argv[2], 'w') else: raise SystemExit(sys.argv[0] + " [infile [outfile]]") with infile: try: obj = json.load(infile, object_pairs_hook=json.OrderedDict, use_decimal=True) except ValueError: raise SystemExit(sys.exc_info()[1]) with outfile: json.dump(obj, outfile, sort_keys=True, indent=' ', use_decimal=True) outfile.write('\n') if __name__ == '__main__': main() ================================================ FILE: train.py ================================================ import time import torchvision.transforms as transforms import torchvision.datasets as datasets import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') from utils import * from network import * from configs import * import math import argparse import models.resnet as resnet import models.densenet as densenet from models import create_model parser = argparse.ArgumentParser(description='Training code for GFNet') parser.add_argument('--data_url', default='./data', type=str, help='path to the dataset (ImageNet)') parser.add_argument('--work_dirs', default='./output', type=str, help='path to save log and checkpoints') parser.add_argument('--train_stage', default=-1, type=int, help='select training stage, see our paper for details \ stage-1 : warm-up \ stage-2 : learn to select patches with RL \ stage-3 : finetune CNNs') parser.add_argument('--model_arch', default='', type=str, help='architecture of the model to be trained \ resnet50 / resnet101 / \ densenet121 / densenet169 / densenet201 / \ regnety_600m / regnety_800m / regnety_1.6g / \ mobilenetv3_large_100 / mobilenetv3_large_125 / \ efficientnet_b2 / efficientnet_b3') parser.add_argument('--patch_size', default=96, type=int, help='size of local patches (we recommend 96 / 128 / 144)') parser.add_argument('--T', default=4, type=int, help='maximum length of the sequence of Glance + Focus') parser.add_argument('--print_freq', default=100, type=int, help='the frequency of printing log') parser.add_argument('--model_prime_path', default='', type=str, help='path to the pre-trained model of Global Encoder (for training stage-1)') parser.add_argument('--model_path', default='', type=str, help='path to the pre-trained model of Local Encoder (for training stage-1)') parser.add_argument('--checkpoint_path', default='', type=str, help='path to the stage-2/3 checkpoint (for training stage-2/3)') parser.add_argument('--resume', default='', type=str, help='path to the checkpoint for resuming') args = parser.parse_args() def main(): if not os.path.isdir(args.work_dirs): mkdir_p(args.work_dirs) record_path = args.work_dirs + '/GF-' + str(args.model_arch) \ + '_patch-size-' + str(args.patch_size) \ + '_T' + str(args.T) \ + '_train-stage' + str(args.train_stage) if not os.path.isdir(record_path): mkdir_p(record_path) record_file = record_path + '/record.txt' # *create model* # model_configuration = model_configurations[args.model_arch] if 'resnet' in args.model_arch: model_arch = 'resnet' model = resnet.resnet50(pretrained=False) model_prime = resnet.resnet50(pretrained=False) elif 'densenet' in args.model_arch: model_arch = 'densenet' model = eval('densenet.' + args.model_arch)(pretrained=False) model_prime = eval('densenet.' + args.model_arch)(pretrained=False) elif 'efficientnet' in args.model_arch: model_arch = 'efficientnet' model = create_model(args.model_arch, pretrained=False, num_classes=1000, drop_rate=0.3, drop_connect_rate=0.2) model_prime = create_model(args.model_arch, pretrained=False, num_classes=1000, drop_rate=0.3, drop_connect_rate=0.2) elif 'mobilenetv3' in args.model_arch: model_arch = 'mobilenetv3' model = create_model(args.model_arch, pretrained=False, num_classes=1000, drop_rate=0.2, drop_connect_rate=0.2) model_prime = create_model(args.model_arch, pretrained=False, num_classes=1000, drop_rate=0.2, drop_connect_rate=0.2) elif 'regnet' in args.model_arch: model_arch = 'regnet' import pycls.core.model_builder as model_builder from pycls.core.config import cfg cfg.merge_from_file(model_configuration['cfg_file']) cfg.freeze() model = model_builder.build_model() model_prime = model_builder.build_model() fc = Full_layer(model_configuration['feature_num'], model_configuration['fc_hidden_dim'], model_configuration['fc_rnn']) if args.train_stage == 1: model.load_state_dict(torch.load(args.model_path)) model_prime.load_state_dict(torch.load(args.model_prime_path)) else: checkpoint = torch.load(args.checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) model_prime.load_state_dict(checkpoint['model_prime_state_dict']) fc.load_state_dict(checkpoint['fc']) train_configuration = train_configurations[model_arch] if args.train_stage != 2: if train_configuration['train_model_prime']: optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': model_prime.parameters()}, {'params': fc.parameters()}], lr=0, # specify in adjust_learning_rate() momentum=train_configuration['momentum'], nesterov=train_configuration['Nesterov'], weight_decay=train_configuration['weight_decay']) else: optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': fc.parameters()}], lr=0, # specify in adjust_learning_rate() momentum=train_configuration['momentum'], nesterov=train_configuration['Nesterov'], weight_decay=train_configuration['weight_decay']) training_epoch_num = train_configuration['epoch_num'] else: optimizer = None training_epoch_num = 15 criterion = nn.CrossEntropyLoss().cuda() model = nn.DataParallel(model.cuda()) model_prime = nn.DataParallel(model_prime.cuda()) fc = fc.cuda() traindir = args.data_url + 'train/' valdir = args.data_url + 'val/' normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_set = datasets.ImageFolder(traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_set_index = torch.randperm(len(train_set)) train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, num_workers=32, pin_memory=False, sampler=torch.utils.data.sampler.SubsetRandomSampler( train_set_index[:])) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=train_configuration['batch_size'], shuffle=False, num_workers=32, pin_memory=False) if args.train_stage != 1: state_dim = model_configuration['feature_map_channels'] * math.ceil(args.patch_size / 32) * math.ceil(args.patch_size / 32) ppo = PPO(model_configuration['feature_map_channels'], state_dim, model_configuration['policy_hidden_dim'], model_configuration['policy_conv']) if args.train_stage == 3: ppo.policy.load_state_dict(checkpoint['policy']) ppo.policy_old.load_state_dict(checkpoint['policy']) else: ppo = None memory = Memory() if args.resume: resume_ckp = torch.load(args.resume) start_epoch = resume_ckp['epoch'] print('resume from epoch: {}'.format(start_epoch)) model.module.load_state_dict(resume_ckp['model_state_dict']) model_prime.module.load_state_dict(resume_ckp['model_prime_state_dict']) fc.load_state_dict(resume_ckp['fc']) if optimizer: optimizer.load_state_dict(resume_ckp['optimizer']) if ppo: ppo.policy.load_state_dict(resume_ckp['policy']) ppo.policy_old.load_state_dict(resume_ckp['policy']) ppo.optimizer.load_state_dict(resume_ckp['ppo_optimizer']) best_acc = resume_ckp['best_acc'] else: start_epoch = 0 best_acc = 0 for epoch in range(start_epoch, training_epoch_num): if args.train_stage != 2: print('Training Stage: {}, lr:'.format(args.train_stage)) adjust_learning_rate(optimizer, train_configuration, epoch, training_epoch_num, args) else: print('Training Stage: {}, train ppo only'.format(args.train_stage)) train(model_prime, model, fc, memory, ppo, optimizer, train_loader, criterion, args.print_freq, epoch, train_configuration['batch_size'], record_file, train_configuration, args) acc = validate(model_prime, model, fc, memory, ppo, optimizer, val_loader, criterion, args.print_freq, epoch, train_configuration['batch_size'], record_file, train_configuration, args) if acc > best_acc: best_acc = acc is_best = True else: is_best = False save_checkpoint({ 'epoch': epoch + 1, 'model_state_dict': model.module.state_dict(), 'model_prime_state_dict': model_prime.module.state_dict(), 'fc': fc.state_dict(), 'acc': acc, 'best_acc': best_acc, 'optimizer': optimizer.state_dict() if optimizer else None, 'ppo_optimizer': ppo.optimizer.state_dict() if ppo else None, 'policy': ppo.policy.state_dict() if ppo else None, }, is_best, checkpoint=record_path) def train(model_prime, model, fc, memory, ppo, optimizer, train_loader, criterion, print_freq, epoch, batch_size, record_file, train_configuration, args): batch_time = AverageMeter() losses = [AverageMeter() for _ in range(args.T)] top1 = [AverageMeter() for _ in range(args.T)] reward_list = [AverageMeter() for _ in range(args.T - 1)] train_batches_num = len(train_loader) if args.train_stage == 2: model_prime.eval() model.eval() fc.eval() else: if train_configuration['train_model_prime']: model_prime.train() else: model_prime.eval() model.train() fc.train() if 'resnet' in args.model_arch or 'densenet' in args.model_arch or 'regnet' in args.model_arch: dsn_fc_prime = model_prime.module.fc dsn_fc = model.module.fc else: dsn_fc_prime = model_prime.module.classifier dsn_fc = model.module.classifier fd = open(record_file, 'a+') end = time.time() for i, (x, target) in enumerate(train_loader): loss_cla = [] loss_list_dsn = [] target_var = target.cuda() input_var = x.cuda() input_prime = get_prime(input_var, args.patch_size) if train_configuration['train_model_prime'] and args.train_stage != 2: output, state = model_prime(input_prime) assert 'resnet' in args.model_arch or 'densenet' in args.model_arch or 'regnet' in args.model_arch output_dsn = dsn_fc_prime(output) output = fc(output, restart=True) else: with torch.no_grad(): output, state = model_prime(input_prime) if 'resnet' in args.model_arch or 'densenet' in args.model_arch or 'regnet' in args.model_arch: output_dsn = dsn_fc_prime(output) output = fc(output, restart=True) else: _ = fc(output, restart=True) output = model_prime.module.classifier(output) output_dsn = output loss_prime = criterion(output, target_var) loss_cla.append(loss_prime) loss_dsn = criterion(output_dsn, target_var) loss_list_dsn.append(loss_dsn) losses[0].update(loss_prime.data.item(), x.size(0)) acc = accuracy(output, target_var, topk=(1,)) top1[0].update(acc.sum(0).mul_(100.0 / batch_size).data.item(), x.size(0)) confidence_last = torch.gather(F.softmax(output.detach(), 1), dim=1, index=target_var.view(-1, 1)).view(1, -1) for patch_step in range(1, args.T): if args.train_stage == 1: action = torch.rand(x.size(0), 2).cuda() else: if patch_step == 1: action = ppo.select_action(state.to(0), memory, restart_batch=True) else: action = ppo.select_action(state.to(0), memory) patches = get_patch(input_var, action, args.patch_size) if args.train_stage != 2: output, state = model(patches) output_dsn = dsn_fc(output) output = fc(output, restart=False) else: with torch.no_grad(): output, state = model(patches) output_dsn = dsn_fc(output) output = fc(output, restart=False) loss = criterion(output, target_var) loss_cla.append(loss) losses[patch_step].update(loss.data.item(), x.size(0)) loss_dsn = criterion(output_dsn, target_var) loss_list_dsn.append(loss_dsn) acc = accuracy(output, target_var, topk=(1,)) top1[patch_step].update(acc.sum(0).mul_(100.0 / batch_size).data.item(), x.size(0)) confidence = torch.gather(F.softmax(output.detach(), 1), dim=1, index=target_var.view(-1, 1)).view(1, -1) reward = confidence - confidence_last confidence_last = confidence reward_list[patch_step - 1].update(reward.data.mean(), x.size(0)) memory.rewards.append(reward) loss = (sum(loss_cla) + train_configuration['dsn_ratio'] * sum(loss_list_dsn)) / args.T if args.train_stage != 2: optimizer.zero_grad() loss.backward() optimizer.step() else: ppo.update(memory) memory.clear_memory() batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0 or i == train_batches_num - 1: string = ('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.value:.3f} ({batch_time.ave:.3f})\t' 'Loss {loss.value:.4f} ({loss.ave:.4f})\t'.format( epoch, i + 1, train_batches_num, batch_time=batch_time, loss=losses[-1])) print(string) fd.write(string + '\n') _acc = [acc.ave for acc in top1] print('accuracy of each step:') print(_acc) fd.write('accuracy of each step:\n') fd.write(str(_acc) + '\n') _reward = [reward.ave for reward in reward_list] print('reward of each step:') print(_reward) fd.write('reward of each step:\n') fd.write(str(_reward) + '\n') fd.close() def validate(model_prime, model, fc, memory, ppo, _, val_loader, criterion, print_freq, epoch, batch_size, record_file, __, args): batch_time = AverageMeter() losses = [AverageMeter() for _ in range(args.T)] top1 = [AverageMeter() for _ in range(args.T)] reward_list = [AverageMeter() for _ in range(args.T - 1)] train_batches_num = len(val_loader) model_prime.eval() model.eval() fc.eval() if 'resnet' in args.model_arch or 'densenet' in args.model_arch or 'regnet' in args.model_arch: dsn_fc_prime = model_prime.module.fc dsn_fc = model.module.fc else: dsn_fc_prime = model_prime.module.classifier dsn_fc = model.module.classifier fd = open(record_file, 'a+') end = time.time() with torch.no_grad(): for i, (x, target) in enumerate(val_loader): loss_cla = [] loss_list_dsn = [] target_var = target.cuda() input_var = x.cuda() input_prime = get_prime(input_var, args.patch_size) output, state = model_prime(input_prime) if 'resnet' in args.model_arch or 'densenet' in args.model_arch or 'regnet' in args.model_arch: output_dsn = dsn_fc_prime(output) output = fc(output, restart=True) else: _ = fc(output, restart=True) output = model_prime.module.classifier(output) output_dsn = output loss_prime = criterion(output, target_var) loss_cla.append(loss_prime) loss_dsn = criterion(output_dsn, target_var) loss_list_dsn.append(loss_dsn) losses[0].update(loss_prime.data.item(), x.size(0)) acc = accuracy(output, target_var, topk=(1,)) top1[0].update(acc.sum(0).mul_(100.0 / batch_size).data.item(), x.size(0)) confidence_last = torch.gather(F.softmax(output.detach(), 1), dim=1, index=target_var.view(-1, 1)).view(1, -1) for patch_step in range(1, args.T): if args.train_stage == 1: action = torch.rand(x.size(0), 2).cuda() else: if patch_step == 1: action = ppo.select_action(state.to(0), memory, restart_batch=True, training=False) else: action = ppo.select_action(state.to(0), memory, training=False) patches = get_patch(input_var, action, args.patch_size) output, state = model(patches) output_dsn = dsn_fc(output) output = fc(output, restart=False) loss = criterion(output, target_var) loss_cla.append(loss) losses[patch_step].update(loss.data.item(), x.size(0)) loss_dsn = criterion(output_dsn, target_var) loss_list_dsn.append(loss_dsn) acc = accuracy(output, target_var, topk=(1,)) top1[patch_step].update(acc.sum(0).mul_(100.0 / batch_size).data.item(), x.size(0)) confidence = torch.gather(F.softmax(output.detach(), 1), dim=1, index=target_var.view(-1, 1)).view(1, -1) reward = confidence - confidence_last confidence_last = confidence reward_list[patch_step - 1].update(reward.data.mean(), x.size(0)) memory.rewards.append(reward) memory.clear_memory() batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0 or i == train_batches_num - 1: string = ('Val: [{0}][{1}/{2}]\t' 'Time {batch_time.value:.3f} ({batch_time.ave:.3f})\t' 'Loss {loss.value:.4f} ({loss.ave:.4f})\t'.format( epoch, i + 1, train_batches_num, batch_time=batch_time, loss=losses[-1])) print(string) fd.write(string + '\n') _acc = [acc.ave for acc in top1] print('accuracy of each step:') print(_acc) fd.write('accuracy of each step:\n') fd.write(str(_acc) + '\n') _reward = [reward.ave for reward in reward_list] print('reward of each step:') print(_reward) fd.write('reward of each step:\n') fd.write(str(_reward) + '\n') fd.close() return top1[args.T - 1].ave if __name__ == '__main__': main() ================================================ FILE: utils.py ================================================ import os import errno import math import shutil import torch import torch.nn as nn import torch.nn.functional as F def mkdir_p(path): '''make dir if not exist''' try: os.mkdir(path) except OSError as exc: # Python >2.5 if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.value = 0 self.ave = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.value = val self.sum += val * n self.count += n self.ave = self.sum / self.count def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) correct_k = correct[:1].view(-1).float() return correct_k def get_prime(images, patch_size, interpolation='bicubic'): """Get down-sampled original image""" prime = F.interpolate(images, size=[patch_size, patch_size], mode=interpolation, align_corners=True) return prime def get_patch(images, action_sequence, patch_size): """Get small patch of the original image""" batch_size = images.size(0) image_size = images.size(2) patch_coordinate = torch.floor(action_sequence * (image_size - patch_size)).int() patches = [] for i in range(batch_size): per_patch = images[i, :, (patch_coordinate[i, 0].item()): ((patch_coordinate[i, 0] + patch_size).item()), (patch_coordinate[i, 1].item()): ((patch_coordinate[i, 1] + patch_size).item())] patches.append(per_patch.view(1, per_patch.size(0), per_patch.size(1), per_patch.size(2))) return torch.cat(patches, 0) def adjust_learning_rate(optimizer, train_configuration, epoch, training_epoch_num, args): """Sets the learning rate""" backbone_lr = 0.5 * train_configuration['backbone_lr'] * \ (1 + math.cos(math.pi * epoch / training_epoch_num)) if args.train_stage == 1: fc_lr = 0.5 * train_configuration['fc_stage_1_lr'] * \ (1 + math.cos(math.pi * epoch / training_epoch_num)) elif args.train_stage == 3: fc_lr = 0.5 * train_configuration['fc_stage_3_lr'] * \ (1 + math.cos(math.pi * epoch / training_epoch_num)) if train_configuration['train_model_prime']: optimizer.param_groups[0]['lr'] = backbone_lr optimizer.param_groups[1]['lr'] = backbone_lr optimizer.param_groups[2]['lr'] = fc_lr else: optimizer.param_groups[0]['lr'] = backbone_lr optimizer.param_groups[1]['lr'] = fc_lr for param_group in optimizer.param_groups: print(param_group['lr']) def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): filepath = checkpoint + '/' + filename torch.save(state, filepath) if is_best: shutil.copyfile(filepath, checkpoint + '/model_best.pth.tar') ================================================ FILE: yacs/__init__.py ================================================ ================================================ FILE: yacs/config.py ================================================ # Copyright (c) 2018-present, Facebook, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################## """YACS -- Yet Another Configuration System is designed to be a simple configuration management system for academic and industrial research projects. See README.md for usage and examples. """ import copy import io import logging import os import sys from ast import literal_eval import yaml # Flag for py2 and py3 compatibility to use when separate code paths are necessary # When _PY2 is False, we assume Python 3 is in use _PY2 = sys.version_info.major == 2 # Filename extensions for loading configs from files _YAML_EXTS = {"", ".yaml", ".yml"} _PY_EXTS = {".py"} # py2 and py3 compatibility for checking file object type # We simply use this to infer py2 vs py3 if _PY2: _FILE_TYPES = (file, io.IOBase) else: _FILE_TYPES = (io.IOBase,) # CfgNodes can only contain a limited set of valid types _VALID_TYPES = {tuple, list, str, int, float, bool} # py2 allow for str and unicode if _PY2: _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821 # Utilities for importing modules from file paths if _PY2: # imp is available in both py2 and py3 for now, but is deprecated in py3 import imp else: import importlib.util logger = logging.getLogger(__name__) class CfgNode(dict): """ CfgNode represents an internal node in the configuration tree. It's a simple dict-like container that allows for attribute-based access to keys. """ IMMUTABLE = "__immutable__" DEPRECATED_KEYS = "__deprecated_keys__" RENAMED_KEYS = "__renamed_keys__" NEW_ALLOWED = "__new_allowed__" def __init__(self, init_dict=None, key_list=None, new_allowed=False): """ Args: init_dict (dict): the possibly-nested dictionary to initailize the CfgNode. key_list (list[str]): a list of names which index this CfgNode from the root. Currently only used for logging purposes. new_allowed (bool): whether adding new key is allowed when merging with other configs. """ # Recursively convert nested dictionaries in init_dict into CfgNodes init_dict = {} if init_dict is None else init_dict key_list = [] if key_list is None else key_list init_dict = self._create_config_tree_from_dict(init_dict, key_list) super(CfgNode, self).__init__(init_dict) # Manage if the CfgNode is frozen or not self.__dict__[CfgNode.IMMUTABLE] = False # Deprecated options # If an option is removed from the code and you don't want to break existing # yaml configs, you can add the full config key as a string to the set below. self.__dict__[CfgNode.DEPRECATED_KEYS] = set() # Renamed options # If you rename a config option, record the mapping from the old name to the new # name in the dictionary below. Optionally, if the type also changed, you can # make the value a tuple that specifies first the renamed key and then # instructions for how to edit the config file. self.__dict__[CfgNode.RENAMED_KEYS] = { # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow # 'EXAMPLE.OLD.KEY': ( # A more complex example to follow # 'EXAMPLE.NEW.KEY', # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or " # + "'foo:bar' -> ('foo', 'bar')" # ), } # Allow new attributes after initialisation self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed @classmethod def _create_config_tree_from_dict(cls, dic, key_list): """ Create a configuration tree using the given dict. Any dict-like objects inside dict will be treated as a new CfgNode. Args: dic (dict): key_list (list[str]): a list of names which index this CfgNode from the root. Currently only used for logging purposes. """ dic = copy.deepcopy(dic) for k, v in dic.items(): if isinstance(v, dict): # Convert dict to CfgNode dic[k] = cls(v, key_list=key_list + [k]) else: # Check for valid leaf type or nested CfgNode _assert_with_logging( _valid_type(v, allow_cfg_node=False), "Key {} with value {} is not a valid type; valid types: {}".format( ".".join(key_list + [k]), type(v), _VALID_TYPES ), ) return dic def __getattr__(self, name): if name in self: return self[name] else: raise AttributeError(name) def __setattr__(self, name, value): if self.is_frozen(): raise AttributeError( "Attempted to set {} to {}, but CfgNode is immutable".format( name, value ) ) _assert_with_logging( name not in self.__dict__, "Invalid attempt to modify internal CfgNode state: {}".format(name), ) _assert_with_logging( _valid_type(value, allow_cfg_node=True), "Invalid type {} for key {}; valid types = {}".format( type(value), name, _VALID_TYPES ), ) self[name] = value def __str__(self): def _indent(s_, num_spaces): s = s_.split("\n") if len(s) == 1: return s_ first = s.pop(0) s = [(num_spaces * " ") + line for line in s] s = "\n".join(s) s = first + "\n" + s return s r = "" s = [] for k, v in sorted(self.items()): seperator = "\n" if isinstance(v, CfgNode) else " " attr_str = "{}:{}{}".format(str(k), seperator, str(v)) attr_str = _indent(attr_str, 2) s.append(attr_str) r += "\n".join(s) return r def __repr__(self): return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) def dump(self, **kwargs): """Dump to a string.""" def convert_to_dict(cfg_node, key_list): if not isinstance(cfg_node, CfgNode): _assert_with_logging( _valid_type(cfg_node), "Key {} with value {} is not a valid type; valid types: {}".format( ".".join(key_list), type(cfg_node), _VALID_TYPES ), ) return cfg_node else: cfg_dict = dict(cfg_node) for k, v in cfg_dict.items(): cfg_dict[k] = convert_to_dict(v, key_list + [k]) return cfg_dict self_as_dict = convert_to_dict(self, []) return yaml.safe_dump(self_as_dict, **kwargs) def merge_from_file(self, cfg_filename): """Load a yaml config file and merge it this CfgNode.""" with open(cfg_filename, "r") as f: cfg = self.load_cfg(f) self.merge_from_other_cfg(cfg) def merge_from_other_cfg(self, cfg_other): """Merge `cfg_other` into this CfgNode.""" _merge_a_into_b(cfg_other, self, self, []) def merge_from_list(self, cfg_list): """Merge config (keys, values) in a list (e.g., from command line) into this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`. """ _assert_with_logging( len(cfg_list) % 2 == 0, "Override list has odd length: {}; it must be a list of pairs".format( cfg_list ), ) root = self for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): if root.key_is_deprecated(full_key): continue if root.key_is_renamed(full_key): root.raise_key_rename_error(full_key) key_list = full_key.split(".") d = self for subkey in key_list[:-1]: _assert_with_logging( subkey in d, "Non-existent key: {}".format(full_key) ) d = d[subkey] subkey = key_list[-1] _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key)) value = self._decode_cfg_value(v) value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key) d[subkey] = value def freeze(self): """Make this CfgNode and all of its children immutable.""" self._immutable(True) def defrost(self): """Make this CfgNode and all of its children mutable.""" self._immutable(False) def is_frozen(self): """Return mutability.""" return self.__dict__[CfgNode.IMMUTABLE] def _immutable(self, is_immutable): """Set immutability to is_immutable and recursively apply the setting to all nested CfgNodes. """ self.__dict__[CfgNode.IMMUTABLE] = is_immutable # Recursively set immutable state for v in self.__dict__.values(): if isinstance(v, CfgNode): v._immutable(is_immutable) for v in self.values(): if isinstance(v, CfgNode): v._immutable(is_immutable) def clone(self): """Recursively copy this CfgNode.""" return copy.deepcopy(self) def register_deprecated_key(self, key): """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated keys a warning is generated and the key is ignored. """ _assert_with_logging( key not in self.__dict__[CfgNode.DEPRECATED_KEYS], "key {} is already registered as a deprecated key".format(key), ) self.__dict__[CfgNode.DEPRECATED_KEYS].add(key) def register_renamed_key(self, old_name, new_name, message=None): """Register a key as having been renamed from `old_name` to `new_name`. When merging a renamed key, an exception is thrown alerting to user to the fact that the key has been renamed. """ _assert_with_logging( old_name not in self.__dict__[CfgNode.RENAMED_KEYS], "key {} is already registered as a renamed cfg key".format(old_name), ) value = new_name if message: value = (new_name, message) self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value def key_is_deprecated(self, full_key): """Test if a key is deprecated.""" if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]: logger.warning("Deprecated config key (ignoring): {}".format(full_key)) return True return False def key_is_renamed(self, full_key): """Test if a key is renamed.""" return full_key in self.__dict__[CfgNode.RENAMED_KEYS] def raise_key_rename_error(self, full_key): new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key] if isinstance(new_key, tuple): msg = " Note: " + new_key[1] new_key = new_key[0] else: msg = "" raise KeyError( "Key {} was renamed to {}; please update your config.{}".format( full_key, new_key, msg ) ) def is_new_allowed(self): return self.__dict__[CfgNode.NEW_ALLOWED] @classmethod def load_cfg(cls, cfg_file_obj_or_str): """ Load a cfg. Args: cfg_file_obj_or_str (str or file): Supports loading from: - A file object backed by a YAML file - A file object backed by a Python source file that exports an attribute "cfg" that is either a dict or a CfgNode - A string that can be parsed as valid YAML """ _assert_with_logging( isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)), "Expected first argument to be of type {} or {}, but it was {}".format( _FILE_TYPES, str, type(cfg_file_obj_or_str) ), ) if isinstance(cfg_file_obj_or_str, str): return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str) elif isinstance(cfg_file_obj_or_str, _FILE_TYPES): return cls._load_cfg_from_file(cfg_file_obj_or_str) else: raise NotImplementedError("Impossible to reach here (unless there's a bug)") @classmethod def _load_cfg_from_file(cls, file_obj): """Load a config from a YAML file or a Python source file.""" _, file_extension = os.path.splitext(file_obj.name) if file_extension in _YAML_EXTS: return cls._load_cfg_from_yaml_str(file_obj.read()) elif file_extension in _PY_EXTS: return cls._load_cfg_py_source(file_obj.name) else: raise Exception( "Attempt to load from an unsupported file type {}; " "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS)) ) @classmethod def _load_cfg_from_yaml_str(cls, str_obj): """Load a config from a YAML string encoding.""" cfg_as_dict = yaml.safe_load(str_obj) return cls(cfg_as_dict) @classmethod def _load_cfg_py_source(cls, filename): """Load a config from a Python source file.""" module = _load_module_from_file("yacs.config.override", filename) _assert_with_logging( hasattr(module, "cfg"), "Python module from file {} must have 'cfg' attr".format(filename), ) VALID_ATTR_TYPES = {dict, CfgNode} _assert_with_logging( type(module.cfg) in VALID_ATTR_TYPES, "Imported module 'cfg' attr must be in {} but is {} instead".format( VALID_ATTR_TYPES, type(module.cfg) ), ) return cls(module.cfg) @classmethod def _decode_cfg_value(cls, value): """ Decodes a raw config value (e.g., from a yaml config files or command line argument) into a Python object. If the value is a dict, it will be interpreted as a new CfgNode. If the value is a str, it will be evaluated as literals. Otherwise it is returned as-is. """ # Configs parsed from raw yaml will contain dictionary keys that need to be # converted to CfgNode objects if isinstance(value, dict): return cls(value) # All remaining processing is only applied to strings if not isinstance(value, str): return value # Try to interpret `value` as a: # string, number, tuple, list, dict, boolean, or None try: value = literal_eval(value) # The following two excepts allow v to pass through when it represents a # string. # # Longer explanation: # The type of v is always a string (before calling literal_eval), but # sometimes it *represents* a string and other times a data structure, like # a list. In the case that v represents a string, what we got back from the # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is # ok with '"foo"', but will raise a ValueError if given 'foo'. In other # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval # will raise a SyntaxError. except ValueError: pass except SyntaxError: pass return value load_cfg = ( CfgNode.load_cfg ) # keep this function in global scope for backward compatibility def _valid_type(value, allow_cfg_node=False): return (type(value) in _VALID_TYPES) or ( allow_cfg_node and isinstance(value, CfgNode) ) def _merge_a_into_b(a, b, root, key_list): """Merge config dictionary a into config dictionary b, clobbering the options in b whenever they are also specified in a. """ _assert_with_logging( isinstance(a, CfgNode), "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode), ) _assert_with_logging( isinstance(b, CfgNode), "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode), ) for k, v_ in a.items(): full_key = ".".join(key_list + [k]) v = copy.deepcopy(v_) v = b._decode_cfg_value(v) if k in b: v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) # Recursively merge dicts if isinstance(v, CfgNode): try: _merge_a_into_b(v, b[k], root, key_list + [k]) except BaseException: raise else: b[k] = v elif b.is_new_allowed(): b[k] = v else: if root.key_is_deprecated(full_key): continue elif root.key_is_renamed(full_key): root.raise_key_rename_error(full_key) else: raise KeyError("Non-existent config key: {}".format(full_key)) def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): """Checks that `replacement`, which is intended to replace `original` is of the right type. The type is correct if it matches exactly or is one of a few cases in which the type can be easily coerced. """ original_type = type(original) replacement_type = type(replacement) # The types must match (with some exceptions) if replacement_type == original_type: return replacement # Cast replacement from from_type to to_type if the replacement and original # types match from_type and to_type def conditional_cast(from_type, to_type): if replacement_type == from_type and original_type == to_type: return True, to_type(replacement) else: return False, None # Conditionally casts # list <-> tuple casts = [(tuple, list), (list, tuple)] # For py2: allow converting from str (bytes) to a unicode string try: casts.append((str, unicode)) # noqa: F821 except Exception: pass for (from_type, to_type) in casts: converted, converted_value = conditional_cast(from_type, to_type) if converted: return converted_value raise ValueError( "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " "key: {}".format( original_type, replacement_type, original, replacement, full_key ) ) def _assert_with_logging(cond, msg): if not cond: logger.debug(msg) assert cond, msg def _load_module_from_file(name, filename): if _PY2: module = imp.load_source(name, filename) else: spec = importlib.util.spec_from_file_location(name, filename) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module ================================================ FILE: yacs/tests.py ================================================ import logging import tempfile import unittest import yacs.config from yacs.config import CfgNode as CN try: _ignore = unicode # noqa: F821 PY2 = True except Exception as _ignore: PY2 = False class SubCN(CN): pass def get_cfg(cls=CN): cfg = cls() cfg.NUM_GPUS = 8 cfg.TRAIN = cls() cfg.TRAIN.HYPERPARAMETER_1 = 0.1 cfg.TRAIN.SCALES = (2, 4, 8, 16) cfg.MODEL = cls() cfg.MODEL.TYPE = "a_foo_model" # Some extra stuff to test CfgNode.__str__ cfg.STR = cls() cfg.STR.KEY1 = 1 cfg.STR.KEY2 = 2 cfg.STR.FOO = cls() cfg.STR.FOO.KEY1 = 1 cfg.STR.FOO.KEY2 = 2 cfg.STR.FOO.BAR = cls() cfg.STR.FOO.BAR.KEY1 = 1 cfg.STR.FOO.BAR.KEY2 = 2 cfg.register_deprecated_key("FINAL_MSG") cfg.register_deprecated_key("MODEL.DILATION") cfg.register_renamed_key( "EXAMPLE.OLD.KEY", "EXAMPLE.NEW.KEY", message="Please update your config fil config file.", ) cfg.KWARGS = cls(new_allowed=True) cfg.KWARGS.z = 0 cfg.KWARGS.Y = cls() cfg.KWARGS.Y.X = 1 return cfg class TestCfgNode(unittest.TestCase): def test_immutability(self): # Top level immutable a = CN() a.foo = 0 a.freeze() with self.assertRaises(AttributeError): a.foo = 1 a.bar = 1 assert a.is_frozen() assert a.foo == 0 a.defrost() assert not a.is_frozen() a.foo = 1 assert a.foo == 1 # Recursively immutable a.level1 = CN() a.level1.foo = 0 a.level1.level2 = CN() a.level1.level2.foo = 0 a.freeze() assert a.is_frozen() with self.assertRaises(AttributeError): a.level1.level2.foo = 1 a.level1.bar = 1 assert a.level1.level2.foo == 0 class TestCfg(unittest.TestCase): def test_copy_cfg(self): cfg = get_cfg() cfg2 = cfg.clone() s = cfg.MODEL.TYPE cfg2.MODEL.TYPE = "dummy" assert cfg.MODEL.TYPE == s def test_merge_cfg_from_cfg(self): # Test: merge from clone cfg = get_cfg() s = "dummy0" cfg2 = cfg.clone() cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s # Test: merge from yaml s = "dummy1" cfg2 = CN.load_cfg(cfg.dump()) cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s # Test: merge with a valid key s = "dummy2" cfg2 = CN() cfg2.MODEL = CN() cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s # Test: merge with an invalid key s = "dummy3" cfg2 = CN() cfg2.FOO = CN() cfg2.FOO.BAR = s with self.assertRaises(KeyError): cfg.merge_from_other_cfg(cfg2) # Test: merge with converted type cfg2 = CN() cfg2.TRAIN = CN() cfg2.TRAIN.SCALES = [1] cfg.merge_from_other_cfg(cfg2) assert type(cfg.TRAIN.SCALES) is tuple assert cfg.TRAIN.SCALES[0] == 1 # Test str (bytes) <-> unicode conversion for py2 if PY2: cfg.A_UNICODE_KEY = u"foo" cfg2 = CN() cfg2.A_UNICODE_KEY = b"bar" cfg.merge_from_other_cfg(cfg2) assert type(cfg.A_UNICODE_KEY) == unicode # noqa: F821 assert cfg.A_UNICODE_KEY == u"bar" # Test: merge with invalid type cfg2 = CN() cfg2.TRAIN = CN() cfg2.TRAIN.SCALES = 1 with self.assertRaises(ValueError): cfg.merge_from_other_cfg(cfg2) def test_merge_cfg_from_file(self): with tempfile.NamedTemporaryFile(mode="wt") as f: cfg = get_cfg() f.write(cfg.dump()) f.flush() s = cfg.MODEL.TYPE cfg.MODEL.TYPE = "dummy" assert cfg.MODEL.TYPE != s cfg.merge_from_file(f.name) assert cfg.MODEL.TYPE == s def test_merge_cfg_from_list(self): cfg = get_cfg() opts = ["TRAIN.SCALES", "(100, )", "MODEL.TYPE", "foobar", "NUM_GPUS", 2] assert len(cfg.TRAIN.SCALES) > 0 assert cfg.TRAIN.SCALES[0] != 100 assert cfg.MODEL.TYPE != "foobar" assert cfg.NUM_GPUS != 2 cfg.merge_from_list(opts) assert type(cfg.TRAIN.SCALES) is tuple assert len(cfg.TRAIN.SCALES) == 1 assert cfg.TRAIN.SCALES[0] == 100 assert cfg.MODEL.TYPE == "foobar" assert cfg.NUM_GPUS == 2 def test_deprecated_key_from_list(self): # You should see logger messages like: # "Deprecated config key (ignoring): MODEL.DILATION" cfg = get_cfg() opts = ["FINAL_MSG", "foobar", "MODEL.DILATION", 2] with self.assertRaises(AttributeError): _ = cfg.FINAL_MSG # noqa with self.assertRaises(AttributeError): _ = cfg.MODEL.DILATION # noqa cfg.merge_from_list(opts) with self.assertRaises(AttributeError): _ = cfg.FINAL_MSG # noqa with self.assertRaises(AttributeError): _ = cfg.MODEL.DILATION # noqa def test_nonexistant_key_from_list(self): cfg = get_cfg() opts = ["MODEL.DOES_NOT_EXIST", "IGNORE"] with self.assertRaises(AssertionError): cfg.merge_from_list(opts) def test_load_cfg_invalid_type(self): # FOO.BAR.QUUX will have type None, which is not allowed cfg_string = "FOO:\n BAR:\n QUUX:" with self.assertRaises(AssertionError): yacs.config.load_cfg(cfg_string) def test_deprecated_key_from_file(self): # You should see logger messages like: # "Deprecated config key (ignoring): MODEL.DILATION" cfg = get_cfg() with tempfile.NamedTemporaryFile("wt") as f: cfg2 = cfg.clone() cfg2.MODEL.DILATION = 2 f.write(cfg2.dump()) f.flush() with self.assertRaises(AttributeError): _ = cfg.MODEL.DILATION # noqa cfg.merge_from_file(f.name) with self.assertRaises(AttributeError): _ = cfg.MODEL.DILATION # noqa def test_renamed_key_from_list(self): cfg = get_cfg() opts = ["EXAMPLE.OLD.KEY", "foobar"] with self.assertRaises(AttributeError): _ = cfg.EXAMPLE.OLD.KEY # noqa with self.assertRaises(KeyError): cfg.merge_from_list(opts) def test_renamed_key_from_file(self): cfg = get_cfg() with tempfile.NamedTemporaryFile("wt") as f: cfg2 = cfg.clone() cfg2.EXAMPLE = CN() cfg2.EXAMPLE.RENAMED = CN() cfg2.EXAMPLE.RENAMED.KEY = "foobar" f.write(cfg2.dump()) f.flush() with self.assertRaises(AttributeError): _ = cfg.EXAMPLE.RENAMED.KEY # noqa with self.assertRaises(KeyError): cfg.merge_from_file(f.name) def test_load_cfg_from_file(self): cfg = get_cfg() with tempfile.NamedTemporaryFile("wt") as f: f.write(cfg.dump()) f.flush() with open(f.name, "rt") as f_read: yacs.config.load_cfg(f_read) def test_load_from_python_file(self): # Case 1: exports CfgNode cfg = get_cfg() cfg.merge_from_file("example/config_override.py") assert cfg.TRAIN.HYPERPARAMETER_1 == 0.9 # Case 2: exports dict cfg = get_cfg() cfg.merge_from_file("example/config_override_from_dict.py") assert cfg.TRAIN.HYPERPARAMETER_1 == 0.9 def test_invalid_type(self): cfg = get_cfg() with self.assertRaises(AssertionError): cfg.INVALID_KEY_TYPE = object() def test__str__(self): expected_str = """ KWARGS: Y: X: 1 z: 0 MODEL: TYPE: a_foo_model NUM_GPUS: 8 STR: FOO: BAR: KEY1: 1 KEY2: 2 KEY1: 1 KEY2: 2 KEY1: 1 KEY2: 2 TRAIN: HYPERPARAMETER_1: 0.1 SCALES: (2, 4, 8, 16) """.strip() cfg = get_cfg() assert str(cfg) == expected_str def test_new_allowed(self): cfg = get_cfg() cfg.merge_from_file("example/config_new_allowed.yaml") assert cfg.KWARGS.a == 1 assert cfg.KWARGS.B.c == 2 assert cfg.KWARGS.B.D.e == "3" def test_new_allowed_bad(self): cfg = get_cfg() with self.assertRaises(KeyError): cfg.merge_from_file("example/config_new_allowed_bad.yaml") class TestCfgNodeSubclass(unittest.TestCase): def test_merge_cfg_from_file(self): with tempfile.NamedTemporaryFile(mode="wt") as f: cfg = get_cfg(SubCN) f.write(cfg.dump()) f.flush() s = cfg.MODEL.TYPE cfg.MODEL.TYPE = "dummy" assert cfg.MODEL.TYPE != s cfg.merge_from_file(f.name) assert cfg.MODEL.TYPE == s def test_merge_cfg_from_list(self): cfg = get_cfg(SubCN) opts = ["TRAIN.SCALES", "(100, )", "MODEL.TYPE", "foobar", "NUM_GPUS", 2] assert len(cfg.TRAIN.SCALES) > 0 assert cfg.TRAIN.SCALES[0] != 100 assert cfg.MODEL.TYPE != "foobar" assert cfg.NUM_GPUS != 2 cfg.merge_from_list(opts) assert type(cfg.TRAIN.SCALES) is tuple assert len(cfg.TRAIN.SCALES) == 1 assert cfg.TRAIN.SCALES[0] == 100 assert cfg.MODEL.TYPE == "foobar" assert cfg.NUM_GPUS == 2 def test_merge_cfg_from_cfg(self): cfg = get_cfg(SubCN) cfg2 = get_cfg(SubCN) s = "dummy0" cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s # Test: merge from yaml s = "dummy1" cfg2 = SubCN.load_cfg(cfg.dump()) cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s if __name__ == "__main__": logging.basicConfig() yacs_logger = logging.getLogger("yacs.config") yacs_logger.setLevel(logging.DEBUG) unittest.main()