Repository: D-X-Y/NATS-Bench Branch: main Commit: 1d4a304ad190 Files: 34 Total size: 199.1 KB Directory structure: gitextract_fmzpgwc0/ ├── .github/ │ ├── CODE-OF-CONDUCT.md │ ├── ISSUE_TEMPLATE/ │ │ ├── bug-report.md │ │ └── question.md │ └── workflows/ │ └── ci.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── fake_torch_dir/ │ ├── NATS-sss-v1_0-50262-simple/ │ │ ├── 000000.pickle.pbz2 │ │ ├── 000011.pickle.pbz2 │ │ ├── 000284.pickle.pbz2 │ │ └── meta.pickle.pbz2 │ └── NATS-tss-v1_0-3ffb9-simple/ │ ├── 000000.pickle.pbz2 │ ├── 000011.pickle.pbz2 │ ├── 000284.pickle.pbz2 │ └── meta.pickle.pbz2 ├── nats_bench/ │ ├── __init__.py │ ├── api_size.py │ ├── api_topology.py │ ├── api_utils.py │ └── genotype_utils.py ├── notebooks/ │ ├── README.md │ ├── create-query-sss.ipynb │ ├── find-largest.ipynb │ ├── issue-11.ipynb │ ├── issue-12.ipynb │ ├── issue-21.ipynb │ ├── issue-27.ipynb │ ├── issue-30.ipynb │ ├── issue-33.ipynb │ ├── issue-36.ipynb │ ├── issue-7.ipynb │ └── random-search.ipynb ├── setup.py └── tests/ └── api_test.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/CODE-OF-CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at dongxuanyi888@gmail.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ================================================ FILE: .github/ISSUE_TEMPLATE/bug-report.md ================================================ --- name: Bug Report about: Create a report to help us improve title: '' labels: '' assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Please provide a small script to reproduce the behavior: ``` codes to reproduce the bug ``` Please let me know your OS, Python version, PyTorch version. **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. ================================================ FILE: .github/ISSUE_TEMPLATE/question.md ================================================ --- name: Questions about NATS-Bench about: Ask questions about or discuss on NATS-Bench title: '' labels: '' assignees: '' --- **Describe the Question** A clear and concise description of the question. - Is it about the topology search space in NATS-Bench? - Is it about the size search space in NATS-Bench? - Which figure or table are you referring to in the paper? ================================================ FILE: .github/workflows/ci.yml ================================================ name: Run Python Tests on: push: branches: - main pull_request: branches: - main jobs: build: strategy: matrix: os: [ubuntu-18.04, ubuntu-20.04, macos-latest] python-version: [3.6, 3.7, 3.8, 3.9] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Lint with Black run: | cd .. if [ "$RUNNER_OS" == "Windows" ]; then python.exe -m pip install black python.exe -m black NATS-Bench/nats_bench -l 88 --check --diff python.exe -m black NATS-Bench/tests -l 88 --check --diff else python -m pip install black python --version python -m black --version echo $PWD ls python -m black NATS-Bench/nats_bench -l 88 --check --diff --verbose python -m black NATS-Bench/tests -l 88 --check --diff --verbose fi shell: bash - name: Install nats_bench from source run: | pip install . - name: Run tests with pytest run: | export FAKE_TORCH_HOME="fake_torch_dir" python -m pip install pytest python -m pytest . --durations=0 - name: Install nats_bench from pip with tests run: | pip uninstall -y nats_bench python -m pip install nats_bench export FAKE_TORCH_HOME="fake_torch_dir" python -m pip install pytest python -m pytest . --durations=0 ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *,cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # IPython Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # dotenv .env # virtualenv venv/ ENV/ # Spyder project settings .spyderproject # Rope project settings .ropeproject # Pycharm project .idea snapshots *.pytorch *.tar.bz data .*.swp *.sh main_main.py dist build *.egg-info .DS_Store ================================================ FILE: LICENSE.md ================================================ MIT License Copyright (c) since 2020 Xuanyi Dong (GitHub: https://github.com/D-X-Y) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # [NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size](https://arxiv.org/abs/2009.00437) Xuanyi Dong, Lu Liu, Katarzyna Musial, Bogdan Gabrys in IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI), 2021 **Abstract**: Neural architecture search (NAS) has attracted a lot of attention and has been illustrated to bring tangible benefits in a large number of applications in the past few years. Network topology and network size have been regarded as two of the most important aspects for the performance of deep learning models and the community has spawned lots of searching algorithms for both of those aspects of the neural architectures. However, the performance gain from these searching algorithms is achieved under different search spaces and training setups. This makes the overall performance of the algorithms incomparable and the improvement from a sub-module of the searching model unclear. In this paper, we propose NATS-Bench, a unified benchmark on searching for both topology and size, for (almost) any up-to-date NAS algorithm. NATS-Bench includes the search space of 15,625 neural cell candidates for architecture topology and 32,768 for architecture size on three datasets. We analyze the validity of our benchmark in terms of various criteria and performance comparison of all candidates in the search space. We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-the-art NAS algorithms on it. All logs and diagnostic information trained using the same setup for each candidate are provided. This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment. **You can use `pip install nats_bench` to install the library of NATS-Bench or install from source by `pip install .`.** If you are seeking how to re-create NATS-Bench from scratch or reproduce benchmarked results, please see use [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) and see these [instructions](https://github.com/D-X-Y/NATS-Bench#how-to-re-create-nats-bench-from-scratch). If you have questions, please ask at [here](https://github.com/D-X-Y/NATS-Bench/issues) or [email me](mailto:dongxuanyi888@gmail.com) :) This figure is the main difference between `NATS-Bench`, `NAS-Bench-101`, and `NAS-Bench-201`. The `topology search space` (`$\mathcal{S}_t$`) in `NATS-Bench` is the same as `NAS-Bench-201`, while we upgrade with results of more runs for the architecture candidates, and the benchmarked NAS algorithms have better hyperparameters.

## Preparation and Download **Step-1: download raw vision datasets.** (you can skip this one if you do not use weight-sharing NAS or re-create NATS-Bench). In NATS-Bench, we (create and) use three image datasets -- CIFAR-10, CIFAR-100, and ImageNet16-120. For more details, please see Sec-3.2 in [the NATS-Bench paper](https://arxiv.org/pdf/2009.00437.pdf). To download these three datasets, please find them at [Google Drive](https://drive.google.com/drive/folders/1T3UIyZXUhMmIuJLOBMIYKAsJknAtrrO4?usp=sharing). To create the `ImageNet16-120` PyTorch dataset, please call [AutoDL-Projects/lib/datasets/ImageNet16](https://github.com/D-X-Y/AutoDL-Projects/blob/main/xautodl/datasets/get_dataset_with_transform.py#L222-L225), by using: ``` train_data = ImageNet16(root, True , train_transform, 120) test_data = ImageNet16(root, False, test_transform , 120) ``` **Step-2: download benchmark files of NATS-Bench.** The **latest** benchmark file of NATS-Bench can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zjB6wMANiKwB2A1yil2hQ8H_qyeSe2yt?usp=sharing). After download `NATS-[tss/sss]-[version]-[md5sum]-simple.tar`, please uncompress it by using `tar xvf [file_name]`. We highly recommend to put the downloaded benchmark file (`NATS-sss-v1_0-50262.pickle.pbz2` / `NATS-tss-v1_0-3ffb9.pickle.pbz2`) or uncompressed archive (`NATS-sss-v1_0-50262-simple` / `NATS-tss-v1_0-3ffb9-simple`) into `$TORCH_HOME`. In this way, our api will automatically find the path for these benchmark files, which are convenient for the users. Otherwise, you need to indicate the file when creating the benchmark instance manually. The history of benchmark files is as follows, `tss` indicates the topology search space and `sss` indicates the size search space. The benchmark file is used when creating the NATS-Bench instance with `fast_mode=False`. The archive is used when `fast_mode=True`, where `archive` is a directory containing 15,625 files for tss or contains 32,768 files for sss. Each file contains all the information for a specific architecture candidate. The `full archive` is similar to `archive`, while each file in `full archive` contains **the trained weights**. Since the full archive is too large, we use `split -b 30G file_name file_name` to split it into multiple 30G chunks. To merge the chunks into the original full archive, you can use `cat file_name* > file_name`. | Date | benchmark file (tss) | (unpacked benchmark file) archive (tss) | full archive (tss) | benchmark file (sss) | (unpacked benchmark file) archive (sss) | full archive (sss) | |:-----------|:---------------------:|:-------------:|:------------------:|:-------------------------------:|:--------------------------:|:------------------:| | 2020.08.31 | [NATS-tss-v1_0-3ffb9.pickle.pbz2](https://drive.google.com/file/d/1vzyK0UVH2D3fTpa1_dSWnp1gvGpAxRul/view?usp=sharing) | [NATS-tss-v1_0-3ffb9-simple.tar](https://drive.google.com/file/d/17_saCsj_krKjlCBLOJEpNtzPXArMCqxU/view?usp=sharing) | [NATS-tss-v1_0-3ffb9-full](https://drive.google.com/drive/folders/17S2Xg_rVkUul4KuJdq0WaWoUuDbo8ZKB?usp=sharing) | [NATS-sss-v1_0-50262.pickle.pbz2](https://drive.google.com/file/d/1IabIvzWeDdDAWICBzFtTCMXxYWPIOIOX/view?usp=sharing) | [NATS-sss-v1_0-50262-simple.tar](https://drive.google.com/file/d/1scOMTUwcQhAMa_IMedp9lTzwmgqHLGgA/view?usp=sharing) | [NATS-sss-v1_0-50262-full](https://drive.google.com/drive/folders/1xutPQJ4bHoUV1EMArsPD0c1bUqvtMuYY?usp=sharing) | | 2021.04.22 (Baidu-Pan) | [NATS-tss-v1_0-3ffb9.pickle.pbz2](https://pan.baidu.com/s/10z20F5s2RRPzGwRO40fLTw) (code: 8duj) | [NATS-tss-v1_0-3ffb9-simple.tar](https://pan.baidu.com/s/1vOnrHLxCB4y8cxUDrHUYAg) (code: tu1e) | [NATS-tss-v1_0-3ffb9-full](https://pan.baidu.com/s/1qbPNlI8Y1I29qMdxTo_ycA) (code:ssub) | [NATS-sss-v1_0-50262.pickle.pbz2](https://pan.baidu.com/s/1M1UaXr6y1D_RqEYg95YJcA) (code: za2h) | [NATS-sss-v1_0-50262-simple.tar](https://pan.baidu.com/s/1ek-b89Pw2qdm9MP6KKkErA) (code: e4t9) | [NATS-sss-v1_0-50262-full](https://pan.baidu.com/s/1bIruQd9pPeyArej_wttg_A) (code: htif) | These benchmark files (without pretrained weights) can also be downloaded from [Dropbox](https://www.dropbox.com/sh/ceeo70u1buow681/AAC2M-SbKOxiIqpB0UCgXNxja?dl=0), [OneDrive](https://1drv.ms/u/s!Aqkc27lrowWDf6SvuIkSXx0UQaI?e=nfvM5r) or [Baidu-Pan (extract code: h6pm)](https://pan.baidu.com/s/144VC2BDm6iXbAVzMUpqO7A). For the full checkpoints in `NATS-*ss-*-full`, we split the file into multiple parts (`NATS-*ss-*-full.tara*`) since they are too large to upload. Each file is about `30GB`. For Baidu Pan, since they restrict the maximum size of each file, we further split `NATS-*ss-*-full.tara*` into `NATS-*ss-*-full.tara*-aa` and `NATS-*ss-*-full.tara*-ab`. All splits are created by the command `split`. **Note:** if you encounter the `quota exceed erros` when download from Google Drive, please try to (1) login your personal Google account, (2) right-click-copy the files to your personal Google Drive, and (3) download from your personal Google Drive. ## Usage See more examples at [notebooks](notebooks). #### 1, create the benchmark instance: ``` from nats_bench import create # Create the API instance for the size search space in NATS api = create(None, 'sss', fast_mode=True, verbose=True) # Create the API instance for the topology search space in NATS api = create(None, 'tss', fast_mode=True, verbose=True) ``` #### 2, query the performance: ``` # Show the architecture topology string of the 12-th architecture # For the topology search space, the string is interpreted as # arch = '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.format( # edge_node_0_to_node_1, # edge_node_0_to_node_2, # edge_node_1_to_node_2, # edge_node_0_to_node_3, # edge_node_1_to_node_3, # edge_node_2_to_node_3, # ) # For the size search space, the string is interpreted as # arch = '{}:{}:{}:{}:{}'.format(out_channel_of_1st_conv_layer, # out_channel_of_1st_cell_stage, # out_channel_of_1st_residual_block, # out_channel_of_2nd_cell_stage, # out_channel_of_2nd_residual_block, # ) architecture_str = api.arch(12) print(architecture_str) # Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10 # info is a dict, where you can easily figure out the meaning by key info = api.get_more_info(1234, 'cifar10') # Query the flops, params, latency. info is a dict. info = api.get_cost_info(12, 'cifar10') # Simulate the training of the 1224-th candidate: validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12') ``` #### 3, create the instance of an architecture candidate in `NATS-Bench`: ``` # Create the instance of th 12-th candidate for CIFAR-10. # To keep NATS-Bench repo concise, we did not include any model-related codes here because they rely on PyTorch. # The package of [models] is defined at https://github.com/D-X-Y/AutoDL-Projects # so that one need to first import this package. import xautodl from xautodl.models import get_cell_based_tiny_net config = api.get_net_config(12, 'cifar10') network = get_cell_based_tiny_net(config) # Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights. params = api.get_net_param(12, 'cifar10', None) network.load_state_dict(next(iter(params.values()))) ``` #### 4, others: ``` # Clear the parameters of the 12-th candidate. api.clear_params(12) # Reload all information of the 12-th candidate. api.reload(index=12) ``` Please see [`api_test.py`](https://github.com/D-X-Y/NATS-Bench/blob/main/tests/api_test.py) for more examples. ``` from nats_bench import api_test api_test.test_nats_bench_tss('NATS-tss-v1_0-3ffb9-simple') api_test.test_nats_bench_tss('NATS-sss-v1_0-50262-simple') ``` ## How to Re-create NATS-Bench from Scratch **You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to re-create NATS-Bench from scratch.** ### The Size Search Space The following command will train all architecture candidate in the size search space with 90 epochs and use the random seed of `777`. If you want to use a different number of training epochs, please replace `90` with it, such as `01` or `12`. ``` bash ./scripts/NATS-Bench/train-shapes.sh 00000-32767 90 777 ``` The checkpoint of all candidates are located at `output/NATS-Bench-size` by default. After training these candidate architectures, please use the following command to re-organize all checkpoints into the official benchmark file. ``` python exps/NATS-Bench/sss-collect.py ``` ### The Topology Search Space The following command will train all architecture candidate in the topology search space with 200 epochs and use the random seed of `777`/`888`/`999`. If you want to use a different number of training epochs, please replace `200` with it, such as `12`. ``` bash scripts/NATS-Bench/train-topology.sh 00000-15624 200 '777 888 999' ``` The checkpoint of all candidates are located at `output/NATS-Bench-topology` by default. After training these candidate architectures, please use the following command to re-organize all checkpoints into the official benchmark file. ``` python exps/NATS-Bench/tss-collect.py ``` ## To Reproduce 13 Baseline NAS Algorithms in NATS-Bench **You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to run 13 baseline NAS methods.** Here are a brief introduction on how to run each algorithm ([NATS-algos](https://github.com/D-X-Y/AutoDL-Projects/tree/main/exps/NATS-algos)). ### Reproduce NAS methods on the topology search space Please use the following commands to run different NAS methods on the topology search space: ``` Four multi-trial based methods: python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01 python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space tss python ./exps/NATS-algos/bohb.py --dataset cifar100 --search_space tss --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 DARTS (first order): python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1 DARTS (second order): python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2 python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2 GDAS: python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 SETN: python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn Random Search with Weight Sharing: python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random ENAS: python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 ``` ### Reproduce NAS methods on the size search space Please use the following commands to run different NAS methods on the size search space: ``` Four multi-trial based methods: python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01 python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space sss python ./exps/NATS-algos/bohb.py --dataset cifar100 --search_space sss --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 Run Transformable Architecture Search (TAS), proposed in Network Pruning via Transformable Architecture Search, NeurIPS 2019 python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 Run the channel search strategy in FBNet-V2 -- masking + Gumbel-Softmax : python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777 Run the channel search strategy in TuNAS -- masking + sampling : python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0 python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777 ``` ### Final Discovered Architectures for Each Algorithm The architecture index can be found by use `api.query_index_by_arch(architecture_string)`. The final discovered architecture ID on CIFAR-10: ``` DARTS (first order): |skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2| |skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2| |skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2| DARTS (second order): |skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2| |skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2| |skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2| GDAS: |nor_conv_3x3~0|+|nor_conv_3x3~0|none~1|+|nor_conv_1x1~0|nor_conv_3x3~1|nor_conv_3x3~2| |nor_conv_3x3~0|+|nor_conv_3x3~0|none~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2| |avg_pool_3x3~0|+|nor_conv_3x3~0|skip_connect~1|+|nor_conv_3x3~0|nor_conv_1x1~1|nor_conv_1x1~2| ``` The final discovered architecture ID on CIFAR-100: ``` DARTS (V1): |none~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|none~2| |none~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|none~2| |skip_connect~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|nor_conv_3x3~2| DARTS (V2): |none~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|skip_connect~2| |skip_connect~0|+|nor_conv_3x3~0|none~1|+|skip_connect~0|none~1|none~2| |skip_connect~0|+|nor_conv_1x1~0|none~1|+|nor_conv_3x3~0|skip_connect~1|none~2| GDAS: |nor_conv_3x3~0|+|nor_conv_1x1~0|none~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2| |avg_pool_3x3~0|+|nor_conv_1x1~0|none~1|+|nor_conv_3x3~0|avg_pool_3x3~1|nor_conv_1x1~2| |avg_pool_3x3~0|+|nor_conv_3x3~0|none~1|+|nor_conv_3x3~0|nor_conv_1x1~1|nor_conv_1x1~2| ``` The final discovered architecture ID on ImageNet-16-120: ``` DARTS (V1): |none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2| |none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2| |none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|nor_conv_1x1~2| DARTS (V2): |none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|skip_connect~2| GDAS: |none~0|+|none~0|none~1|+|nor_conv_3x3~0|none~1|none~2| |none~0|+|none~0|none~1|+|nor_conv_3x3~0|none~1|none~2| |none~0|+|none~0|none~1|+|nor_conv_3x3~0|none~1|none~2| ``` ## Others We use [`black`](https://github.com/psf/black) for Python code formatter. Please use `black . -l 120`. ## Citation If you find that NATS-Bench helps your research, please consider citing it: ``` @article{dong2021nats, title = {{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size}, author = {Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan}, doi = {10.1109/TPAMI.2021.3054824}, journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, year = {2021}, note = {\mbox{doi}:\url{10.1109/TPAMI.2021.3054824}} } @inproceedings{dong2020nasbench201, title = {{NAS-Bench-201}: Extending the Scope of Reproducible Neural Architecture Search}, author = {Dong, Xuanyi and Yang, Yi}, booktitle = {International Conference on Learning Representations (ICLR)}, url = {https://openreview.net/forum?id=HJxyZkBKDr}, year = {2020} } ``` ================================================ FILE: nats_bench/__init__.py ================================================ ############################################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ########################## ############################################################################## # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # ############################################################################## """The official Application Programming Interface (API) for NATS-Bench.""" from typing import Text, Optional from nats_bench.api_size import NATSsize from nats_bench.api_topology import NATStopology from nats_bench.api_utils import ArchResults from nats_bench.api_utils import pickle_load from nats_bench.api_utils import pickle_save from nats_bench.api_utils import ResultsCount NATS_BENCH_API_VERSIONs = [ "v1.0", # [2020.08.31] initialize "v1.1", # [2020.12.20] add unit tests "v1.2", # [2021.03.17] black re-formulate "v1.3", # [2021.04.08] fix find_best issue for fast_mode=True "v1.4", # [2021.04.30] add topology_str2structure "v1.5", # [2021.12.09] make simulate_train_eval more robust "v1.6", # [2022.01.19] fix the inconsistent flop/params which is caused by a legacy (weight migration) issue "v1.7", # [2022.03.25] relax enforce_all kwargs and add a test "v1.8", # [2022.10.06] fix bugs at issues/44 ] NATS_BENCH_SSS_NAMEs = ("sss", "size") NATS_BENCH_TSS_NAMEs = ("tss", "topology") def version(): return NATS_BENCH_API_VERSIONs[-1] def create(file_path_or_dict, search_space, fast_mode=False, verbose=True): """Create the instead for NATS API. Args: file_path_or_dict: None or a file path or a directory path. search_space: This is a string indicates the search space in NATS-Bench. fast_mode: If True, we will not load all the data at initialization, instead, the data for each candidate architecture will be loaded when quering it; If False, we will load all the data during initialization. verbose: This is a flag to indicate whether log additional information. Raises: ValueError: If not find the matched serach space description. Returns: The created NATS-Bench API. """ if search_space in NATS_BENCH_TSS_NAMEs: return NATStopology(file_path_or_dict, fast_mode, verbose) elif search_space in NATS_BENCH_SSS_NAMEs: return NATSsize(file_path_or_dict, fast_mode, verbose) else: raise ValueError("invalid search space : {:}".format(search_space)) def search_space_info(main_tag: Text, aux_tag: Optional[Text]): """Obtain the search space information.""" nats_sss = dict(candidates=[8, 16, 24, 32, 40, 48, 56, 64], num_layers=5) nats_tss = dict( op_names=[ "none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3", ], num_nodes=4, ) if main_tag == "nats-bench": if aux_tag in NATS_BENCH_SSS_NAMEs: return nats_sss elif aux_tag in NATS_BENCH_TSS_NAMEs: return nats_tss else: raise ValueError("Unknown auxiliary tag: {:}".format(aux_tag)) elif main_tag == "nas-bench-201": if aux_tag is not None: raise ValueError("For NAS-Bench-201, the auxiliary tag should be None.") return nats_tss else: raise ValueError("Unknown main tag: {:}".format(main_tag)) ================================================ FILE: nats_bench/api_size.py ================================================ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 # ############################################################################## # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # ############################################################################## # The history of benchmark files are as follows, # # where the format is (the name is NATS-sss-[version]-[md5].pickle.pbz2) # # [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 # ############################################################################## # pylint: disable=line-too-long """The API for size search space in NATS-Bench.""" import collections import copy import os import random from typing import Dict, Optional, Text, Union, Any from nats_bench.api_utils import ArchResults from nats_bench.api_utils import NASBenchMetaAPI from nats_bench.api_utils import get_torch_home from nats_bench.api_utils import nats_is_dir from nats_bench.api_utils import nats_is_file from nats_bench.api_utils import PICKLE_EXT from nats_bench.api_utils import pickle_load from nats_bench.api_utils import time_string ALL_BASE_NAMES = ["NATS-sss-v1_0-50262"] def print_information(information, extra_info=None, show=False): """print out the information of a given ArchResults.""" dataset_names = information.get_dataset_names() strings = [ information.arch_str, "datasets : {:}, extra-info : {:}".format(dataset_names, extra_info), ] def metric2str(loss, acc): return "loss = {:.3f} & top1 = {:.2f}%".format(loss, acc) for dataset in dataset_names: metric = information.get_compute_costs(dataset) flop, param, latency = metric["flops"], metric["params"], metric["latency"] str1 = "{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.".format( dataset, flop, param, "{:.2f}".format(latency * 1000) if latency is not None and latency > 0 else None, ) train_info = information.get_metrics(dataset, "train") if dataset == "cifar10-valid": valid_info = information.get_metrics(dataset, "x-valid") test__info = information.get_metrics(dataset, "ori-test") str2 = "{:14s} train : [{:}], valid : [{:}], test : [{:}]".format( dataset, metric2str(train_info["loss"], train_info["accuracy"]), metric2str(valid_info["loss"], valid_info["accuracy"]), metric2str(test__info["loss"], test__info["accuracy"]), ) elif dataset == "cifar10": test__info = information.get_metrics(dataset, "ori-test") str2 = "{:14s} train : [{:}], test : [{:}]".format( dataset, metric2str(train_info["loss"], train_info["accuracy"]), metric2str(test__info["loss"], test__info["accuracy"]), ) else: valid_info = information.get_metrics(dataset, "x-valid") test__info = information.get_metrics(dataset, "x-test") str2 = "{:14s} train : [{:}], valid : [{:}], test : [{:}]".format( dataset, metric2str(train_info["loss"], train_info["accuracy"]), metric2str(valid_info["loss"], valid_info["accuracy"]), metric2str(test__info["loss"], test__info["accuracy"]), ) strings += [str1, str2] if show: print("\n".join(strings)) return strings class NATSsize(NASBenchMetaAPI): """This is the class for the API of size search space in NATS-Bench.""" def __init__( self, file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None, fast_mode: bool = False, verbose: bool = True, ): """The initialization function that takes the dataset file path (or a dict loaded from that path) as input.""" self._all_base_names = ALL_BASE_NAMES self.filename = None self._search_space_name = "size" self._fast_mode = fast_mode self._archive_dir = None self._full_train_epochs = 90 self.reset_time() if file_path_or_dict is None: if self._fast_mode: self._archive_dir = os.path.join( get_torch_home(), "{:}-simple".format(ALL_BASE_NAMES[-1]) ) else: file_path_or_dict = os.path.join( get_torch_home(), "{:}.{:}".format(ALL_BASE_NAMES[-1], PICKLE_EXT) ) print( "{:} Try to use the default NATS-Bench (size) path from " "fast_mode={:} and path={:}.".format( time_string(), self._fast_mode, file_path_or_dict ) ) if isinstance(file_path_or_dict, str): file_path_or_dict = str(file_path_or_dict) if verbose: print( "{:} Try to create the NATS-Bench (size) api " "from {:} with fast_mode={:}".format( time_string(), file_path_or_dict, fast_mode ) ) if not nats_is_file(file_path_or_dict) and not nats_is_dir( file_path_or_dict ): raise ValueError( "{:} is neither a file or a dir.".format(file_path_or_dict) ) self.filename = os.path.basename(file_path_or_dict) if fast_mode: if nats_is_file(file_path_or_dict): raise ValueError( "fast_mode={:} must feed the path for directory " ": {:}".format(fast_mode, file_path_or_dict) ) else: self._archive_dir = file_path_or_dict else: if nats_is_dir(file_path_or_dict): raise ValueError( "fast_mode={:} must feed the path for file " ": {:}".format(fast_mode, file_path_or_dict) ) else: file_path_or_dict = pickle_load(file_path_or_dict) elif isinstance(file_path_or_dict, dict): file_path_or_dict = copy.deepcopy(file_path_or_dict) self.verbose = verbose if isinstance(file_path_or_dict, dict): keys = ("meta_archs", "arch2infos", "evaluated_indexes") for key in keys: if key not in file_path_or_dict: raise ValueError("Can not find key[{:}] in the dict".format(key)) self.meta_archs = copy.deepcopy(file_path_or_dict["meta_archs"]) # NOTE(xuanyidong): This is a dict mapping each architecture to a dict, # where the key is #epochs and the value is ArchResults self.arch2infos_dict = collections.OrderedDict() self._avaliable_hps = set() for xkey in sorted(list(file_path_or_dict["arch2infos"].keys())): all_infos = file_path_or_dict["arch2infos"][xkey] hp2archres = collections.OrderedDict() for hp_key, results in all_infos.items(): hp2archres[hp_key] = ArchResults.create_from_state_dict(results) self._avaliable_hps.add( hp_key ) # save the avaliable hyper-parameter self.arch2infos_dict[xkey] = hp2archres self.evaluated_indexes = set(file_path_or_dict["evaluated_indexes"]) elif self.archive_dir is not None: benchmark_meta = pickle_load( "{:}/meta.{:}".format(self.archive_dir, PICKLE_EXT) ) self.meta_archs = copy.deepcopy(benchmark_meta["meta_archs"]) self.arch2infos_dict = collections.OrderedDict() self._avaliable_hps = set() self.evaluated_indexes = set() else: raise ValueError( "file_path_or_dict [{:}] must be a dict or archive_dir " "must be set".format(type(file_path_or_dict)) ) self.archstr2index = {} for idx, arch in enumerate(self.meta_archs): if arch in self.archstr2index: raise ValueError( "This [{:}]-th arch {:} already in the " "dict ({:}).".format(idx, arch, self.archstr2index[arch]) ) self.archstr2index[arch] = idx if self.verbose: print( "{:} Create NATS-Bench (size) done with {:}/{:} architectures " "avaliable.".format( time_string(), len(self.evaluated_indexes), len(self.meta_archs) ) ) @property def is_size(self): return True @property def is_topology(self): return False @property def full_epochs_in_paper(self): return 90 def query_info_str_by_arch(self, arch, hp: Text = "12"): """Query the information of a specific architecture. Args: arch: it can be an architecture index or an architecture string. hp: the hyperparamete indicator, could be 01, 12, or 90. The difference between these three configurations are the number of training epochs. Returns: ArchResults instance """ if self.verbose: print( "{:} Call query_info_str_by_arch with arch={:}" "and hp={:}".format(time_string(), arch, hp) ) return self._query_info_str_by_arch(arch, hp, print_information) def get_more_info( self, index, dataset, iepoch=None, hp: Text = "12", is_random: bool = True ): """Return the metric for the `index`-th architecture. Args: index: the architecture index. dataset: 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set 'cifar100' : using the proposed train set of CIFAR-100 as the training set 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set iepoch: the index of training epochs from 0 to 11/199. When iepoch=None, it will return the metric for the last training epoch When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0) hp: indicates different hyper-parameters for training When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs is_random: When is_random=True, the performance of a random architecture will be returned When is_random=False, the performanceo of all trials will be averaged. Returns: a dict, where key is the metric name and value is its value. """ if self.verbose: print( "{:} Call the get_more_info function with index={:}, dataset={:}, " "iepoch={:}, hp={:}, and is_random={:}.".format( time_string(), index, dataset, iepoch, hp, is_random ) ) index = self.query_index_by_arch( index ) # To avoid the input is a string or an instance of a arch object self._prepare_info(index) if index not in self.arch2infos_dict: raise ValueError("Did not find {:} from arch2infos_dict.".format(index)) archresult = self.arch2infos_dict[index][str(hp)] # if randomly select one trial, select the seed at first if isinstance(is_random, bool) and is_random: seeds = archresult.get_dataset_seeds(dataset) is_random = random.choice(seeds) # collect the training information train_info = archresult.get_metrics( dataset, "train", iepoch=iepoch, is_random=is_random ) total = train_info["iepoch"] + 1 xinfo = { "train-loss": train_info["loss"], "train-accuracy": train_info["accuracy"], "train-per-time": train_info["all_time"] / total, "train-all-time": train_info["all_time"], } # collect the evaluation information if dataset == "cifar10-valid": valid_info = archresult.get_metrics( dataset, "x-valid", iepoch=iepoch, is_random=is_random ) try: test_info = archresult.get_metrics( dataset, "ori-test", iepoch=iepoch, is_random=is_random ) except Exception as unused_e: # pylint: disable=broad-except test_info = None valtest_info = None xinfo[ "comment" ] = "In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.".format( hp ) else: if dataset == "cifar10": xinfo[ "comment" ] = "In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.".format( hp ) try: # collect results on the proposed test set if dataset == "cifar10": test_info = archresult.get_metrics( dataset, "ori-test", iepoch=iepoch, is_random=is_random ) else: test_info = archresult.get_metrics( dataset, "x-test", iepoch=iepoch, is_random=is_random ) except Exception as unused_e: # pylint: disable=broad-except test_info = None try: # collect results on the proposed validation set valid_info = archresult.get_metrics( dataset, "x-valid", iepoch=iepoch, is_random=is_random ) except Exception as unused_e: # pylint: disable=broad-except valid_info = None try: if dataset != "cifar10": valtest_info = archresult.get_metrics( dataset, "ori-test", iepoch=iepoch, is_random=is_random ) else: valtest_info = None except Exception as unused_e: # pylint: disable=broad-except valtest_info = None if valid_info is not None: xinfo["valid-loss"] = valid_info["loss"] xinfo["valid-accuracy"] = valid_info["accuracy"] xinfo["valid-per-time"] = valid_info["all_time"] / total xinfo["valid-all-time"] = valid_info["all_time"] if test_info is not None: xinfo["test-loss"] = test_info["loss"] xinfo["test-accuracy"] = test_info["accuracy"] xinfo["test-per-time"] = test_info["all_time"] / total xinfo["test-all-time"] = test_info["all_time"] if valtest_info is not None: xinfo["valtest-loss"] = valtest_info["loss"] xinfo["valtest-accuracy"] = valtest_info["accuracy"] xinfo["valtest-per-time"] = valtest_info["all_time"] / total xinfo["valtest-all-time"] = valtest_info["all_time"] return xinfo def show(self, index: int = -1) -> None: """Print the information of a specific (or all) architecture(s).""" self._show(index, print_information) ================================================ FILE: nats_bench/api_topology.py ================================================ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 # ############################################################################## # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # ############################################################################## # The history of benchmark files are as follows, # # where the format is (the name is NATS-tss-[version]-[md5].pickle.pbz2) # # [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 # ############################################################################## # pylint: disable=line-too-long """The API for topology search space in NATS-Bench.""" import collections import copy import os import random from typing import Any, Dict, List, Optional, Text, Union from nats_bench.api_utils import ArchResults from nats_bench.api_utils import NASBenchMetaAPI from nats_bench.api_utils import get_torch_home from nats_bench.api_utils import nats_is_dir from nats_bench.api_utils import nats_is_file from nats_bench.api_utils import PICKLE_EXT from nats_bench.api_utils import pickle_load from nats_bench.api_utils import time_string from nats_bench.genotype_utils import topology_str2structure ALL_BASE_NAMES = ["NATS-tss-v1_0-3ffb9"] def print_information(information, extra_info=None, show=False): """print out the information of a given ArchResults.""" dataset_names = information.get_dataset_names() strings = [ information.arch_str, "datasets : {:}, extra-info : {:}".format(dataset_names, extra_info), ] def metric2str(loss, acc): return "loss = {:.3f} & top1 = {:.2f}%".format(loss, acc) for dataset in dataset_names: metric = information.get_compute_costs(dataset) flop, param, latency = metric["flops"], metric["params"], metric["latency"] str1 = "{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.".format( dataset, flop, param, "{:.2f}".format(latency * 1000) if latency is not None and latency > 0 else None, ) train_info = information.get_metrics(dataset, "train") if dataset == "cifar10-valid": valid_info = information.get_metrics(dataset, "x-valid") str2 = "{:14s} train : [{:}], valid : [{:}]".format( dataset, metric2str(train_info["loss"], train_info["accuracy"]), metric2str(valid_info["loss"], valid_info["accuracy"]), ) elif dataset == "cifar10": test__info = information.get_metrics(dataset, "ori-test") str2 = "{:14s} train : [{:}], test : [{:}]".format( dataset, metric2str(train_info["loss"], train_info["accuracy"]), metric2str(test__info["loss"], test__info["accuracy"]), ) else: valid_info = information.get_metrics(dataset, "x-valid") test__info = information.get_metrics(dataset, "x-test") str2 = "{:14s} train : [{:}], valid : [{:}], test : [{:}]".format( dataset, metric2str(train_info["loss"], train_info["accuracy"]), metric2str(valid_info["loss"], valid_info["accuracy"]), metric2str(test__info["loss"], test__info["accuracy"]), ) strings += [str1, str2] if show: print("\n".join(strings)) return strings class NATStopology(NASBenchMetaAPI): """This is the class for the API of topology search space in NATS-Bench.""" def __init__( self, file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None, fast_mode: bool = False, verbose: bool = True, ): """The initialization function that takes the dataset file path (or a dict loaded from that path) as input.""" self._all_base_names = ALL_BASE_NAMES self.filename = None self._search_space_name = "topology" self._fast_mode = fast_mode self._archive_dir = None self._full_train_epochs = 200 self.reset_time() if file_path_or_dict is None: if self._fast_mode: self._archive_dir = os.path.join( get_torch_home(), "{:}-simple".format(ALL_BASE_NAMES[-1]) ) else: file_path_or_dict = os.path.join( get_torch_home(), "{:}.{:}".format(ALL_BASE_NAMES[-1], PICKLE_EXT) ) print( "{:} Try to use the default NATS-Bench (topology) path from " "fast_mode={:} and path={:}.".format( time_string(), self._fast_mode, file_path_or_dict ) ) if isinstance(file_path_or_dict, str): file_path_or_dict = str(file_path_or_dict) if verbose: print( "{:} Try to create the NATS-Bench (topology) api " "from {:} with fast_mode={:}".format( time_string(), file_path_or_dict, fast_mode ) ) if not nats_is_file(file_path_or_dict) and not nats_is_dir( file_path_or_dict ): raise ValueError( "{:} is neither a file or a dir.".format(file_path_or_dict) ) self.filename = os.path.basename(file_path_or_dict) if fast_mode: if nats_is_file(file_path_or_dict): raise ValueError( "fast_mode={:} must feed the path for directory " ": {:}".format(fast_mode, file_path_or_dict) ) else: self._archive_dir = file_path_or_dict else: if nats_is_dir(file_path_or_dict): raise ValueError( "fast_mode={:} must feed the path for file " ": {:}".format(fast_mode, file_path_or_dict) ) else: file_path_or_dict = pickle_load(file_path_or_dict) elif isinstance(file_path_or_dict, dict): file_path_or_dict = copy.deepcopy(file_path_or_dict) self.verbose = verbose if isinstance(file_path_or_dict, dict): keys = ("meta_archs", "arch2infos", "evaluated_indexes") for key in keys: if key not in file_path_or_dict: raise ValueError("Can not find key[{:}] in the dict".format(key)) self.meta_archs = copy.deepcopy(file_path_or_dict["meta_archs"]) # NOTE(xuanyidong): This is a dict mapping each architecture to a dict, # where the key is #epochs and the value is ArchResults self.arch2infos_dict = collections.OrderedDict() self._avaliable_hps = set() for xkey in sorted(list(file_path_or_dict["arch2infos"].keys())): all_infos = file_path_or_dict["arch2infos"][xkey] hp2archres = collections.OrderedDict() for hp_key, results in all_infos.items(): hp2archres[hp_key] = ArchResults.create_from_state_dict(results) self._avaliable_hps.add( hp_key ) # save the avaliable hyper-parameter self.arch2infos_dict[xkey] = hp2archres self.evaluated_indexes = set(file_path_or_dict["evaluated_indexes"]) elif self.archive_dir is not None: benchmark_meta = pickle_load( "{:}/meta.{:}".format(self.archive_dir, PICKLE_EXT) ) self.meta_archs = copy.deepcopy(benchmark_meta["meta_archs"]) self.arch2infos_dict = collections.OrderedDict() self._avaliable_hps = set() self.evaluated_indexes = set() else: raise ValueError( "file_path_or_dict [{:}] must be a dict or archive_dir " "must be set".format(type(file_path_or_dict)) ) self.archstr2index = {} for idx, arch in enumerate(self.meta_archs): if arch in self.archstr2index: raise ValueError( "This [{:}]-th arch {:} already in the " "dict ({:}).".format(idx, arch, self.archstr2index[arch]) ) self.archstr2index[arch] = idx if self.verbose: print( "{:} Create NATS-Bench (topology) done with {:}/{:} architectures " "avaliable.".format( time_string(), len(self.evaluated_indexes), len(self.meta_archs) ) ) @property def is_size(self): return False @property def is_topology(self): return True @property def full_epochs_in_paper(self): return 200 def get_unique_str(self, arch): """Return a unique string for the isomorphism architectures. Args: arch: it can be an architecture index or an architecture string. Returns: the unique string. """ index = self.query_index_by_arch( arch ) # To avoid the arch is a string or an instance of a arch object arch_str = self.meta_archs[index] structure = topology_str2structure(arch_str) return structure.to_unique_str(consider_zero=True) def query_info_str_by_arch(self, arch, hp: Text = "12"): """Query the information of a specific architecture. Args: arch: it can be an architecture index or an architecture string. hp: the hyperparamete indicator, could be 12 or 200. The difference between these three configurations are the number of training epochs. Returns: ArchResults instance """ if self.verbose: print( "{:} Call query_info_str_by_arch with arch={:}" "and hp={:}".format(time_string(), arch, hp) ) return self._query_info_str_by_arch(arch, hp, print_information) def get_more_info( self, index, dataset, iepoch=None, hp: Text = "12", is_random: bool = True ): """Return the metric for the `index`-th architecture.""" if self.verbose: print( "{:} Call the get_more_info function with index={:}, dataset={:}, " "iepoch={:}, hp={:}, and is_random={:}.".format( time_string(), index, dataset, iepoch, hp, is_random ) ) index = self.query_index_by_arch( index ) # To avoid the input is a string or an instance of a arch object self._prepare_info(index) if index not in self.arch2infos_dict: raise ValueError("Did not find {:} from arch2infos_dict.".format(index)) archresult = self.arch2infos_dict[index][str(hp)] # if randomly select one trial, select the seed at first if isinstance(is_random, bool) and is_random: seeds = archresult.get_dataset_seeds(dataset) is_random = random.choice(seeds) # collect the training information train_info = archresult.get_metrics( dataset, "train", iepoch=iepoch, is_random=is_random ) total = train_info["iepoch"] + 1 xinfo = { "train-loss": train_info["loss"], "train-accuracy": train_info["accuracy"], "train-per-time": train_info["all_time"] / total if train_info["all_time"] is not None else None, "train-all-time": train_info["all_time"], } # collect the evaluation information if dataset == "cifar10-valid": valid_info = archresult.get_metrics( dataset, "x-valid", iepoch=iepoch, is_random=is_random ) try: test_info = archresult.get_metrics( dataset, "ori-test", iepoch=iepoch, is_random=is_random ) except Exception as unused_e: # pylint: disable=broad-except test_info = None valtest_info = None xinfo[ "comment" ] = "In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.".format( hp ) else: if dataset == "cifar10": xinfo[ "comment" ] = "In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.".format( hp ) try: # collect results on the proposed test set if dataset == "cifar10": test_info = archresult.get_metrics( dataset, "ori-test", iepoch=iepoch, is_random=is_random ) else: test_info = archresult.get_metrics( dataset, "x-test", iepoch=iepoch, is_random=is_random ) except Exception as unused_e: # pylint: disable=broad-except test_info = None try: # collect results on the proposed validation set valid_info = archresult.get_metrics( dataset, "x-valid", iepoch=iepoch, is_random=is_random ) except Exception as unused_e: # pylint: disable=broad-except valid_info = None try: if dataset != "cifar10": valtest_info = archresult.get_metrics( dataset, "ori-test", iepoch=iepoch, is_random=is_random ) else: valtest_info = None except Exception as unused_e: # pylint: disable=broad-except valtest_info = None if valid_info is not None: xinfo["valid-loss"] = valid_info["loss"] xinfo["valid-accuracy"] = valid_info["accuracy"] xinfo["valid-per-time"] = ( valid_info["all_time"] / total if valid_info["all_time"] is not None else None ) xinfo["valid-all-time"] = valid_info["all_time"] if test_info is not None: xinfo["test-loss"] = test_info["loss"] xinfo["test-accuracy"] = test_info["accuracy"] xinfo["test-per-time"] = ( test_info["all_time"] / total if test_info["all_time"] is not None else None ) xinfo["test-all-time"] = test_info["all_time"] if valtest_info is not None: xinfo["valtest-loss"] = valtest_info["loss"] xinfo["valtest-accuracy"] = valtest_info["accuracy"] xinfo["valtest-per-time"] = ( valtest_info["all_time"] / total if valtest_info["all_time"] is not None else None ) xinfo["valtest-all-time"] = valtest_info["all_time"] return xinfo def show(self, index: int = -1) -> None: """This function will print the information of a specific (or all) architecture(s).""" self._show(index, print_information) @staticmethod def str2lists(arch_str: Text) -> List[Any]: """Shows how to read the string-based architecture encoding. Args: arch_str: the input is a string indicates the architecture topology, such as |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2| Returns: a list of tuple, contains multiple (op, input_node_index) pairs. [USAGE] It is the same as the `str2structure` func in AutoDL-Projects: `github.com/D-X-Y/AutoDL-Projects/lib/models/cell_searchs/genotypes.py` ``` arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list for i, node in enumerate(arch): print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) ``` """ node_strs = arch_str.split("+") genotypes = [] for unused_i, node_str in enumerate(node_strs): inputs = list( filter(lambda x: x != "", node_str.split("|")) ) # pylint: disable=g-explicit-bool-comparison for xinput in inputs: assert len(xinput.split("~")) == 2, "invalid input length : {:}".format( xinput ) inputs = (xi.split("~") for xi in inputs) input_infos = tuple((op, int(idx)) for (op, idx) in inputs) genotypes.append(input_infos) return genotypes @staticmethod def str2matrix( arch_str: Text, search_space: List[Text] = ( "none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3", ), ): """Convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101. Args: arch_str: the input is a string indicates the architecture topology, such as |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2| search_space: a list of operation string, the default list is the topology search space for NATS-BENCH. the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/models/cell_operations.py#L24 Returns: the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology [USAGE] matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful). [ [0, 0, 0, 0], # the first line represents the input (0-th) node [2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node ) [0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) [0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node ) In the topology search space in NATS-BENCH, 0-th-op is 'none', 1-th-op is 'skip_connect', 2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'. [NOTE] If a node has two input-edges from the same node, this function does not work. One edge will be overlapped. """ import numpy as np node_strs = arch_str.split("+") num_nodes = len(node_strs) + 1 matrix = np.zeros((num_nodes, num_nodes)) for i, node_str in enumerate(node_strs): inputs = list( filter(lambda x: x != "", node_str.split("|")) ) # pylint: disable=g-explicit-bool-comparison for xinput in inputs: assert len(xinput.split("~")) == 2, "invalid input length : {:}".format( xinput ) for xi in inputs: op, idx = xi.split("~") if op not in search_space: raise ValueError( "this op ({:}) is not in {:}".format(op, search_space) ) op_idx, node_idx = search_space.index(op), int(idx) matrix[i + 1, node_idx] = op_idx return matrix ================================================ FILE: nats_bench/api_utils.py ================================================ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 # ############################################################################## # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # ############################################################################## """In this file, we define NASBenchMetaAPI, ArchResults, and ResultsCount. NASBenchMetaAPI is the abstract class for benchmark APIs. We also define the class ArchResults, which contains all information of a single architecture trained by one kind of hyper-parameters on three datasets. We also define the class ResultsCount, which contains all information of a single trial for a single architecture. """ import abc import bz2 import collections import copy import os import pickle import random import time from typing import Any, Dict, Optional, Text, Union import warnings _FILE_SYSTEM = "default" PICKLE_EXT = "pickle.pbz2" def mean(xlist): return sum(xlist) / len(xlist) def time_string(): iso_time_format = "%Y-%m-%d %X" string = "[{:}]".format(time.strftime(iso_time_format, time.gmtime(time.time()))) return string def reset_file_system(lib: Text = "default"): global _FILE_SYSTEM _FILE_SYSTEM = lib def get_file_system(): return _FILE_SYSTEM def get_torch_home(): if "TORCH_HOME" in os.environ: return os.environ["TORCH_HOME"] elif "HOME" in os.environ: return os.path.join(os.environ["HOME"], ".torch") else: raise ValueError( "Did not find HOME in os.environ. " "Please at least setup the path of HOME or TORCH_HOME " "in the environment." ) def nats_is_dir(file_path): if _FILE_SYSTEM == "default": return os.path.isdir(file_path) elif _FILE_SYSTEM == "google": import tensorflow as tf # pylint: disable=g-import-not-at-top return tf.io.gfile.isdir(file_path) else: raise ValueError("Unknown file system lib: {:}".format(_FILE_SYSTEM)) def nats_is_file(file_path): if _FILE_SYSTEM == "default": return os.path.isfile(file_path) elif _FILE_SYSTEM == "google": import tensorflow as tf # pylint: disable=g-import-not-at-top return tf.io.gfile.exists(file_path) and not tf.io.gfile.isdir(file_path) else: raise ValueError("Unknown file system lib: {:}".format(_FILE_SYSTEM)) def pickle_save(obj, file_path, ext=".pbz2", protocol=4): """Use pickle to save data (obj) into file_path. Args: obj: The object to be saved into a path. file_path: The target saving path. ext: The extension of file name. protocol: The pickle protocol. According to this documentation (https://docs.python.org/3/library/pickle.html#data-stream-format), the protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. It is the default protocol starting with Python 3.8. """ # with open(file_path, 'wb') as cfile: if _FILE_SYSTEM == "default": with bz2.BZ2File(str(file_path) + ext, "wb") as cfile: pickle.dump( obj, cfile, protocol=protocol ) # pytype: disable=wrong-arg-types else: raise ValueError("Unknown file system lib: {:}".format(_FILE_SYSTEM)) def pickle_load(file_path, ext=".pbz2"): """Use pickle to load the file on different systems.""" # return pickle.load(open(file_path, "rb")) if nats_is_file(str(file_path)): xfile_path = str(file_path) else: xfile_path = str(file_path) + ext if _FILE_SYSTEM == "default": with bz2.BZ2File(xfile_path, "rb") as cfile: return pickle.load(cfile) # pytype: disable=wrong-arg-types elif _FILE_SYSTEM == "google": import tensorflow as tf # pylint: disable=g-import-not-at-top file_content = tf.io.gfile.GFile(file_path, mode="rb").read() byte_content = bz2.decompress(file_content) return pickle.loads(byte_content) else: raise ValueError("Unknown file system lib: {:}".format(_FILE_SYSTEM)) def remap_dataset_set_names(dataset, metric_on_set, verbose=False): """Re-map the metric_on_set to internal keys.""" if verbose: print( "Call internal function _remap_dataset_set_names with dataset={:} " "and metric_on_set={:}".format(dataset, metric_on_set) ) if dataset == "cifar10" and metric_on_set == "valid": dataset, metric_on_set = "cifar10-valid", "x-valid" elif dataset == "cifar10" and metric_on_set == "test": dataset, metric_on_set = "cifar10", "ori-test" elif dataset == "cifar10" and metric_on_set == "train": dataset, metric_on_set = "cifar10", "train" elif ( dataset == "cifar100" or dataset == "ImageNet16-120" ) and metric_on_set == "valid": metric_on_set = "x-valid" elif ( dataset == "cifar100" or dataset == "ImageNet16-120" ) and metric_on_set == "test": metric_on_set = "x-test" if verbose: print( " return dataset={:} and metric_on_set={:}".format(dataset, metric_on_set) ) return dataset, metric_on_set class NASBenchMetaAPI(metaclass=abc.ABCMeta): """The abstract class for NATS Bench API.""" @abc.abstractmethod def __init__( self, file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None, fast_mode: bool = False, verbose: bool = True, ): """The initialization function that takes the dataset file path (or a dict loaded from that path) as input.""" # NOTE(xuanyidong): the following attributes must be initilaized in subclass self.meta_archs = None self.verbose = None self.evaluated_indexes = None self.arch2infos_dict = None self.filename = None self._fast_mode = None self._archive_dir = None self._avaliable_hps = None self.archstr2index = None def __getitem__(self, index: int): return copy.deepcopy(self.meta_archs[index]) def arch(self, index: int): """Return the topology structure of the `index`-th architecture.""" if self.verbose: print("Call the arch function with index={:}".format(index)) if index < 0 or index >= len(self.meta_archs): raise ValueError( "invalid index : {:} vs. {:}.".format(index, len(self.meta_archs)) ) return copy.deepcopy(self.meta_archs[index]) def __len__(self): return len(self.meta_archs) def __repr__(self): return ( "{name}({num}/{total} architectures, fast_mode={fast_mode}, " "file={filename})".format( name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), fast_mode=self.fast_mode, filename=self.filename, ) ) @property def avaliable_hps(self): return list(copy.deepcopy(self._avaliable_hps)) @property def used_time(self): return self._used_time @property def search_space_name(self): return self._search_space_name @property def fast_mode(self): return self._fast_mode @property def archive_dir(self): return self._archive_dir @property def full_train_epochs(self): return self._full_train_epochs def reset_archive_dir(self, archive_dir): self._archive_dir = archive_dir def reset_fast_mode(self, fast_mode): self._fast_mode = fast_mode def reset_time(self): self._used_time = 0 @abc.abstractmethod def get_more_info( self, index, dataset, iepoch=None, hp: Text = "12", is_random: bool = True ): """Return the metric for the `index`-th architecture.""" def simulate_train_eval( self, arch, dataset, iepoch=None, hp="12", account_time=True ): """This function is used to simulate training and evaluating an arch.""" index = self.query_index_by_arch(arch) all_names = ("cifar10", "cifar100", "ImageNet16-120") if dataset not in all_names: raise ValueError( "Invalid dataset name : {:} vs {:}".format(dataset, all_names) ) if dataset == "cifar10": info = self.get_more_info( index, "cifar10-valid", iepoch=iepoch, hp=hp, is_random=True ) else: info = self.get_more_info( index, dataset, iepoch=iepoch, hp=hp, is_random=True ) if "valid-accuracy" in info: valid_acc, time_cost = ( info["valid-accuracy"], info["train-all-time"] + info["valid-per-time"], ) else: valid_acc = info["valtest-accuracy"] temp_info = self.get_more_info( index, dataset, iepoch=None, hp=hp, is_random=True ) time_cost = info["train-all-time"] + temp_info["valid-per-time"] latency = self.get_latency(index, dataset, hp=hp) if account_time: self._used_time += time_cost return valid_acc, latency, time_cost, self._used_time def random(self): """Return a random index of all architectures.""" return random.randint(0, len(self.meta_archs) - 1) def reload(self, archive_root: Text = None, index: int = None): """Overwrite all information of the 'index'-th architecture in search space. Args: archive_root: If archive_root is None, it will try to load from the default path os.environ['TORCH_HOME'] / 'BASE_NAME'-full. index: If index is None, overwrite all ckps. """ if self.verbose: print( "{:} Call clear_params with archive_root={:} and index={:}".format( time_string(), archive_root, index ) ) if archive_root is None: archive_root = os.path.join( os.environ["TORCH_HOME"], "{:}-full".format(self._all_base_names[-1]) ) if not nats_is_dir(archive_root): warnings.warn( "The input archive_root is None and the default " "archive_root path ({:}) does not exist, try to use " "self.archive_dir.".format(archive_root) ) archive_root = self.archive_dir if archive_root is None or not nats_is_dir(archive_root): raise ValueError("Invalid archive_root : {:}".format(archive_root)) if index is None: indexes = list(range(len(self))) else: indexes = [index] for idx in indexes: if not ( 0 <= idx < len(self.meta_archs) ): # pylint: disable=superfluous-parens raise ValueError("invalid index of {:}".format(idx)) xfile_path = os.path.join( archive_root, "{:06d}.{:}".format(idx, PICKLE_EXT) ) if not nats_is_file(xfile_path): xfile_path = os.path.join( archive_root, "{:d}.{:}".format(idx, PICKLE_EXT) ) assert nats_is_file(xfile_path), "invalid data path : {:}".format( xfile_path ) xdata = pickle_load(xfile_path) assert isinstance(xdata, dict), "invalid format of data in {:}".format( xfile_path ) self.evaluated_indexes.add(idx) hp2archres = collections.OrderedDict() for hp_key, results in xdata.items(): hp2archres[hp_key] = ArchResults.create_from_state_dict(results) self._avaliable_hps.add(hp_key) self.arch2infos_dict[idx] = hp2archres def query_index_by_arch(self, arch): """Query the index of an architecture in the search space. Args: arch: For topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'; # pylint: disable=line-too-long or an instance that has the 'tostr' function that can generate the architecture string; or it is directly an architecture index, in this case, we will check whether it is valid or not. This function will return the index. If return -1, it means this architecture is not in the search space. Otherwise, it will return an intenger in [0, the-number-of-candidates-in-the-search-space). Raises: ValueError: If did not find the architecture in this benchmark. Returns: The index of the architcture in this benchmark. """ if self.verbose: print( "{:} Call query_index_by_arch with arch={:}".format(time_string(), arch) ) if isinstance(arch, int): if 0 <= arch < len(self): return arch else: raise ValueError( "Invalid architecture index {:} vs [{:}, {:}].".format( arch, 0, len(self) ) ) elif isinstance(arch, str): if arch in self.archstr2index: arch_index = self.archstr2index[arch] else: arch_index = -1 elif hasattr(arch, "tostr"): if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[arch.tostr()] else: arch_index = -1 else: arch_index = -1 return arch_index def query_by_arch(self, arch, hp): """Make the current version be compatible with the old NAS-Bench-201 version.""" return self.query_info_str_by_arch(arch, hp) def _prepare_info(self, index): """This is a function to load the data from disk when using fast mode.""" if index not in self.arch2infos_dict: if self.fast_mode and self.archive_dir is not None: self.reload(self.archive_dir, index) elif not self.fast_mode: if self.verbose: print( "{:} Call _prepare_info with index={:} skip because it is not" "the fast mode.".format(time_string(), index) ) else: raise ValueError( "Invalid status: fast_mode={:} and " "archive_dir={:}".format(self.fast_mode, self.archive_dir) ) else: if index not in self.evaluated_indexes: raise ValueError( "The index of {:} is not in self.evaluated_indexes, " "there must be something wrong.".format(index) ) if self.verbose: print( "{:} Call _prepare_info with index={:} skip because it is in " "arch2infos_dict".format(time_string(), index) ) def clear_params(self, index: int, hp: Optional[Text] = None): """Remove the architecture's weights to save memory. Args: index: the index of the target architecture hp: a flag to controll how to clear the parameters. -- None: clear all the weights in '01'/'12'/'90', which indicates the number of training epochs. -- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp]. """ if self.verbose: print( "{:} Call clear_params with index={:} and hp={:}".format( time_string(), index, hp ) ) if index not in self.arch2infos_dict: warnings.warn( "The {:}-th architecture is not in the benchmark data yet, " "no need to clear params.".format(index) ) elif hp is None: for key, result in self.arch2infos_dict[index].items(): result.clear_params() else: if str(hp) not in self.arch2infos_dict[index]: raise ValueError( "The {:}-th architecture only has hyper-parameters " "of {:} instead of {:}.".format( index, list(self.arch2infos_dict[index].keys()), hp ) ) self.arch2infos_dict[index][str(hp)].clear_params() @abc.abstractmethod def query_info_str_by_arch(self, arch, hp: Text = "12"): """This function is used to query the information of a specific architecture.""" def _query_info_str_by_arch(self, arch, hp: Text = "12", print_information=None): """Internal function to query the information of `arch` when using `hp`.""" arch_index = self.query_index_by_arch(arch) self._prepare_info(arch_index) if arch_index in self.arch2infos_dict: if hp not in self.arch2infos_dict[arch_index]: raise ValueError( "The {:}-th architecture only has hyper-parameters of " "{:} instead of {:}.".format( arch_index, list(self.arch2infos_dict[arch_index].keys()), hp ) ) info = self.arch2infos_dict[arch_index][hp] strings = print_information(info, "arch-index={:}".format(arch_index)) return "\n".join(strings) else: warnings.warn( "Find this arch-index : {:}, but this arch is not " "evaluated.".format(arch_index) ) return None def query_meta_info_by_index(self, arch_index, hp: Text = "12"): """Return ArchResults for the 'arch_index'-th architecture.""" if self.verbose: print( "Call query_meta_info_by_index with arch_index={:}, hp={:}".format( arch_index, hp ) ) self._prepare_info(arch_index) if arch_index in self.arch2infos_dict: if str(hp) not in self.arch2infos_dict[arch_index]: raise ValueError( "The {:}-th architecture only has hyper-parameters of " "{:} instead of {:}.".format( arch_index, list(self.arch2infos_dict[arch_index].keys()), hp ) ) info = self.arch2infos_dict[arch_index][str(hp)] else: raise ValueError( "arch_index [{:}] does not in arch2infos".format(arch_index) ) return copy.deepcopy(info) def query_by_index( self, arch_index: int, dataname: Union[None, Text] = None, hp: Text = "12" ): """Query the information with the training of 01/12/90/200 epochs. Args: arch_index: The architecture index in this benchmark. dataname: If dataname is None, return the ArchResults; otherwise, we will return a dict with all trials on that dataset (the key is the seed). Options are 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'. -- cifar10-valid : train the model on CIFAR-10 training set. -- cifar10 : train the model on CIFAR-10 training + validation set. -- cifar100 : train the model on CIFAR-100 training set. -- ImageNet16-120 : train the model on ImageNet16-120 training set. hp: The hyperparameters. If hp=01, we train the model by 01 epochs. If hp=12, we train the model by 01 epochs. If hp=90, we train the model by 01 epochs. If hp=200, we train the model by 01 epochs. See github.com/D-X-Y/AutoDL-Projects/configs/nas-benchmark/hyper-opts for more details. Raises: ValueError: If not find the matched serach space description. Returns: An instance fo ArchResults. """ if self.verbose: print( "{:} Call query_by_index with arch_index={:}, dataname={:}, " "hp={:}".format(time_string(), arch_index, dataname, hp) ) info = self.query_meta_info_by_index(arch_index, str(hp)) if dataname is None: return info else: if dataname not in info.get_dataset_names(): raise ValueError( "invalid dataset-name : {:} vs. {:}".format( dataname, info.get_dataset_names() ) ) return info.query(dataname) def find_best( self, dataset, metric_on_set, flop_max=None, param_max=None, hp: Text = "12", enforce_all: Optional[bool] = None, ): """Find the architecture with the highest accuracy based on some constraints.""" # Please see how to set the `dataset` and `metric_on_set` (setname) at here: # https://github.com/D-X-Y/NATS-Bench/blob/main/nats_bench/api_utils.py#L702 if self.verbose: print( "{:} Call find_best with dataset={:}, metric_on_set={:}, hp={:} " "| with #FLOPs < {:} and #Params < {:}".format( time_string(), dataset, metric_on_set, hp, flop_max, param_max ) ) dataset, metric_on_set = remap_dataset_set_names( dataset, metric_on_set, self.verbose ) best_index, highest_accuracy = -1, None if enforce_all is None: enforce_all = True if self.fast_mode else False if enforce_all: # We set this arg `enforce_all` because in the fast mode, evaluated_indexes will be empty # `evaluated_indexes` is dynamically inserted with architectures along with the query assert ( self.fast_mode ), "enforce_all can only be set when fast_mode=True; if you are using non-fast-mode, please set it as False" evaluated_indexes = list(range(len(self))) else: evaluated_indexes = sorted(list(self.evaluated_indexes)) for arch_index in evaluated_indexes: self._prepare_info(arch_index) arch_info = self.arch2infos_dict[arch_index][hp] info = arch_info.get_compute_costs(dataset) # the information of costs flop, param, latency = info["flops"], info["params"], info["latency"] if flop_max is not None and flop > flop_max: continue if param_max is not None and param > param_max: continue xinfo = arch_info.get_metrics( dataset, metric_on_set ) # the information of loss and accuracy loss, accuracy = xinfo["loss"], xinfo["accuracy"] if best_index == -1: best_index, highest_accuracy = arch_index, accuracy elif highest_accuracy < accuracy: best_index, highest_accuracy = arch_index, accuracy del latency, loss if self.verbose: if not evaluated_indexes: print( "The evaluated_indexes is empty, please fill it before call find_best." ) else: print( " the best architecture : [{:}] {:} with accuracy={:.3f}%".format( best_index, self.arch(best_index), highest_accuracy ) ) return best_index, highest_accuracy def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = "12"): """Obtain the trained weights of the `index`-th arch on `dataset`. Args: index: The architecture index. dataset: The training dataset name. seed: -- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights. -- Interger : return the weights of a specific trial, whose seed is this interger. hp: -- 01 : train the model by 01 epochs -- 12 : train the model by 12 epochs -- 90 : train the model by 90 epochs -- 200 : train the model by 200 epochs Returns: PyTorch weights. """ if self.verbose: print( "{:} Call the get_net_param function with index={:}, dataset={:}, " "seed={:}, hp={:}".format(time_string(), index, dataset, seed, hp) ) info = self.query_meta_info_by_index(index, hp) return info.get_net_param(dataset, seed) def get_net_config(self, index: int, dataset: Text): """Obtain the configuration for the `index`-th architecture on `dataset`. Args: index: The architecture index. dataset: 4 possible options as follows, -- cifar10-valid : train the model on the CIFAR-10 training set. -- cifar10 : train the model on the CIFAR-10 training + validation set. -- cifar100 : train the model on the CIFAR-100 training set. -- ImageNet16-120 : train the model on the ImageNet16-120 training set. Returns: A dict. Note: some examlpes for using this function: config = api.get_net_config(128, 'cifar10') """ if self.verbose: print( "{:} Call the get_net_config function with index={:}, " "dataset={:}.".format(time_string(), index, dataset) ) self._prepare_info(index) if index in self.arch2infos_dict: info = self.arch2infos_dict[index] else: raise ValueError( "The arch_index={:} is not in arch2infos_dict.".format(index) ) info = next(iter(info.values())) results = info.query(dataset, None) results = next(iter(results.values())) return results.get_config(None) def get_cost_info( self, index: int, dataset: Text, hp: Text = "12" ) -> Dict[Text, float]: """To obtain the cost metric for the `index`-th architecture on a dataset.""" if self.verbose: print( "{:} Call the get_cost_info function with index={:}, " "dataset={:}, and hp={:}.".format(time_string(), index, dataset, hp) ) self._prepare_info(index) info = self.query_meta_info_by_index(index, hp) return info.get_compute_costs(dataset) def get_latency(self, index: int, dataset: Text, hp: Text = "12") -> float: """Obtain the latency of the network. Note: by default it will return the latency with the batch size of 256. Args: index: the index of the target architecture dataset: the dataset name (cifar10-valid, cifar10, cifar100, and ImageNet16-120) hp: the hyperparamete indicator. Returns: return a float value in seconds """ if self.verbose: print( "{:} Call the get_latency function with index={:}, " "dataset={:}, and hp={:}.".format(time_string(), index, dataset, hp) ) cost_dict = self.get_cost_info(index, dataset, hp) return cost_dict["latency"] @abc.abstractmethod def show(self, index=-1): """This function will print the information of a specific (or all) architecture(s).""" def _show(self, index=-1, print_information=None) -> None: """Print the information of a specific (or all) architecture(s). Args: index: If the index < 0: it will loop for all architectures and print their information one by one. Else: it will print the information of the 'index'-th architecture. print_information: A function to print result. Returns: None """ if index < 0: # show all architectures print(self) evaluated_indexes = sorted(list(self.evaluated_indexes)) for i, idx in enumerate(evaluated_indexes): print( "\n" + "-" * 10 + " The ({:5d}/{:5d}) {:06d}-th " "architecture! ".format(i, len(evaluated_indexes), idx) + "-" * 10 ) print("arch : {:}".format(self.meta_archs[idx])) for unused_key, result in self.arch2infos_dict[index].items(): strings = print_information(result) print( ">" * 40 + " {:03d} epochs ".format(result.get_total_epoch()) + ">" * 40 ) print("\n".join(strings)) print("<" * 40 + "------------" + "<" * 40) else: if 0 <= index < len(self.meta_archs): if index not in self.evaluated_indexes: self._prepare_info(index) if index not in self.evaluated_indexes: print( "The {:}-th architecture has not been evaluated " "or not saved.".format(index) ) else: # arch_info = self.arch2infos_dict[index] for unused_key, result in self.arch2infos_dict[index].items(): strings = print_information(result) print( ">" * 40 + " {:03d} epochs ".format(result.get_total_epoch()) + ">" * 40 ) print("\n".join(strings)) print("<" * 40 + "------------" + "<" * 40) else: print( "This index ({:}) is out of range (0~{:}).".format( index, len(self.meta_archs) ) ) def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]: """This function will count the number of total trials.""" if self.verbose: print( "Call the statistics function with dataset={:} and hp={:}.".format( dataset, hp ) ) valid_datasets = ["cifar10-valid", "cifar10", "cifar100", "ImageNet16-120"] if dataset not in valid_datasets: raise ValueError("{:} not in {:}".format(dataset, valid_datasets)) nums, hp = collections.defaultdict(lambda: 0), str(hp) # for index in range(len(self)): for index in self.evaluated_indexes: arch_info = self.arch2infos_dict[index][hp] dataset_seed = arch_info.dataset_seed if dataset not in dataset_seed: nums[0] += 1 else: nums[len(dataset_seed[dataset])] += 1 return dict(nums) class ArchResults(object): """A class to maintain the results of an architecture under different settings.""" def __init__(self, arch_index, arch_str): self.arch_index = int(arch_index) self.arch_str = copy.deepcopy(arch_str) self.all_results = dict() self.dataset_seed = dict() self.clear_net_done = False def get_compute_costs(self, dataset): """Return the computation cost on the input dataset.""" x_seeds = self.dataset_seed[dataset] results = [self.all_results[(dataset, seed)] for seed in x_seeds] flops = [result.flop for result in results] params = [result.params for result in results] # NOTE(xuanyidong): # Due to the legacy issue, flops and params may be incorrect. # This is a quick fix for this legacy issue. def fix_legacy_issue(raw_list): xlist = [f"{x:.5f}" for x in raw_list] if len(xlist) > 1 and len(set(xlist)) > 1: # inconsistent return [raw_list[x_seeds.index(888)]] else: return raw_list flops = fix_legacy_issue(flops) params = fix_legacy_issue(params) latencies = [result.get_latency() for result in results] latencies = [x for x in latencies if x > 0] mean_latency = mean(latencies) if len(latencies) else None time_infos = collections.defaultdict(list) for result in results: time_info = result.get_times() for key, value in time_info.items(): time_infos[key].append(value) info = { "flops": mean(flops), "params": mean(params), "latency": mean_latency, } for key, value in time_infos.items(): if len(value) and value[0] is not None: info[key] = mean(value) else: info[key] = None return info def get_metrics(self, dataset, setname, iepoch=None, is_random=False): """Obtain the loss, accuracy, etc information on a specific dataset. If not specify, each set refer to the proposed split in NAS-Bench-201. If some args return None or raise error, then it is not avaliable. ======================================== Args: dataset: 4 possible options as follows -- cifar10-valid : train the model on the CIFAR-10 training set. -- cifar10 : train the model on the CIFAR-10 training + validation set. -- cifar100 : train the model on the CIFAR-100 training set. -- ImageNet16-120 : train the model on the ImageNet16-120 training set. setname: each dataset has different setnames -- When dataset = cifar10-valid, you can use 'train', 'x-valid', and 'ori-test' ------ 'train' : the metric on the training set. ------ 'x-valid' : the metric on the validation set. ------ 'ori-test' : the metric on the test set. -- When dataset = cifar10, you can use 'train', 'ori-test'. ------ 'train' : the metric on the training + validation set. ------ 'ori-test' : the metric on the test set. -- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', and 'x-test' ------ 'train' : the metric on the training set. ------ 'x-valid' : the metric on the validation set. ------ 'x-test' : the metric on the test set. ------ 'ori-test' : the metric on the validation + test set. iepoch: (None or an integer in [0, the-number-of-total-training-epochs) ------ None : return the metric after the last training epoch. ------ an integer i : return the metric after the i-th training epoch. is_random: ------ True : return the metric of a randomly selected trial. ------ False : return the averaged metric of all avaliable trials. ------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random'). Returns: All the metrics given the input setting. """ x_seeds = self.dataset_seed[dataset] results = [self.all_results[(dataset, seed)] for seed in x_seeds] infos = collections.defaultdict(list) for result in results: if setname == "train": info = result.get_train(iepoch) else: info = result.get_eval(setname, iepoch) for key, value in info.items(): infos[key].append(value) return_info = dict() if isinstance(is_random, bool) and is_random: # randomly select one index = random.randint(0, len(results) - 1) for key, value in infos.items(): return_info[key] = value[index] elif isinstance(is_random, bool) and not is_random: # average for key, value in infos.items(): if len(value) and value[0] is not None: return_info[key] = mean(value) else: return_info[key] = None elif isinstance(is_random, int): # specify the seed if is_random not in x_seeds: raise ValueError( "can not find random seed ({:}) from {:}".format(is_random, x_seeds) ) index = x_seeds.index(is_random) for key, value in infos.items(): return_info[key] = value[index] else: raise ValueError("invalid value for is_random: {:}".format(is_random)) return return_info # def show(self, is_print=False): # return print_information(self, None, is_print) def get_dataset_names(self): return list(self.dataset_seed.keys()) def get_dataset_seeds(self, dataset): return copy.deepcopy(self.dataset_seed[dataset]) def get_net_param(self, dataset: Text, seed: Union[None, int] = None): """Return the trained network's weights on the 'dataset'. Args: dataset: 'cifar10-valid', 'cifar10', 'cifar100', or 'ImageNet16-120'. seed: an integer indicates the seed value or None that indicates returing all trials. Returns: The trained weights (parameters). """ if seed is None: x_seeds = self.dataset_seed[dataset] return { seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds } else: xkey = (dataset, seed) if xkey in self.all_results: return self.all_results[xkey].get_net_param() else: raise ValueError( "key={:} not in {:}".format(xkey, list(self.all_results.keys())) ) def reset_latency( self, dataset: Text, seed: Union[None, Text], latency: float ) -> None: """This function is used to reset the latency in all corresponding ResultsCount(s).""" if seed is None: for seed in self.dataset_seed[dataset]: self.all_results[(dataset, seed)].update_latency([latency]) else: self.all_results[(dataset, seed)].update_latency([latency]) def reset_pseudo_train_times( self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float ) -> None: """This function is used to reset the train-times in all corresponding ResultsCount(s).""" if seed is None: for seed in self.dataset_seed[dataset]: self.all_results[(dataset, seed)].reset_pseudo_train_times( estimated_per_epoch_time ) else: self.all_results[(dataset, seed)].reset_pseudo_train_times( estimated_per_epoch_time ) def reset_pseudo_eval_times( self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float, ) -> None: """This function is used to reset the eval-times in all corresponding ResultsCount(s).""" if seed is None: for seed in self.dataset_seed[dataset]: self.all_results[(dataset, seed)].reset_pseudo_eval_times( eval_name, estimated_per_epoch_time ) else: self.all_results[(dataset, seed)].reset_pseudo_eval_times( eval_name, estimated_per_epoch_time ) def get_latency(self, dataset: Text) -> float: """Get the latency of a model on the target dataset.""" latencies = [] for seed in self.dataset_seed[dataset]: latency = self.all_results[(dataset, seed)].get_latency() if not isinstance(latency, float) or latency <= 0: raise ValueError( "invalid latency of {:} with seed={:} : {:}".format( dataset, seed, latency ) ) latencies.append(latency) return sum(latencies) / len(latencies) def get_total_epoch(self, dataset=None): """Return the total number of training epochs.""" if dataset is None: epochss = [] for xdata, x_seeds in self.dataset_seed.items(): epochss += [ self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds ] elif isinstance(dataset, str): x_seeds = self.dataset_seed[dataset] epochss = [ self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds ] else: raise ValueError("invalid dataset={:}".format(dataset)) if len(set(epochss)) > 1: raise ValueError( "Each trial mush have the same number of training epochs : {:}".format( epochss ) ) return epochss[-1] def query(self, dataset, seed=None): """Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'.""" if seed is None: x_seeds = self.dataset_seed[dataset] return {seed: self.all_results[(dataset, seed)] for seed in x_seeds} else: return self.all_results[(dataset, seed)] def arch_idx_str(self): return "{:06d}".format(self.arch_index) def update(self, dataset_name, seed, result): """Update the result for the given dataset and seed.""" if dataset_name not in self.dataset_seed: self.dataset_seed[dataset_name] = [] if seed in self.dataset_seed[dataset_name]: raise ValueError( "{:}-th arch alreadly has this seed ({:}) on {:}".format( self.arch_index, seed, dataset_name ) ) self.dataset_seed[dataset_name].append(seed) self.dataset_seed[dataset_name] = sorted(self.dataset_seed[dataset_name]) assert (dataset_name, seed) not in self.all_results self.all_results[(dataset_name, seed)] = result self.clear_net_done = False def state_dict(self): """Return a dict that can be used to re-create this instance.""" state_dict = dict() for key, value in self.__dict__.items(): if key == "all_results": # contain the class of ResultsCount xvalue = dict() if not isinstance(value, dict): raise ValueError( "invalid type of value for {:} : {:}".format(key, type(value)) ) for cur_k, cur_v in value.items(): if not isinstance(cur_v, ResultsCount): raise ValueError( "invalid type of value for {:}/{:} : {:}".format( key, cur_k, type(cur_v) ) ) xvalue[cur_k] = cur_v.state_dict() else: xvalue = value state_dict[key] = xvalue return state_dict def load_state_dict(self, state_dict): """Update self based on the input dict.""" new_state_dict = dict() for key, value in state_dict.items(): if key == "all_results": # To convert to the class of ResultsCount xvalue = dict() if not isinstance(value, dict): raise ValueError( "invalid type of value for {:} : {:}".format(key, type(value)) ) for cur_k, cur_v in value.items(): xvalue[cur_k] = ResultsCount.create_from_state_dict(cur_v) else: xvalue = value new_state_dict[key] = xvalue self.__dict__.update(new_state_dict) @staticmethod def create_from_state_dict(state_dict_or_file): """Create the ArchResults instance from a dict or a file.""" x = ArchResults(-1, -1) if isinstance(state_dict_or_file, str): # a file path state_dict = pickle_load(state_dict_or_file) elif isinstance(state_dict_or_file, dict): state_dict = state_dict_or_file else: raise ValueError( "invalid type of state_dict_or_file : {:}".format( type(state_dict_or_file) ) ) x.load_state_dict(state_dict) return x def clear_params(self): """Clear the weights saved in each 'result'.""" # NOTE(xuanyidong): This can help reduce the memory footprint. for unused_key, result in self.all_results.items(): del result.net_state_dict result.net_state_dict = None self.clear_net_done = True def debug_test(self): """Help debug and test, which will call most methods.""" all_dataset = ["cifar10-valid", "cifar10", "cifar100", "ImageNet16-120"] for dataset in all_dataset: print("---->>>> {:}".format(dataset)) print( "The latency on {:} is {:} s".format(dataset, self.get_latency(dataset)) ) for seed in self.dataset_seed[dataset]: result = self.all_results[(dataset, seed)] print(" ==>> result = {:}".format(result)) print(" ==>> cost = {:}".format(result.get_times())) def __repr__(self): return ( "{name}(arch-index={index}, arch={arch}, " "{num} runs, clear={clear})".format( name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done, ) ) class ResultsCount(object): """ResultsCount is to save the information of one trial for a single architecture.""" def __init__( self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency, ): self.name = name self.net_state_dict = state_dict self.train_acc1es = copy.deepcopy(train_accs) self.train_acc5es = None self.train_losses = copy.deepcopy(train_losses) self.train_times = None self.arch_config = copy.deepcopy(arch_config) self.params = params self.flop = flop self.seed = seed self.epochs = epochs self.latency = latency # evaluation results self.reset_eval() def update_train_info( self, train_acc1es, train_acc5es, train_losses, train_times ) -> None: self.train_acc1es = train_acc1es self.train_acc5es = train_acc5es self.train_losses = train_losses self.train_times = train_times def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None: """Assign the training times.""" train_times = collections.OrderedDict() for i in range(self.epochs): train_times[i] = estimated_per_epoch_time self.train_times = train_times def reset_pseudo_eval_times( self, eval_name: Text, estimated_per_epoch_time: float ) -> None: """Assign the evaluation times.""" if eval_name not in self.eval_names: raise ValueError("invalid eval name : {:}".format(eval_name)) for i in range(self.epochs): self.eval_times["{:}@{:}".format(eval_name, i)] = estimated_per_epoch_time def reset_eval(self): self.eval_names = [] self.eval_acc1es = {} self.eval_times = {} self.eval_losses = {} def update_latency(self, latency): self.latency = copy.deepcopy(latency) def get_latency(self) -> float: """Return the latency value in seconds.""" # NOTE(xuanyidong): -1 represents not avaliable, # NOTE(xuanyidong): otherwise it should be a float value. if self.latency is None: return -1.0 else: return sum(self.latency) / len(self.latency) def update_eval(self, accs, losses, times): """To update the evaluataion results.""" data_names = set([x.split("@")[0] for x in accs.keys()]) for data_name in data_names: if data_name in self.eval_names: raise ValueError( "{:} has already been added into " "eval-names".format(data_name) ) self.eval_names.append(data_name) for iepoch in range(self.epochs): xkey = "{:}@{:}".format(data_name, iepoch) self.eval_acc1es[xkey] = accs[xkey] self.eval_losses[xkey] = losses[xkey] self.eval_times[xkey] = times[xkey] def update_OLD_eval(self, name, accs, losses): # pylint: disable=invalid-name """To update the evaluataion results (old NAS-Bench-201 version).""" assert name not in self.eval_names, "{:} has already added".format(name) self.eval_names.append(name) for iepoch in range(self.epochs): if iepoch in accs: self.eval_acc1es["{:}@{:}".format(name, iepoch)] = accs[iepoch] self.eval_losses["{:}@{:}".format(name, iepoch)] = losses[iepoch] def __repr__(self): num_eval = len(self.eval_names) set_name = "[" + ", ".join(self.eval_names) + "]" return ( "{name}({xname}, arch={arch}, FLOP={flop:.2f}M, " "Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: " "{set_name})".format( name=self.__class__.__name__, xname=self.name, arch=self.arch_config["arch_str"], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name, ) ) def get_total_epoch(self): return copy.deepcopy(self.epochs) def get_times(self): """Obtain the information regarding both training and evaluation time.""" if self.train_times is not None and isinstance(self.train_times, dict): train_times = list(self.train_times.values()) time_info = { "T-train@epoch": mean(train_times), "T-train@total": sum(train_times), } else: time_info = {"T-train@epoch": None, "T-train@total": None} for name in self.eval_names: try: xtimes = [ self.eval_times["{:}@{:}".format(name, i)] for i in range(self.epochs) ] time_info["T-{:}@epoch".format(name)] = mean(xtimes) time_info["T-{:}@total".format(name)] = sum(xtimes) except Exception as unused_e: # pylint: disable=broad-except time_info["T-{:}@epoch".format(name)] = None time_info["T-{:}@total".format(name)] = None return time_info def get_eval_set(self): return self.eval_names def judge_valid(self, iepoch: Optional[int]): if iepoch < 0 or iepoch >= self.epochs: raise ValueError("invalid iepoch={:} < {:}".format(iepoch, self.epochs)) def get_train(self, iepoch: Optional[int] = None): """Get the training information.""" if iepoch is None: iepoch = self.epochs - 1 self.judge_valid(iepoch) if self.train_times is not None: xtime = self.train_times[iepoch] atime = sum([self.train_times[i] for i in range(iepoch + 1)]) else: xtime, atime = None, None return { "iepoch": iepoch, "loss": self.train_losses[iepoch], "accuracy": self.train_acc1es[iepoch], "cur_time": xtime, "all_time": atime, } def get_eval(self, name, iepoch: Optional[int] = None): """Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).""" if iepoch is None: iepoch = self.epochs - 1 self.judge_valid(iepoch) def _internal_query(xname): if isinstance(self.eval_times, dict) and len(self.eval_times): key = "{:}@{:}".format(xname, iepoch) if key in self.eval_times: xtime = self.eval_times[key] else: raise ValueError( "{:} not in {:}".format(key, self.eval_times.keys()) ) atime = sum( [ self.eval_times["{:}@{:}".format(xname, i)] for i in range(iepoch + 1) ] ) else: xtime, atime = None, None return { "iepoch": iepoch, "loss": self.eval_losses["{:}@{:}".format(xname, iepoch)], "accuracy": self.eval_acc1es["{:}@{:}".format(xname, iepoch)], "cur_time": xtime, "all_time": atime, } if name == "valid": return _internal_query("x-valid") else: return _internal_query(name) def get_net_param(self, clone: bool = False): if clone: return copy.deepcopy(self.net_state_dict) else: return self.net_state_dict def get_config(self, str2structure): """This function is used to obtain the config dict for this architecture.""" if str2structure is None: # In this case, this is an arch in size search space of NATS-BENCH. if ( "name" in self.arch_config and self.arch_config["name"] == "infer.shape.tiny" ): return { "name": "infer.shape.tiny", "channels": self.arch_config["channels"], "genotype": self.arch_config["genotype"], "num_classes": self.arch_config["class_num"], } else: # This is an arch in NATS-BENCH's topology search space. return { "name": "infer.tiny", "C": self.arch_config["channel"], "N": self.arch_config["num_cells"], "arch_str": self.arch_config["arch_str"], "num_classes": self.arch_config["class_num"], } else: # This is an arch in the size search space of NATS-BENCH. if ( "name" in self.arch_config and self.arch_config["name"] == "infer.shape.tiny" ): return { "name": "infer.shape.tiny", "channels": self.arch_config["channels"], "genotype": str2structure(self.arch_config["genotype"]), "num_classes": self.arch_config["class_num"], } else: # This is an arch in the topology search space of NATS-BENCH. return { "name": "infer.tiny", "C": self.arch_config["channel"], "N": self.arch_config["num_cells"], "genotype": str2structure(self.arch_config["arch_str"]), "num_classes": self.arch_config["class_num"], } def state_dict(self): collected_state_dict = {key: value for key, value in self.__dict__.items()} return collected_state_dict def load_state_dict(self, state_dict): self.__dict__.update(state_dict) @staticmethod def create_from_state_dict(state_dict): x = ResultsCount(None, None, None, None, None, None, None, None, None, None) x.load_state_dict(state_dict) return x ================================================ FILE: nats_bench/genotype_utils.py ================================================ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.09 # ##################################################### # Moved from https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/models/cell_searchs/genotypes.py ##################################################### from copy import deepcopy def topology_str2structure(xstr): return TopologyStructure.str2structure(xstr) class TopologyStructure: """A class to describe the topology, especially that used in NATS-Bench.""" def __init__(self, genotype): assert isinstance(genotype, list) or isinstance( genotype, tuple ), "invalid class of genotype : {:}".format(type(genotype)) self.node_num = len(genotype) + 1 self.nodes = [] self.node_N = [] for idx, node_info in enumerate(genotype): assert isinstance(node_info, list) or isinstance( node_info, tuple ), "invalid class of node_info : {:}".format(type(node_info)) assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info)) for node_in in node_info: assert isinstance(node_in, list) or isinstance( node_in, tuple ), "invalid class of in-node : {:}".format(type(node_in)) assert ( len(node_in) == 2 and node_in[1] <= idx ), "invalid in-node : {:}".format(node_in) self.node_N.append(len(node_info)) self.nodes.append(tuple(deepcopy(node_info))) def tolist(self, remove_str): # convert this class to the list, if remove_str is 'none', then remove the 'none' operation. # note that we re-order the input node in this function # return the-genotype-list and success [if unsuccess, it is not a connectivity] genotypes = [] for node_info in self.nodes: node_info = list(node_info) node_info = sorted(node_info, key=lambda x: (x[1], x[0])) node_info = tuple(filter(lambda x: x[0] != remove_str, node_info)) if len(node_info) == 0: return None, False genotypes.append(node_info) return genotypes, True def node(self, index): assert index > 0 and index <= len(self), "invalid index={:} < {:}".format( index, len(self) ) return self.nodes[index] def tostr(self): strings = [] for node_info in self.nodes: string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info]) string = "|{:}|".format(string) strings.append(string) return "+".join(strings) def check_valid(self): nodes = {0: True} for i, node_info in enumerate(self.nodes): sums = [] for op, xin in node_info: if op == "none" or nodes[xin] is False: x = False else: x = True sums.append(x) nodes[i + 1] = sum(sums) > 0 return nodes[len(self.nodes)] def to_unique_str(self, consider_zero=False): # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation # two operations are special, i.e., none and skip_connect nodes = {0: "0"} for i_node, node_info in enumerate(self.nodes): cur_node = [] for op, xin in node_info: if consider_zero is None: x = "(" + nodes[xin] + ")" + "@{:}".format(op) elif consider_zero: if op == "none" or nodes[xin] == "#": x = "#" # zero elif op == "skip_connect": x = nodes[xin] else: x = "(" + nodes[xin] + ")" + "@{:}".format(op) else: if op == "skip_connect": x = nodes[xin] else: x = "(" + nodes[xin] + ")" + "@{:}".format(op) cur_node.append(x) nodes[i_node + 1] = "+".join(sorted(cur_node)) return nodes[len(self.nodes)] def check_valid_op(self, op_names): for node_info in self.nodes: for inode_edge in node_info: # assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) if inode_edge[0] not in op_names: return False return True def __repr__(self): return "{name}({node_num} nodes with {node_info})".format( name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__ ) def __len__(self): return len(self.nodes) + 1 def __getitem__(self, index): return self.nodes[index] @staticmethod def str2structure(xstr): if isinstance(xstr, TopologyStructure): return xstr assert isinstance(xstr, str), "must take string (not {:}) as input".format( type(xstr) ) nodestrs = xstr.split("+") genotypes = [] for i, node_str in enumerate(nodestrs): inputs = list(filter(lambda x: x != "", node_str.split("|"))) for xinput in inputs: assert len(xinput.split("~")) == 2, "invalid input length : {:}".format( xinput ) inputs = (xi.split("~") for xi in inputs) input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs) genotypes.append(input_infos) return TopologyStructure(genotypes) @staticmethod def str2fullstructure(xstr, default_name="none"): assert isinstance(xstr, str), "must take string (not {:}) as input".format( type(xstr) ) nodestrs = xstr.split("+") genotypes = [] for i, node_str in enumerate(nodestrs): inputs = list(filter(lambda x: x != "", node_str.split("|"))) for xinput in inputs: assert len(xinput.split("~")) == 2, "invalid input length : {:}".format( xinput ) inputs = (xi.split("~") for xi in inputs) input_infos = list((op, int(IDX)) for (op, IDX) in inputs) all_in_nodes = list(x[1] for x in input_infos) for j in range(i): if j not in all_in_nodes: input_infos.append((default_name, j)) node_info = sorted(input_infos, key=lambda x: (x[1], x[0])) genotypes.append(tuple(node_info)) return TopologyStructure(genotypes) @staticmethod def gen_all(search_space, num, return_ori): assert isinstance(search_space, list) or isinstance( search_space, tuple ), "invalid class of search-space : {:}".format(type(search_space)) assert ( num >= 2 ), "There should be at least two nodes in a neural cell instead of {:}".format( num ) all_archs = get_combination(search_space, 1) for i, arch in enumerate(all_archs): all_archs[i] = [tuple(arch)] for inode in range(2, num): cur_nodes = get_combination(search_space, inode) new_all_archs = [] for previous_arch in all_archs: for cur_node in cur_nodes: new_all_archs.append(previous_arch + [tuple(cur_node)]) all_archs = new_all_archs if return_ori: return all_archs else: return [TopologyStructure(x) for x in all_archs] ================================================ FILE: notebooks/README.md ================================================ # Notebooks for NATS-Bench We provide some examples on how to use NATS-Bench in the following notebooks: - [`create-query-sss.ipynb`](create-query-sss.ipynb) : create the size search space and query the performance of some architectures. - [`find-largest.ipynb`](find-largest.ipynb) : find the largest model and report its performance on each dataset. - [`random-search.ipynb`](random-search.ipynb) : use random search on the topology search space of NATS-Bench - [`issue-7.ipynb`](issue-7.ipynb) : show how to use `find_best` function. - [`issue-11.ipynb`](issue-11.ipynb) : show how to use `get_more_info` function. - [`issue-12.ipynb`](issue-12.ipynb) : show the unique string of each candidate and remove duplicated candidates. - [`issue-21.ipynb`](issue-21.ipynb) : create the PyTorch model instance of an architecture candidate. FYI, to open and run them in your machine, execute `jupyter notebook` under the `NATS-Bench/notebooks` folder. ================================================ FILE: notebooks/create-query-sss.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-03-29 08:23:08] Try to use the default NATS-Bench (size) path from fast_mode=True and path=None.\n", "[2021-03-29 08:23:08] Create NATS-Bench (size) done with 0/32768 architectures avaliable.\n", "\n", "API create done: NATSsize(0/32768 architectures, fast_mode=True, file=None)\n", "\n", "[2021-03-29 08:23:08] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=12, and is_random=True.\n", "[2021-03-29 08:23:08] Call query_index_by_arch with arch=1234\n", "[2021-03-29 08:23:08] Call clear_params with archive_root=/Users/xuanyidong/.torch/NATS-sss-v1_0-50262-simple and index=1234\n", "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n", " 'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n", " 'performance of the CIFAR-10 test set after training on the '\n", " 'train+valid sets by 12 epochs. The per-time and total-time '\n", " 'indicate the per epoch and total time costs, respectively.',\n", " 'test-accuracy': 83.87,\n", " 'test-all-time': 8.31445026397705,\n", " 'test-loss': 0.4872739363670349,\n", " 'test-per-time': 0.6928708553314209,\n", " 'train-accuracy': 85.74,\n", " 'train-all-time': 69.73253917694092,\n", " 'train-loss': 0.4183172229385376,\n", " 'train-per-time': 5.811044931411743}\n" ] } ], "source": [ "from nats_bench import create\n", "from pprint import pprint\n", "\n", "# Create the API instance for the size search space in NATS\n", "api = create(None, 'sss', fast_mode=True, verbose=True)\n", "print('\\nAPI create done: {:}\\n'.format(api))\n", "\n", "\n", "info = api.get_more_info(1234, 'cifar10')\n", "pprint(info)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-03-29 08:23:12] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=90, and is_random=True.\n", "[2021-03-29 08:23:12] Call query_index_by_arch with arch=1234\n", "[2021-03-29 08:23:12] Call _prepare_info with index=1234 skip because it is in arch2infos_dict\n", "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n", " 'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n", " 'performance of the CIFAR-10 test set after training on the '\n", " 'train+valid sets by 90 epochs. The per-time and total-time '\n", " 'indicate the per epoch and total time costs, respectively.',\n", " 'test-accuracy': 89.4,\n", " 'test-all-time': 62.35837697982788,\n", " 'test-loss': 0.3388326271057129,\n", " 'test-per-time': 0.6928708553314209,\n", " 'train-accuracy': 95.206,\n", " 'train-all-time': 522.9940438270569,\n", " 'train-loss': 0.14320597895622253,\n", " 'train-per-time': 5.811044931411743}\n" ] } ], "source": [ "info = api.get_more_info(1234, 'cifar10', hp='90')\n", "pprint(info)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-03-29 08:23:15] Call the get_cost_info function with index=12, dataset=cifar10, and hp=12.\n", "[2021-03-29 08:23:15] Call clear_params with archive_root=/Users/xuanyidong/.torch/NATS-sss-v1_0-50262-simple and index=12\n", "Call query_meta_info_by_index with arch_index=12, hp=12\n", "[2021-03-29 08:23:15] Call _prepare_info with index=12 skip because it is in arch2infos_dict\n", "{'T-ori-test@epoch': 0.6709375381469727,\n", " 'T-ori-test@total': 8.051250457763672,\n", " 'T-train@epoch': 5.539922475814819,\n", " 'T-train@total': 66.47906970977783,\n", " 'flops': 7.991706,\n", " 'latency': 0.014862352974560795,\n", " 'params': 0.067378}\n" ] } ], "source": [ "# Query the flops, params, latency. info is a dict.\n", "info = api.get_cost_info(12, 'cifar10')\n", "pprint(info)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-03-29 08:23:17] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=12, and is_random=True.\n", "[2021-03-29 08:23:17] Call query_index_by_arch with arch=1234\n", "[2021-03-29 08:23:17] Call _prepare_info with index=1234 skip because it is in arch2infos_dict\n", "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n", " 'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n", " 'performance of the CIFAR-10 test set after training on the '\n", " 'train+valid sets by 12 epochs. The per-time and total-time '\n", " 'indicate the per epoch and total time costs, respectively.',\n", " 'test-accuracy': 84.28,\n", " 'test-all-time': 8.31445026397705,\n", " 'test-loss': 0.46498328766822816,\n", " 'test-per-time': 0.6928708553314209,\n", " 'train-accuracy': 86.004,\n", " 'train-all-time': 69.73253917694092,\n", " 'train-loss': 0.405061281375885,\n", " 'train-per-time': 5.811044931411743}\n" ] } ], "source": [ "info = api.get_more_info(1234, 'cifar10', hp='12', is_random=True)\n", "pprint(info)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-03-29 08:23:20] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=12, and is_random=True.\n", "[2021-03-29 08:23:20] Call query_index_by_arch with arch=1234\n", "[2021-03-29 08:23:20] Call _prepare_info with index=1234 skip because it is in arch2infos_dict\n", "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n", " 'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n", " 'performance of the CIFAR-10 test set after training on the '\n", " 'train+valid sets by 12 epochs. The per-time and total-time '\n", " 'indicate the per epoch and total time costs, respectively.',\n", " 'test-accuracy': 83.87,\n", " 'test-all-time': 8.31445026397705,\n", " 'test-loss': 0.4872739363670349,\n", " 'test-per-time': 0.6928708553314209,\n", " 'train-accuracy': 85.74,\n", " 'train-all-time': 69.73253917694092,\n", " 'train-loss': 0.4183172229385376,\n", " 'train-per-time': 5.811044931411743}\n" ] } ], "source": [ "# The same code as above, but return the different performance because we set is_random=True\n", "info = api.get_more_info(1234, 'cifar10', hp='12', is_random=True)\n", "pprint(info)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: notebooks/find-largest.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-03-29 08:22:18] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n" ] } ], "source": [ "import random\n", "import numpy as np\n", "from nats_bench import create\n", "from pprint import pprint\n", "# Create the API for tologoy search space\n", "api = create(None, 'tss', fast_mode=True, verbose=False)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The architecture-index for the largest model is 1462\n", "Its performance on cifar10 with 12-epoch-training\n", "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n", " 'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n", " 'performance of the CIFAR-10 test set after training on the '\n", " 'train+valid sets by 12 epochs. The per-time and total-time '\n", " 'indicate the per epoch and total time costs, respectively.',\n", " 'test-accuracy': 82.2,\n", " 'test-all-time': 25.68491765430996,\n", " 'test-loss': 0.5235109260559082,\n", " 'test-per-time': 2.14040980452583,\n", " 'train-accuracy': 83.78,\n", " 'train-all-time': 415.2997846603394,\n", " 'train-loss': 0.4719834935951233,\n", " 'train-per-time': 34.60831538836162}\n", "Its performance on cifar10 with 200-epoch-training\n", "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n", " 'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n", " 'performance of the CIFAR-10 test set after training on the '\n", " 'train+valid sets by 200 epochs. The per-time and total-time '\n", " 'indicate the per epoch and total time costs, respectively.',\n", " 'test-accuracy': 93.76,\n", " 'test-all-time': 428.08196090516384,\n", " 'test-loss': 0.29643801399866737,\n", " 'test-per-time': 2.1404098045258193,\n", " 'train-accuracy': 99.968,\n", " 'train-all-time': 6921.6630776723405,\n", " 'train-loss': 0.0021994023492435616,\n", " 'train-per-time': 34.6083153883617}\n", "Its performance on cifar100 with 12-epoch-training\n", "{'test-accuracy': 44.97999995727539,\n", " 'test-all-time': 12.84245882715498,\n", " 'test-loss': 2.069740362548828,\n", " 'test-per-time': 1.070204902262915,\n", " 'train-accuracy': 46.014,\n", " 'train-all-time': 415.2997846603394,\n", " 'train-loss': 1.9952968555450439,\n", " 'train-per-time': 34.60831538836162,\n", " 'valid-accuracy': 44.05999992675781,\n", " 'valid-all-time': 12.84245882715498,\n", " 'valid-loss': 2.077388186645508,\n", " 'valid-per-time': 1.070204902262915,\n", " 'valtest-accuracy': 44.52,\n", " 'valtest-all-time': 25.68491765430996,\n", " 'valtest-loss': 2.073564303588867,\n", " 'valtest-per-time': 2.14040980452583}\n", "Its performance on cifar100 with 200-epoch-training\n", "{'test-accuracy': 71.10666660563152,\n", " 'test-all-time': 214.04098045258192,\n", " 'test-loss': 1.3540414614995322,\n", " 'test-per-time': 1.0702049022629097,\n", " 'train-accuracy': 99.79133333333334,\n", " 'train-all-time': 6921.6630776723405,\n", " 'train-loss': 0.02413411712328593,\n", " 'train-per-time': 34.6083153883617,\n", " 'valid-accuracy': 70.70666666259766,\n", " 'valid-all-time': 214.04098045258192,\n", " 'valid-loss': 1.3654081104278564,\n", " 'valid-per-time': 1.0702049022629097,\n", " 'valtest-accuracy': 70.90666666666667,\n", " 'valtest-all-time': 428.08196090516384,\n", " 'valtest-loss': 1.3597248032251994,\n", " 'valtest-per-time': 2.1404098045258193}\n", "Its performance on ImageNet16-120 with 12-epoch-training\n", "{'test-accuracy': 22.39999992879232,\n", " 'test-all-time': 7.7054752962929856,\n", " 'test-loss': 3.1626377182006835,\n", " 'test-per-time': 0.6421229413577488,\n", " 'train-accuracy': 21.68885959195242,\n", " 'train-all-time': 1260.0195466594694,\n", " 'train-loss': 3.1863493608815463,\n", " 'train-per-time': 105.00162888828912,\n", " 'valid-accuracy': 23.266666631062826,\n", " 'valid-all-time': 7.7054752962929856,\n", " 'valid-loss': 3.1219845104217527,\n", " 'valid-per-time': 0.6421229413577488,\n", " 'valtest-accuracy': 22.833333323160808,\n", " 'valtest-all-time': 15.410950592585971,\n", " 'valtest-loss': 3.142311067581177,\n", " 'valtest-per-time': 1.2842458827154977}\n", "Its performance on ImageNet16-120 with 200-epoch-training\n", "{'test-accuracy': 41.44444444783529,\n", " 'test-all-time': 128.4245882715503,\n", " 'test-loss': 2.3114658287896046,\n", " 'test-per-time': 0.6421229413577515,\n", " 'train-accuracy': 50.604262800759415,\n", " 'train-all-time': 21000.325777657865,\n", " 'train-loss': 1.8626367051877495,\n", " 'train-per-time': 105.00162888828932,\n", " 'valid-accuracy': 40.777777659098305,\n", " 'valid-all-time': 128.4245882715503,\n", " 'valid-loss': 2.3157107713487415,\n", " 'valid-per-time': 0.6421229413577515,\n", " 'valtest-accuracy': 41.11111109754774,\n", " 'valtest-all-time': 256.8491765431006,\n", " 'valtest-loss': 2.313588462617662,\n", " 'valtest-per-time': 1.284245882715503}\n" ] } ], "source": [ "# query the largest model's performance\n", "largest_candidate_tss = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|'\n", "arch_index = api.query_index_by_arch(largest_candidate_tss)\n", "print('The architecture-index for the largest model is {:}'.format(arch_index))\n", "datasets = ('cifar10', 'cifar100', 'ImageNet16-120')\n", "for dataset in datasets:\n", " print('Its performance on {:} with 12-epoch-training'.format(dataset))\n", " info = api.get_more_info(arch_index, dataset, hp='12', is_random=False)\n", " pprint(info)\n", " print('Its performance on {:} with 200-epoch-training'.format(dataset))\n", " info = api.get_more_info(arch_index, dataset, hp='200', is_random=False)\n", " pprint(info)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: notebooks/issue-11.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-03-09 08:44:19] Try to use the default NATS-Bench (size) path from fast_mode=True and path=None.\n" ] } ], "source": [ "from nats_bench import create\n", "import numpy as np\n", "\n", "# Create the API for size search space\n", "api = create(None, 'sss', fast_mode=True, verbose=False)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "There are 32768 architectures on the size search space\n" ] } ], "source": [ "print('There are {:} architectures on the size search space'.format(len(api)))\n", "\n", "c2acc = dict()\n", "for index in range(len(api)):\n", " info = api.get_more_info(index, 'cifar10', hp='90')\n", " config = api.get_net_config(index, 'cifar10')\n", " c2acc[config['channels']] = info['test-accuracy']" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "91.08546417236329\n" ] } ], "source": [ "print(np.mean(list(c2acc.values())))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: notebooks/issue-12.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-04-30 03:41:23] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n" ] } ], "source": [ "from nats_bench import create\n", "from nats_bench.api_utils import time_string\n", "import numpy as np\n", "\n", "# Create the API for topology search space\n", "api = create(None, 'tss', fast_mode=True, verbose=False)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-04-30 03:41:23] There are 15625 architectures on the topology search space\n", "[2021-04-30 03:41:24] There are 6466 isomorphism architectures on the topology search space\n" ] } ], "source": [ "print('{:} There are {:} architectures on the topology search space'.format(time_string(), len(api)))\n", "\n", "unique_strs = []\n", "for index in range(len(api)):\n", " unique_str = api.get_unique_str(index)\n", " unique_strs.append(unique_str)\n", "print('{:} There are {:} isomorphism architectures on the topology search space'.format(time_string(), len(set(unique_strs))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: notebooks/issue-21.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-08-07 21:51:10] Try to use the default NATS-Bench (size) path from fast_mode=True and path=None.\n" ] } ], "source": [ "from nats_bench import create\n", "from nats_bench.api_utils import time_string\n", "import numpy as np\n", "\n", "# Create the API for size search space\n", "api = create(None, 'sss', fast_mode=True, verbose=False)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-08-07 21:51:10] There are 32768 architectures on the size search space\n", "{'name': 'infer.shape.tiny', 'channels': '8:8:8:16:40', 'genotype': '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', 'num_classes': 10}\n" ] } ], "source": [ "print('{:} There are {:} architectures on the size search space'.format(time_string(), len(api)))\n", "\n", "# Obtain the 12-th candidate's configureation on CIFAR-10\n", "config = api.get_net_config(12, 'cifar10')\n", "print(config)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import xautodl # import this lib -- \"https://github.com/D-X-Y/AutoDL-Projects\", you can use pip install xautodl\n", "from xautodl.models import get_cell_based_tiny_net\n", "# create the network, which is the sub-class of torch.nn.Module\n", "network = get_cell_based_tiny_net(config)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DynamicShapeTinyNet(\n", " DynamicShapeTinyNet(C=(8, 8, 8, 16, 40), N=1, L=5)\n", " (stem): Sequential(\n", " (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (cells): ModuleList(\n", " (0): InferCell(\n", " info :: nodes=4, inC=8, outC=8, [1<-(I0-L0) | 2<-(I0-L1,I1-L2) | 3<-(I0-L3,I1-L4,I2-L5)], |nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n", " (layers): ModuleList(\n", " (0): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (2): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (3): Identity()\n", " (4): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (5): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " )\n", " )\n", " (1): ResNetBasicblock(\n", " ResNetBasicblock(inC=8, outC=8, stride=2)\n", " (conv_a): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (conv_b): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (downsample): Sequential(\n", " (0): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", " (1): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " (2): InferCell(\n", " info :: nodes=4, inC=8, outC=8, [1<-(I0-L0) | 2<-(I0-L1,I1-L2) | 3<-(I0-L3,I1-L4,I2-L5)], |nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n", " (layers): ModuleList(\n", " (0): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (2): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (3): Identity()\n", " (4): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (5): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " )\n", " )\n", " (3): ResNetBasicblock(\n", " ResNetBasicblock(inC=8, outC=16, stride=2)\n", " (conv_a): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (conv_b): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (downsample): Sequential(\n", " (0): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", " (1): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " (4): InferCell(\n", " info :: nodes=4, inC=16, outC=40, [1<-(I0-L0) | 2<-(I0-L1,I1-L2) | 3<-(I0-L3,I1-L4,I2-L5)], |nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n", " (layers): ModuleList(\n", " (0): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(16, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(16, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (2): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (3): FactorizedReduce(\n", " C_in=16, C_out=40, stride=1\n", " (relu): ReLU()\n", " (conv): Conv2d(16, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (4): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (5): ReLUConvBN(\n", " (op): Sequential(\n", " (0): ReLU()\n", " (1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " )\n", " )\n", " )\n", " (lastact): Sequential(\n", " (0): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): ReLU(inplace=True)\n", " )\n", " (global_pooling): AdaptiveAvgPool2d(output_size=1)\n", " (classifier): Linear(in_features=40, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "print(network)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The model parameters are 0.067378 MB\n" ] } ], "source": [ "from xautodl.utils import count_parameters_in_MB\n", "print('The model parameters are {:} MB'.format(count_parameters_in_MB(network)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: notebooks/issue-27.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-10-21 07:08:52] Try to use the default NATS-Bench (topology) path from fast_mode=False and path=/Users/xuanyidong/.torch/NATS-tss-v1_0-3ffb9.pickle.pbz2.\n" ] } ], "source": [ "from nats_bench import create\n", "from nats_bench.api_utils import time_string\n", "import numpy as np\n", "\n", "# Create the API for size search space\n", "api_tss = create(None, \"tss\", fast_mode=False, verbose=False)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------ImageNet16-120--------------------------------------------------\n", "Best (10676) architecture on validation: |nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_1x1~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n", "Best (857) architecture on test: |nor_conv_1x1~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n", "using validation ::: validation = 46.73, test = 46.20\n", "\n", "using test ::: validation = 46.38, test = 47.31\n", "\n" ] } ], "source": [ "def get_valid_test_acc(api, arch, dataset):\n", " is_size_space = api.search_space_name == \"size\"\n", " if dataset == \"cifar10\":\n", " xinfo = api.get_more_info(\n", " arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False\n", " )\n", " test_acc = xinfo[\"test-accuracy\"]\n", " xinfo = api.get_more_info(\n", " arch,\n", " dataset=\"cifar10-valid\",\n", " hp=90 if is_size_space else 200,\n", " is_random=False,\n", " )\n", " valid_acc = xinfo[\"valid-accuracy\"]\n", " else:\n", " xinfo = api.get_more_info(\n", " arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False\n", " )\n", " valid_acc = xinfo[\"valid-accuracy\"]\n", " test_acc = xinfo[\"test-accuracy\"]\n", " return (\n", " valid_acc,\n", " test_acc,\n", " \"validation = {:.2f}, test = {:.2f}\\n\".format(valid_acc, test_acc),\n", " )\n", "\n", "def find_best_valid(api, dataset):\n", " all_valid_accs, all_test_accs = [], []\n", " for index, arch in enumerate(api):\n", " valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset)\n", " all_valid_accs.append((index, valid_acc))\n", " all_test_accs.append((index, test_acc))\n", " best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0]\n", " best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]\n", "\n", " print(\"-\" * 50 + \"{:10s}\".format(dataset) + \"-\" * 50)\n", " print(\n", " \"Best ({:}) architecture on validation: {:}\".format(\n", " best_valid_index, api[best_valid_index]\n", " )\n", " )\n", " print(\n", " \"Best ({:}) architecture on test: {:}\".format(\n", " best_test_index, api[best_test_index]\n", " )\n", " )\n", " _, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)\n", " print(\"using validation ::: {:}\".format(perf_str))\n", " _, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)\n", " print(\"using test ::: {:}\".format(perf_str))\n", "\n", "dataset = \"ImageNet16-120\"\n", "find_best_valid(api_tss, dataset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: notebooks/issue-30.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-11-15 01:40:00] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n" ] } ], "source": [ "from nats_bench import create\n", "import numpy as np\n", "\n", "# Create the API for the topology search space\n", "api = create(None, 'tss', fast_mode=True, verbose=False)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train-loss': 2.6261614637934385,\n", " 'train-accuracy': 33.934739619476986,\n", " 'train-per-time': 51.38146062405923,\n", " 'train-all-time': 10276.292124811845,\n", " 'valid-loss': 2.6819862944285076,\n", " 'valid-accuracy': 32.83333325195313,\n", " 'valid-per-time': 0.36819872515542285,\n", " 'valid-all-time': 73.63974503108457,\n", " 'test-loss': 2.7110290253957112,\n", " 'test-accuracy': 31.966666661580405,\n", " 'test-per-time': 0.36819872515542285,\n", " 'test-all-time': 73.63974503108457,\n", " 'valtest-loss': 2.6965076230367027,\n", " 'valtest-accuracy': 32.399999994913735,\n", " 'valtest-per-time': 0.7363974503108457,\n", " 'valtest-all-time': 147.27949006216915}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "name = 'ImageNet16-120'\n", "api.get_more_info(index=346, dataset=name, hp='200', iepoch=199, is_random=999)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "ename": "ValueError", "evalue": "can not find random seed (999) from [777, 888]", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_more_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m347\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'200'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miepoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m199\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_random\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m999\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/nats_bench/api_topology.py\u001b[0m in \u001b[0;36mget_more_info\u001b[0;34m(self, index, dataset, iepoch, hp, is_random)\u001b[0m\n\u001b[1;32m 218\u001b[0m \u001b[0mis_random\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseeds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0;31m# collect the training information\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 220\u001b[0;31m \u001b[0mtrain_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marchresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"train\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miepoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0miepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_random\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mis_random\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 221\u001b[0m \u001b[0mtotal\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"iepoch\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 222\u001b[0m xinfo = {\n", "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/nats_bench/api_utils.py\u001b[0m in \u001b[0;36mget_metrics\u001b[0;34m(self, dataset, setname, iepoch, is_random)\u001b[0m\n\u001b[1;32m 770\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_random\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# specify the seed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 771\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_random\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mx_seeds\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 772\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"can not find random seed ({:}) from {:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_random\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_seeds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 773\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx_seeds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_random\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 774\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minfos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: can not find random seed (999) from [777, 888]" ] } ], "source": [ "api.get_more_info(index=347, dataset=name, hp='200', iepoch=199, is_random=999)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: notebooks/issue-33.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-12-09 09:32:55] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n", "There are 15625 architectures on the size search space\n" ] } ], "source": [ "from nats_bench import create\n", "import numpy as np\n", "\n", "# Create the API for size search space\n", "api = create(None, 'tss', fast_mode=True, verbose=False)\n", "print('There are {:} architectures on the size search space'.format(len(api)))" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(44.82400000488281, 0.014427971421626575, 287.3714566230774, 287.3714566230774)\n", "(19.71, 0.01402865074299, 567.5892686843872, 854.9607253074646)\n" ] } ], "source": [ "print(api.simulate_train_eval(2, 'cifar10', iepoch=24, hp='200'))\n", "print(api.simulate_train_eval(2, 'cifar100', iepoch=24, hp='200'))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(30.21, 0.01402865074299, 2290.6443033218384, 3145.605028629303)\n", "(34.86, 0.01402865074299, 3424.233141899109, 6569.838170528412)\n", "(51.65, 0.01402865074299, 4104.386445045471, 10674.224615573883)\n", "(54.640000018310545, 0.01402865074299, 4535.150203704834, 15209.374819278717)\n" ] } ], "source": [ "print(api.simulate_train_eval(2, 'cifar100', iepoch=100, hp='200'))\n", "print(api.simulate_train_eval(2, 'cifar100', iepoch=150, hp='200'))\n", "print(api.simulate_train_eval(2, 'cifar100', iepoch=180, hp='200'))\n", "print(api.simulate_train_eval(2, 'cifar100', iepoch=199, hp='200'))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(25.98333334350586, 0.012760758399963379, 12450.773810838671, 27660.148630117386)\n", "(27.033333257039388, 0.012760758399963379, 13757.711054611173, 41417.85968472856)\n" ] } ], "source": [ "print(api.simulate_train_eval(2, 'ImageNet16-120', iepoch=180, hp='200'))\n", "print(api.simulate_train_eval(2, 'ImageNet16-120', iepoch=199, hp='200'))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: notebooks/issue-36.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from nats_bench import create\n", "from tqdm import tqdm\n", "import numpy as np\n", "\n", "def close_to(a, b, eps=1e-4):\n", " if b != 0 and abs(a-b) / abs(b) > eps:\n", " return False\n", " if a != 0 and abs(a-b) / abs(a) > eps:\n", " return False\n", " return True" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def check_flops_params(xapi):\n", " print(f\"Check {xapi}\")\n", " datasets = (\"cifar10-valid\", \"cifar10\", \"cifar100\", \"ImageNet16-120\")\n", " counts = 0\n", " for index in tqdm(range(len(xapi))):\n", " for dataset in datasets:\n", " info_12 = xapi.get_cost_info(index, dataset, hp=\"12\")\n", " info_full = xapi.get_cost_info(index, dataset, hp=xapi.full_epochs_in_paper)\n", " assert close_to(info_12['flops'], info_full['flops']), f\"The {index}-th \" \\\n", " f\"architecture has issues on {dataset} \" \\\n", " f\"-- {info_12['flops']} vs {info_full['flops']}.\" # check the FLOPs\n", " assert close_to(info_12['params'], info_full['params']), f\"The {index}-th \" \\\n", " f\"architecture has issues on {dataset} \" \\\n", " f\"-- {info_12['params']} vs {info_full['params']}.\" # check the number of parameters\n", " counts += 1\n", " print(f\"Check {xapi} completed -- {counts} arch-dataset pairs.\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Create the API for size search space\n", "api = create(None, 'sss', fast_mode=True, verbose=False)\n", "print(f'There are {len(api)} architectures in the size search space -- {api}')\n", "check_flops_params(api)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/15625 [00:00 highest_accuracy:\n", " highest_accuracy = accuracy\n", " best_arch = arch_index\n", " return arch_index" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random Search on ImageNet16-120 : 33.9463 $\\pm$ 9.4213\n" ] } ], "source": [ "# Just a small example, not the full experiment in the paper\n", "dataset = 'ImageNet16-120'\n", "rs_times, accuracies = 100, []\n", "for i in range(rs_times):\n", " arch_index = random_search(api_tss, dataset=dataset)\n", " info = api_tss.get_more_info(arch_index, dataset, hp='200', is_random=False)\n", " accuracies.append(info['test-accuracy'])\n", "print('Random Search on {:} : {:.4f} $\\pm$ {:.4f}'.format(dataset, np.mean(accuracies), np.std(accuracies)))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: setup.py ================================================ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # ##################################################### """The setup function for pypi.""" # The following is to make nats_bench avaliable on Python Package Index (PyPI) # # conda install -c conda-forge twine # Use twine to upload nats_bench to pypi # # python setup.py sdist bdist_wheel # python setup.py --help-commands # twine check dist/* # # twine upload --repository-url https://test.pypi.org/legacy/ dist/* # twine upload dist/* # https://pypi.org/project/nats-bench # # NOTE(xuanyidong): # local install = `pip install . --force` # local test = `pytest . -s` # # TODO(xuanyidong): upload it to conda # import os from setuptools import setup from nats_bench import version NAME = "nats_bench" REQUIRES_PYTHON = ">=3.6" DESCRIPTION = "API for NATS-Bench (a dataset/benchmark for neural architecture topology and size)." VERSION = version() def read(fname="README.md"): with open( os.path.join(os.path.dirname(__file__), fname), encoding="utf-8" ) as cfile: return cfile.read() # What packages are required for this module to be executed? REQUIRED = ["numpy>=1.16.5"] setup( name=NAME, version=VERSION, author="Xuanyi Dong", author_email="dongxuanyi888@gmail.com", description=DESCRIPTION, license="MIT Licence", keywords="NAS Dataset API DeepLearning", url="https://github.com/D-X-Y/NATS-Bench", packages=["nats_bench"], install_requires=REQUIRED, python_requires=REQUIRES_PYTHON, long_description=read("README.md"), long_description_content_type="text/markdown", classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 3", "Topic :: Database", "Topic :: Scientific/Engineering :: Artificial Intelligence", "License :: OSI Approved :: MIT License", ], ) ================================================ FILE: tests/api_test.py ================================================ ############################################################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 # ############################################################################################### # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size, IEEE TPAMI 2021 # ############################################################################################### # pytest --capture=tee-sys # ############################################################################################### """This file is used to quickly test the API.""" import os import pytest import random from nats_bench.genotype_utils import topology_str2structure from nats_bench.api_size import NATSsize from nats_bench.api_size import ALL_BASE_NAMES as sss_base_names from nats_bench.api_topology import NATStopology from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names def get_fake_torch_home_dir(): print("This file is {:}".format(os.path.abspath(__file__))) print("The current directory is {:}".format(os.path.abspath(os.getcwd()))) xname = "FAKE_TORCH_HOME" if xname in os.environ: return os.environ["FAKE_TORCH_HOME"] else: return os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", "fake_torch_dir" ) def close_to(a, b, eps=1e-4): if b != 0 and abs(a - b) / abs(b) > eps: return False if a != 0 and abs(a - b) / abs(a) > eps: return False return True class TestNATSBench(object): """A class to test different functions of NATS-Bench API.""" def test_nats_bench_tss(self, benchmark_dir=None, fake_random=True): if benchmark_dir is None: benchmark_dir = os.path.join( get_fake_torch_home_dir(), tss_base_names[-1] + "-simple" ) return _test_nats_bench(benchmark_dir, True, fake_random) def test_nats_bench_sss(self, benchmark_dir=None, fake_random=True): if benchmark_dir is None: benchmark_dir = os.path.join( get_fake_torch_home_dir(), sss_base_names[-1] + "-simple" ) return _test_nats_bench(benchmark_dir, False, fake_random) def prepare_fake_tss(self): tss_benchmark_dir = os.path.join( get_fake_torch_home_dir(), tss_base_names[-1] + "-simple" ) api = NATStopology(tss_benchmark_dir, True, False) return api def prepare_fake_sss(self): sss_benchmark_dir = os.path.join( get_fake_torch_home_dir(), sss_base_names[-1] + "-simple" ) api = NATSsize(sss_benchmark_dir, True, False) return api def test_01_th_issue(self): # Link: https://github.com/D-X-Y/NATS-Bench/issues/1 api = self.prepare_fake_tss() # The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs) info = api.get_more_info(0, "cifar10", hp=12) # First of all, the data split in NATS-Bench is different from that in the official CIFAR paper. # In NATS-Bench, we split the original CIFAR-10 training set into two parts, i.e., a training set and a validation set. # In the following, we will use the splits of NATS-Bench to explain. print(info["comment"]) print( "The loss on the training + validation sets of CIFAR-10: {:}".format( info["train-loss"] ) ) print( "The total training time for 12 epochs on the training + validation sets of CIFAR-10: {:}".format( info["train-all-time"] ) ) print( "The per-epoch training time on CIFAR-10: {:}".format( info["train-per-time"] ) ) print( "The total evaluation time on the test set of CIFAR-10 for 12 times: {:}".format( info["test-all-time"] ) ) print( "The evaluation time on the test set of CIFAR-10: {:}".format( info["test-per-time"] ) ) cost_info = api.get_cost_info(0, "cifar10") xkeys = [ "T-train@epoch", # The per epoch training time on the training + validation sets of CIFAR-10. "T-train@total", "T-ori-test@epoch", # The time cost for the evaluation on CIFAR-10 test set. "T-ori-test@total", ] # T-ori-test@epoch * 12 times. for xkey in xkeys: print( "The cost info [{:}] for 0-th architecture on CIFAR-10 is {:}".format( xkey, cost_info[xkey] ) ) def test_02_th_issue(self): # https://github.com/D-X-Y/NATS-Bench/issues/2 api = self.prepare_fake_tss() data = api.query_by_index(284, dataname="cifar10", hp=200) for xkey, xvalue in data.items(): print("{:} : {:}".format(xkey, xvalue)) xinfo = data[777].get_train() print(xinfo) print(data[777].train_acc1es) info_012_epochs = api.get_more_info(284, "cifar10", hp=12) print( "Train accuracy for 12 epochs is {:}".format( info_012_epochs["train-accuracy"] ) ) info_200_epochs = api.get_more_info(284, "cifar10", hp=200) print( "Train accuracy for 200 epochs is {:}".format( info_200_epochs["train-accuracy"] ) ) def test_07_th_issue(self): # https://github.com/D-X-Y/NATS-Bench/issues/7 apis = [self.prepare_fake_tss(), self.prepare_fake_sss()] indexes = [0, 11, 284] datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") for api in apis: for index in indexes: for dataset in datasets: _ = api.get_cost_info(index, dataset, hp="12") best_arch_index, highest_valid_accuracy = api.find_best( dataset=dataset, metric_on_set="valid", hp="12", enforce_all=False ) print( f"api={api}, best_arch_index={best_arch_index}, highest_valid_accuracy={highest_valid_accuracy}" ) def test_12_th_issue(self): # https://github.com/D-X-Y/NATS-Bench/issues/13 api = self.prepare_fake_tss() structures = [] for arch_index in range(len(api)): structures.append(topology_str2structure(api[arch_index])) unique_strs = [] for structure in structures: unique_strs.append(structure.to_unique_str(consider_zero=True)) unique_strs = set(unique_strs) assert len(unique_strs) == 6466, "{:} vs {:}".format(len(unique_strs), 6446) def test_36_th_issue(self): # https://github.com/D-X-Y/NATS-Bench/issues/36 apis = [self.prepare_fake_tss(), self.prepare_fake_sss()] indexes = [0, 11, 284] datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") for api in apis: for index in indexes: for dataset in datasets: info_12 = api.get_cost_info(index, dataset, hp="12") info_full = api.get_cost_info( index, dataset, hp=api.full_epochs_in_paper ) assert close_to(info_12["flops"], info_full["flops"]), ( f"The {index}-th " f"architecture has issues on {dataset} " f"-- {info_12['flops']} vs {info_full['flops']}." ) # check the FLOPs assert close_to(info_12["params"], info_full["params"]), ( f"The {index}-th " f"architecture has issues on {dataset} " f"-- {info_12['params']} vs {info_full['params']}." ) # check the number of parameters def test_44_th_issue(self): # https://github.com/D-X-Y/NATS-Bench/issues/44 benchmark_dir = os.path.join( get_fake_torch_home_dir(), tss_base_names[-1] + "-simple" ) return _test_nats_bench(benchmark_dir, True, fake_random=True, hp="200") def _test_nats_bench(benchmark_dir, is_tss, fake_random, hp="12", verbose=False): """The main test entry for NATS-Bench.""" if is_tss: api = NATStopology(benchmark_dir, True, verbose) else: api = NATSsize(benchmark_dir, True, verbose) if fake_random: test_indexes = [0, 11, 284] else: test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)] key2dataset = { "cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet16-120", } for index in test_indexes: print("\n\nEvaluate the {:5d}-th architecture.".format(index)) for key, dataset in key2dataset.items(): # Query the loss / accuracy / time for the `index`-th candidate # architecture on CIFAR-10 # info is a dict, where you can easily figure out the meaning by key info = api.get_more_info(index, key) print(" -->> The performance on {:}: {:}".format(dataset, info)) # Query the flops, params, latency. info is a dict. info = api.get_cost_info(index, key) print(" -->> The cost info on {:}: {:}".format(dataset, info)) # Simulate the training of the `index`-th candidate: ( validation_accuracy, latency, time_cost, current_total_time_cost, ) = api.simulate_train_eval(index, dataset=key, hp=hp) print( " -->> The validation accuracy={:}, latency={:}, " "the current time cost={:} s, accumulated time cost={:} s".format( validation_accuracy, latency, time_cost, current_total_time_cost ) ) # Print the configuration of the `index`-th architecture on CIFAR-10 config = api.get_net_config(index, key) print(" -->> The configuration on {:} is {:}".format(dataset, config)) # Show the information of the `index`-th architecture api.show(index) with pytest.raises(ValueError): api.get_more_info(100000, "cifar10")