Full Code of D-X-Y/NATS-Bench for AI

main 1d4a304ad190 cached
34 files
199.1 KB
55.3k tokens
142 symbols
1 requests
Download .txt
Showing preview only (209K chars total). Download the full file or copy to clipboard to get everything.
Repository: D-X-Y/NATS-Bench
Branch: main
Commit: 1d4a304ad190
Files: 34
Total size: 199.1 KB

Directory structure:
gitextract_fmzpgwc0/

├── .github/
│   ├── CODE-OF-CONDUCT.md
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug-report.md
│   │   └── question.md
│   └── workflows/
│       └── ci.yml
├── .gitignore
├── LICENSE.md
├── README.md
├── fake_torch_dir/
│   ├── NATS-sss-v1_0-50262-simple/
│   │   ├── 000000.pickle.pbz2
│   │   ├── 000011.pickle.pbz2
│   │   ├── 000284.pickle.pbz2
│   │   └── meta.pickle.pbz2
│   └── NATS-tss-v1_0-3ffb9-simple/
│       ├── 000000.pickle.pbz2
│       ├── 000011.pickle.pbz2
│       ├── 000284.pickle.pbz2
│       └── meta.pickle.pbz2
├── nats_bench/
│   ├── __init__.py
│   ├── api_size.py
│   ├── api_topology.py
│   ├── api_utils.py
│   └── genotype_utils.py
├── notebooks/
│   ├── README.md
│   ├── create-query-sss.ipynb
│   ├── find-largest.ipynb
│   ├── issue-11.ipynb
│   ├── issue-12.ipynb
│   ├── issue-21.ipynb
│   ├── issue-27.ipynb
│   ├── issue-30.ipynb
│   ├── issue-33.ipynb
│   ├── issue-36.ipynb
│   ├── issue-7.ipynb
│   └── random-search.ipynb
├── setup.py
└── tests/
    └── api_test.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/CODE-OF-CONDUCT.md
================================================
# Contributor Covenant Code of Conduct

## Our Pledge

In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.

## Our Standards

Examples of behavior that contributes to creating a positive environment
include:

* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
 advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
 address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
 professional setting

## Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.

## Scope

This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at dongxuanyi888@gmail.com. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq


================================================
FILE: .github/ISSUE_TEMPLATE/bug-report.md
================================================
---
name: Bug Report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''

---

**Describe the bug**
A clear and concise description of what the bug is.

**To Reproduce**
Please provide a small script to reproduce the behavior:
```
codes to reproduce the bug
```
Please let me know your OS, Python version, PyTorch version.

**Expected behavior**
A clear and concise description of what you expected to happen.

**Screenshots**
If applicable, add screenshots to help explain your problem.


================================================
FILE: .github/ISSUE_TEMPLATE/question.md
================================================
---
name: Questions about NATS-Bench
about: Ask questions about or discuss on NATS-Bench
title: ''
labels: ''
assignees: ''

---

**Describe the Question**
A clear and concise description of the question.
- Is it about the topology search space in NATS-Bench?
- Is it about the size search space in NATS-Bench?
- Which figure or table are you referring to in the paper?


================================================
FILE: .github/workflows/ci.yml
================================================
name: Run Python Tests
on:
  push:
    branches:
      - main
  pull_request:
    branches:
      - main

jobs:
  build:
    strategy:
      matrix:
        os: [ubuntu-18.04, ubuntu-20.04, macos-latest]
        python-version: [3.6, 3.7, 3.8, 3.9]

    runs-on: ${{ matrix.os }}
    steps:
      - uses: actions/checkout@v2

      - name: Set up Python ${{ matrix.python-version }}
        uses: actions/setup-python@v2
        with:
          python-version: ${{ matrix.python-version }}

      - name: Lint with Black
        run: |
          cd ..
          if [ "$RUNNER_OS" == "Windows" ]; then
            python.exe -m pip install black
            python.exe -m black NATS-Bench/nats_bench -l 88 --check --diff
            python.exe -m black NATS-Bench/tests -l 88 --check --diff
          else
            python -m pip install black
            python --version
            python -m black --version
            echo $PWD
            ls
            python -m black NATS-Bench/nats_bench -l 88 --check --diff --verbose
            python -m black NATS-Bench/tests -l 88 --check --diff --verbose
          fi
        shell: bash

      - name: Install nats_bench from source
        run: |
          pip install .

      - name: Run tests with pytest
        run: |
          export FAKE_TORCH_HOME="fake_torch_dir"
          python -m pip install pytest
          python -m pytest . --durations=0

      - name: Install nats_bench from pip with tests
        run: |
          pip uninstall -y nats_bench
          python -m pip install nats_bench
          export FAKE_TORCH_HOME="fake_torch_dir"
          python -m pip install pytest
          python -m pytest . --durations=0


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# IPython Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# dotenv
.env

# virtualenv
venv/
ENV/

# Spyder project settings
.spyderproject

# Rope project settings
.ropeproject

# Pycharm project
.idea
snapshots
*.pytorch
*.tar.bz
data
.*.swp
*.sh
main_main.py
dist
build
*.egg-info

.DS_Store


================================================
FILE: LICENSE.md
================================================
MIT License

Copyright (c) since 2020 Xuanyi Dong (GitHub: https://github.com/D-X-Y)

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# [NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size](https://arxiv.org/abs/2009.00437)

Xuanyi Dong, Lu Liu, Katarzyna Musial, Bogdan Gabrys

in IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI), 2021

**Abstract**: Neural architecture search (NAS) has attracted a lot of attention and has been illustrated to bring tangible benefits in a large number of applications in the past few years. Network topology and network size have been regarded as two of the most important aspects for the performance of deep learning models and the community has spawned lots of searching algorithms for both of those aspects of the neural architectures. However, the performance gain from these searching algorithms is achieved under different search spaces and training setups. This makes the overall performance of the algorithms incomparable and the improvement from a sub-module of the searching model unclear.
In this paper, we propose NATS-Bench, a unified benchmark on searching for both topology and size, for (almost) any up-to-date NAS algorithm.
NATS-Bench includes the search space of 15,625 neural cell candidates for architecture topology and 32,768 for architecture size on three datasets.
We analyze the validity of our benchmark in terms of various criteria and performance comparison of all candidates in the search space.
We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-the-art NAS algorithms on it. All logs and diagnostic information trained using the same setup for each candidate are provided.
This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment.

**You can use `pip install nats_bench` to install the library of NATS-Bench
or install from source by `pip install .`.**

If you are seeking how to re-create NATS-Bench from scratch or reproduce benchmarked results, please see use [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) and see these [instructions](https://github.com/D-X-Y/NATS-Bench#how-to-re-create-nats-bench-from-scratch).

If you have questions, please ask at [here](https://github.com/D-X-Y/NATS-Bench/issues) or [email me](mailto:dongxuanyi888@gmail.com) :)

This figure is the main difference between `NATS-Bench`, `NAS-Bench-101`, and `NAS-Bench-201`. The `topology search space` (`$\mathcal{S}_t$`) in `NATS-Bench` is the same as `NAS-Bench-201`, while we upgrade with results of more runs for the architecture candidates, and the benchmarked NAS algorithms have better hyperparameters.
<p align="center">
<img src="https://xuanyidong.com/resources/images/NATS-compare.png" width="700"/>
</p>


## Preparation and Download

**Step-1: download raw vision datasets.** (you can skip this one if you do not use weight-sharing NAS or re-create NATS-Bench).

In NATS-Bench, we (create and) use three image datasets -- CIFAR-10, CIFAR-100, and ImageNet16-120.
For more details, please see Sec-3.2 in [the NATS-Bench paper](https://arxiv.org/pdf/2009.00437.pdf). To download these three datasets, please find them at [Google Drive](https://drive.google.com/drive/folders/1T3UIyZXUhMmIuJLOBMIYKAsJknAtrrO4?usp=sharing).
To create the `ImageNet16-120` PyTorch dataset, please call [AutoDL-Projects/lib/datasets/ImageNet16](https://github.com/D-X-Y/AutoDL-Projects/blob/main/xautodl/datasets/get_dataset_with_transform.py#L222-L225), by using:
```
train_data = ImageNet16(root, True , train_transform, 120)
test_data  = ImageNet16(root, False, test_transform , 120)
```

**Step-2: download benchmark files of NATS-Bench.**

The **latest** benchmark file of NATS-Bench can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zjB6wMANiKwB2A1yil2hQ8H_qyeSe2yt?usp=sharing).
After download `NATS-[tss/sss]-[version]-[md5sum]-simple.tar`, please uncompress it by using `tar xvf [file_name]`.
We highly recommend to put the downloaded benchmark file (`NATS-sss-v1_0-50262.pickle.pbz2` / `NATS-tss-v1_0-3ffb9.pickle.pbz2`) or uncompressed archive (`NATS-sss-v1_0-50262-simple` / `NATS-tss-v1_0-3ffb9-simple`) into `$TORCH_HOME`.
In this way, our api will automatically find the path for these benchmark files, which are convenient for the users. Otherwise, you need to indicate the file when creating the benchmark instance manually.

The history of benchmark files is as follows, `tss` indicates the topology search space and `sss` indicates the size search space.
The benchmark file is used when creating the NATS-Bench instance with `fast_mode=False`.
The archive is used when `fast_mode=True`, where `archive` is a directory containing 15,625 files for tss or contains 32,768 files for sss.
Each file contains all the information for a specific architecture candidate.
The `full archive` is similar to `archive`, while each file in `full archive` contains **the trained weights**.
Since the full archive is too large, we use `split -b 30G file_name file_name` to split it into multiple 30G chunks.
To merge the chunks into the original full archive, you can use `cat file_name* > file_name`.

|   Date     |  benchmark file (tss) | (unpacked benchmark file) archive (tss) | full archive (tss) |       benchmark file (sss)      | (unpacked benchmark file) archive (sss)        | full archive (sss) |
|:-----------|:---------------------:|:-------------:|:------------------:|:-------------------------------:|:--------------------------:|:------------------:|
| 2020.08.31 | [NATS-tss-v1_0-3ffb9.pickle.pbz2](https://drive.google.com/file/d/1vzyK0UVH2D3fTpa1_dSWnp1gvGpAxRul/view?usp=sharing) | [NATS-tss-v1_0-3ffb9-simple.tar](https://drive.google.com/file/d/17_saCsj_krKjlCBLOJEpNtzPXArMCqxU/view?usp=sharing) | [NATS-tss-v1_0-3ffb9-full](https://drive.google.com/drive/folders/17S2Xg_rVkUul4KuJdq0WaWoUuDbo8ZKB?usp=sharing) | [NATS-sss-v1_0-50262.pickle.pbz2](https://drive.google.com/file/d/1IabIvzWeDdDAWICBzFtTCMXxYWPIOIOX/view?usp=sharing) | [NATS-sss-v1_0-50262-simple.tar](https://drive.google.com/file/d/1scOMTUwcQhAMa_IMedp9lTzwmgqHLGgA/view?usp=sharing) | [NATS-sss-v1_0-50262-full](https://drive.google.com/drive/folders/1xutPQJ4bHoUV1EMArsPD0c1bUqvtMuYY?usp=sharing) |
| 2021.04.22 (Baidu-Pan) | [NATS-tss-v1_0-3ffb9.pickle.pbz2](https://pan.baidu.com/s/10z20F5s2RRPzGwRO40fLTw) (code: 8duj) | [NATS-tss-v1_0-3ffb9-simple.tar](https://pan.baidu.com/s/1vOnrHLxCB4y8cxUDrHUYAg) (code: tu1e) | [NATS-tss-v1_0-3ffb9-full](https://pan.baidu.com/s/1qbPNlI8Y1I29qMdxTo_ycA) (code:ssub) | [NATS-sss-v1_0-50262.pickle.pbz2](https://pan.baidu.com/s/1M1UaXr6y1D_RqEYg95YJcA) (code: za2h) | [NATS-sss-v1_0-50262-simple.tar](https://pan.baidu.com/s/1ek-b89Pw2qdm9MP6KKkErA) (code: e4t9) | [NATS-sss-v1_0-50262-full](https://pan.baidu.com/s/1bIruQd9pPeyArej_wttg_A) (code: htif) |

These benchmark files (without pretrained weights) can also be downloaded from [Dropbox](https://www.dropbox.com/sh/ceeo70u1buow681/AAC2M-SbKOxiIqpB0UCgXNxja?dl=0), [OneDrive](https://1drv.ms/u/s!Aqkc27lrowWDf6SvuIkSXx0UQaI?e=nfvM5r) or [Baidu-Pan (extract code: h6pm)](https://pan.baidu.com/s/144VC2BDm6iXbAVzMUpqO7A).

For the full checkpoints in `NATS-*ss-*-full`, we split the file into multiple parts (`NATS-*ss-*-full.tara*`) since they are too large to upload.
Each file is about `30GB`. For Baidu Pan, since they restrict the maximum size of each file, we further split `NATS-*ss-*-full.tara*` into `NATS-*ss-*-full.tara*-aa` and `NATS-*ss-*-full.tara*-ab`.
All splits are created by the command `split`.

**Note:** if you encounter the `quota exceed erros` when download from Google Drive, please try to (1) login your personal Google account, (2) right-click-copy the files to your personal Google Drive, and (3) download from your personal Google Drive.

## Usage

See more examples at [notebooks](notebooks).

#### 1, create the benchmark instance:
```
from nats_bench import create
# Create the API instance for the size search space in NATS
api = create(None, 'sss', fast_mode=True, verbose=True)

# Create the API instance for the topology search space in NATS
api = create(None, 'tss', fast_mode=True, verbose=True)
```

#### 2, query the performance:
```
# Show the architecture topology string of the 12-th architecture
# For the topology search space, the string is interpreted as
# arch = '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.format(
#         edge_node_0_to_node_1,
#         edge_node_0_to_node_2,
#         edge_node_1_to_node_2,
#         edge_node_0_to_node_3,
#         edge_node_1_to_node_3,
#         edge_node_2_to_node_3,
#         )
# For the size search space, the string is interpreted as
# arch = '{}:{}:{}:{}:{}'.format(out_channel_of_1st_conv_layer,
#                                out_channel_of_1st_cell_stage,
#                                out_channel_of_1st_residual_block,
#                                out_channel_of_2nd_cell_stage,
#                                out_channel_of_2nd_residual_block,
#                                )
architecture_str = api.arch(12)
print(architecture_str)

# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10
# info is a dict, where you can easily figure out the meaning by key
info = api.get_more_info(1234, 'cifar10')

# Query the flops, params, latency. info is a dict.
info = api.get_cost_info(12, 'cifar10')

# Simulate the training of the 1224-th candidate:
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')
```

#### 3, create the instance of an architecture candidate in `NATS-Bench`:
```
# Create the instance of th 12-th candidate for CIFAR-10.
# To keep NATS-Bench repo concise, we did not include any model-related codes here because they rely on PyTorch.
# The package of [models] is defined at https://github.com/D-X-Y/AutoDL-Projects
#   so that one need to first import this package.
import xautodl
from xautodl.models import get_cell_based_tiny_net
config = api.get_net_config(12, 'cifar10')
network = get_cell_based_tiny_net(config)

# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.
params = api.get_net_param(12, 'cifar10', None)
network.load_state_dict(next(iter(params.values())))
```

#### 4, others:
```
# Clear the parameters of the 12-th candidate.
api.clear_params(12)

# Reload all information of the 12-th candidate.
api.reload(index=12)

```

Please see [`api_test.py`](https://github.com/D-X-Y/NATS-Bench/blob/main/tests/api_test.py) for more examples.
```
from nats_bench import api_test
api_test.test_nats_bench_tss('NATS-tss-v1_0-3ffb9-simple')
api_test.test_nats_bench_tss('NATS-sss-v1_0-50262-simple')
```



## How to Re-create NATS-Bench from Scratch

**You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to re-create NATS-Bench from scratch.**

### The Size Search Space

The following command will train all architecture candidate in the size search space with 90 epochs and use the random seed of `777`.
If you want to use a different number of training epochs, please replace `90` with it, such as `01` or `12`.
```
bash ./scripts/NATS-Bench/train-shapes.sh 00000-32767 90 777
```
The checkpoint of all candidates are located at `output/NATS-Bench-size` by default.

After training these candidate architectures, please use the following command to re-organize all checkpoints into the official benchmark file.
```
python exps/NATS-Bench/sss-collect.py
```

### The Topology Search Space

The following command will train all architecture candidate in the topology search space with 200 epochs and use the random seed of `777`/`888`/`999`.
If you want to use a different number of training epochs, please replace `200` with it, such as `12`.
```
bash scripts/NATS-Bench/train-topology.sh 00000-15624 200 '777 888 999'
```
The checkpoint of all candidates are located at `output/NATS-Bench-topology` by default.

After training these candidate architectures, please use the following command to re-organize all checkpoints into the official benchmark file.
```
python exps/NATS-Bench/tss-collect.py
```


## To Reproduce 13 Baseline NAS Algorithms in NATS-Bench

**You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to run 13 baseline NAS methods.** Here are a brief introduction on how to run each algorithm ([NATS-algos](https://github.com/D-X-Y/AutoDL-Projects/tree/main/exps/NATS-algos)).

### Reproduce NAS methods on the topology search space

Please use the following commands to run different NAS methods on the topology search space:
```
Four multi-trial based methods:
python ./exps/NATS-algos/reinforce.py       --dataset cifar100 --search_space tss --learning_rate 0.01
python ./exps/NATS-algos/regularized_ea.py  --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3
python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space tss
python ./exps/NATS-algos/bohb.py            --dataset cifar100 --search_space tss --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3

DARTS (first order):
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v1
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1

DARTS (second order):
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v2
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2

GDAS:
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo gdas
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16

SETN:
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo setn
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn

Random Search with Weight Sharing:
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo random
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random

ENAS:
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
```

### Reproduce NAS methods on the size search space

Please use the following commands to run different NAS methods on the size search space:
```
Four multi-trial based methods:
python ./exps/NATS-algos/reinforce.py       --dataset cifar100 --search_space sss --learning_rate 0.01
python ./exps/NATS-algos/regularized_ea.py  --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3
python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space sss
python ./exps/NATS-algos/bohb.py            --dataset cifar100 --search_space sss --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3


Run Transformable Architecture Search (TAS), proposed in Network Pruning via Transformable Architecture Search, NeurIPS 2019

python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777


Run the channel search strategy in FBNet-V2 -- masking + Gumbel-Softmax :

python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777


Run the channel search strategy in TuNAS -- masking + sampling :

python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
```

### Final Discovered Architectures for Each Algorithm

The architecture index can be found by use `api.query_index_by_arch(architecture_string)`.

The final discovered architecture ID on CIFAR-10:
```
DARTS (first order):
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|

DARTS (second order):
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|

GDAS:
|nor_conv_3x3~0|+|nor_conv_3x3~0|none~1|+|nor_conv_1x1~0|nor_conv_3x3~1|nor_conv_3x3~2|
|nor_conv_3x3~0|+|nor_conv_3x3~0|none~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|
|avg_pool_3x3~0|+|nor_conv_3x3~0|skip_connect~1|+|nor_conv_3x3~0|nor_conv_1x1~1|nor_conv_1x1~2|
```

The final discovered architecture ID on CIFAR-100:
```
DARTS (V1):
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|none~2|
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|none~2|
|skip_connect~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|nor_conv_3x3~2|

DARTS (V2):
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|skip_connect~2|
|skip_connect~0|+|nor_conv_3x3~0|none~1|+|skip_connect~0|none~1|none~2|
|skip_connect~0|+|nor_conv_1x1~0|none~1|+|nor_conv_3x3~0|skip_connect~1|none~2|

GDAS:
|nor_conv_3x3~0|+|nor_conv_1x1~0|none~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|
|avg_pool_3x3~0|+|nor_conv_1x1~0|none~1|+|nor_conv_3x3~0|avg_pool_3x3~1|nor_conv_1x1~2|
|avg_pool_3x3~0|+|nor_conv_3x3~0|none~1|+|nor_conv_3x3~0|nor_conv_1x1~1|nor_conv_1x1~2|
```

The final discovered architecture ID on ImageNet-16-120:
```
DARTS (V1):
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|nor_conv_1x1~2|

DARTS (V2):
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|skip_connect~2|

GDAS:
|none~0|+|none~0|none~1|+|nor_conv_3x3~0|none~1|none~2|
|none~0|+|none~0|none~1|+|nor_conv_3x3~0|none~1|none~2|
|none~0|+|none~0|none~1|+|nor_conv_3x3~0|none~1|none~2|
```

## Others

We use [`black`](https://github.com/psf/black) for Python code formatter.
Please use `black . -l 120`.

## Citation

If you find that NATS-Bench helps your research, please consider citing it:
```
@article{dong2021nats,
  title   = {{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
  author  = {Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
  doi     = {10.1109/TPAMI.2021.3054824},
  journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
  year    = {2021},
  note    = {\mbox{doi}:\url{10.1109/TPAMI.2021.3054824}}
}
@inproceedings{dong2020nasbench201,
  title     = {{NAS-Bench-201}: Extending the Scope of Reproducible Neural Architecture Search},
  author    = {Dong, Xuanyi and Yang, Yi},
  booktitle = {International Conference on Learning Representations (ICLR)},
  url       = {https://openreview.net/forum?id=HJxyZkBKDr},
  year      = {2020}
}
```


================================================
FILE: nats_bench/__init__.py
================================================
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
"""The official Application Programming Interface (API) for NATS-Bench."""

from typing import Text, Optional
from nats_bench.api_size import NATSsize
from nats_bench.api_topology import NATStopology
from nats_bench.api_utils import ArchResults
from nats_bench.api_utils import pickle_load
from nats_bench.api_utils import pickle_save
from nats_bench.api_utils import ResultsCount


NATS_BENCH_API_VERSIONs = [
    "v1.0",  # [2020.08.31] initialize
    "v1.1",  # [2020.12.20] add unit tests
    "v1.2",  # [2021.03.17] black re-formulate
    "v1.3",  # [2021.04.08] fix find_best issue for fast_mode=True
    "v1.4",  # [2021.04.30] add topology_str2structure
    "v1.5",  # [2021.12.09] make simulate_train_eval more robust
    "v1.6",  # [2022.01.19] fix the inconsistent flop/params which is caused by a legacy (weight migration) issue
    "v1.7",  # [2022.03.25] relax enforce_all kwargs and add a test
    "v1.8",  # [2022.10.06] fix bugs at issues/44
]
NATS_BENCH_SSS_NAMEs = ("sss", "size")
NATS_BENCH_TSS_NAMEs = ("tss", "topology")


def version():
    return NATS_BENCH_API_VERSIONs[-1]


def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
    """Create the instead for NATS API.

    Args:
      file_path_or_dict: None or a file path or a directory path.
      search_space: This is a string indicates the search space in NATS-Bench.
      fast_mode: If True, we will not load all the data at initialization,
        instead, the data for each candidate architecture will be loaded when
        quering it; If False, we will load all the data during initialization.
      verbose: This is a flag to indicate whether log additional information.

    Raises:
      ValueError: If not find the matched serach space description.

    Returns:
      The created NATS-Bench API.
    """
    if search_space in NATS_BENCH_TSS_NAMEs:
        return NATStopology(file_path_or_dict, fast_mode, verbose)
    elif search_space in NATS_BENCH_SSS_NAMEs:
        return NATSsize(file_path_or_dict, fast_mode, verbose)
    else:
        raise ValueError("invalid search space : {:}".format(search_space))


def search_space_info(main_tag: Text, aux_tag: Optional[Text]):
    """Obtain the search space information."""
    nats_sss = dict(candidates=[8, 16, 24, 32, 40, 48, 56, 64], num_layers=5)
    nats_tss = dict(
        op_names=[
            "none",
            "skip_connect",
            "nor_conv_1x1",
            "nor_conv_3x3",
            "avg_pool_3x3",
        ],
        num_nodes=4,
    )
    if main_tag == "nats-bench":
        if aux_tag in NATS_BENCH_SSS_NAMEs:
            return nats_sss
        elif aux_tag in NATS_BENCH_TSS_NAMEs:
            return nats_tss
        else:
            raise ValueError("Unknown auxiliary tag: {:}".format(aux_tag))
    elif main_tag == "nas-bench-201":
        if aux_tag is not None:
            raise ValueError("For NAS-Bench-201, the auxiliary tag should be None.")
        return nats_tss
    else:
        raise ValueError("Unknown main tag: {:}".format(main_tag))


================================================
FILE: nats_bench/api_size.py
================================================
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# The history of benchmark files are as follows,                             #
# where the format is (the name is NATS-sss-[version]-[md5].pickle.pbz2)     #
# [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2                               #
##############################################################################
# pylint: disable=line-too-long
"""The API for size search space in NATS-Bench."""
import collections
import copy
import os
import random
from typing import Dict, Optional, Text, Union, Any

from nats_bench.api_utils import ArchResults
from nats_bench.api_utils import NASBenchMetaAPI
from nats_bench.api_utils import get_torch_home
from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
from nats_bench.api_utils import pickle_load
from nats_bench.api_utils import time_string


ALL_BASE_NAMES = ["NATS-sss-v1_0-50262"]


def print_information(information, extra_info=None, show=False):
    """print out the information of a given ArchResults."""
    dataset_names = information.get_dataset_names()
    strings = [
        information.arch_str,
        "datasets : {:}, extra-info : {:}".format(dataset_names, extra_info),
    ]

    def metric2str(loss, acc):
        return "loss = {:.3f} & top1 = {:.2f}%".format(loss, acc)

    for dataset in dataset_names:
        metric = information.get_compute_costs(dataset)
        flop, param, latency = metric["flops"], metric["params"], metric["latency"]
        str1 = "{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.".format(
            dataset,
            flop,
            param,
            "{:.2f}".format(latency * 1000)
            if latency is not None and latency > 0
            else None,
        )
        train_info = information.get_metrics(dataset, "train")
        if dataset == "cifar10-valid":
            valid_info = information.get_metrics(dataset, "x-valid")
            test__info = information.get_metrics(dataset, "ori-test")
            str2 = "{:14s} train : [{:}], valid : [{:}], test : [{:}]".format(
                dataset,
                metric2str(train_info["loss"], train_info["accuracy"]),
                metric2str(valid_info["loss"], valid_info["accuracy"]),
                metric2str(test__info["loss"], test__info["accuracy"]),
            )
        elif dataset == "cifar10":
            test__info = information.get_metrics(dataset, "ori-test")
            str2 = "{:14s} train : [{:}], test  : [{:}]".format(
                dataset,
                metric2str(train_info["loss"], train_info["accuracy"]),
                metric2str(test__info["loss"], test__info["accuracy"]),
            )
        else:
            valid_info = information.get_metrics(dataset, "x-valid")
            test__info = information.get_metrics(dataset, "x-test")
            str2 = "{:14s} train : [{:}], valid : [{:}], test : [{:}]".format(
                dataset,
                metric2str(train_info["loss"], train_info["accuracy"]),
                metric2str(valid_info["loss"], valid_info["accuracy"]),
                metric2str(test__info["loss"], test__info["accuracy"]),
            )
        strings += [str1, str2]
    if show:
        print("\n".join(strings))
    return strings


class NATSsize(NASBenchMetaAPI):
    """This is the class for the API of size search space in NATS-Bench."""

    def __init__(
        self,
        file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
        fast_mode: bool = False,
        verbose: bool = True,
    ):
        """The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
        self._all_base_names = ALL_BASE_NAMES
        self.filename = None
        self._search_space_name = "size"
        self._fast_mode = fast_mode
        self._archive_dir = None
        self._full_train_epochs = 90
        self.reset_time()
        if file_path_or_dict is None:
            if self._fast_mode:
                self._archive_dir = os.path.join(
                    get_torch_home(), "{:}-simple".format(ALL_BASE_NAMES[-1])
                )
            else:
                file_path_or_dict = os.path.join(
                    get_torch_home(), "{:}.{:}".format(ALL_BASE_NAMES[-1], PICKLE_EXT)
                )
            print(
                "{:} Try to use the default NATS-Bench (size) path from "
                "fast_mode={:} and path={:}.".format(
                    time_string(), self._fast_mode, file_path_or_dict
                )
            )
        if isinstance(file_path_or_dict, str):
            file_path_or_dict = str(file_path_or_dict)
            if verbose:
                print(
                    "{:} Try to create the NATS-Bench (size) api "
                    "from {:} with fast_mode={:}".format(
                        time_string(), file_path_or_dict, fast_mode
                    )
                )
            if not nats_is_file(file_path_or_dict) and not nats_is_dir(
                file_path_or_dict
            ):
                raise ValueError(
                    "{:} is neither a file or a dir.".format(file_path_or_dict)
                )
            self.filename = os.path.basename(file_path_or_dict)
            if fast_mode:
                if nats_is_file(file_path_or_dict):
                    raise ValueError(
                        "fast_mode={:} must feed the path for directory "
                        ": {:}".format(fast_mode, file_path_or_dict)
                    )
                else:
                    self._archive_dir = file_path_or_dict
            else:
                if nats_is_dir(file_path_or_dict):
                    raise ValueError(
                        "fast_mode={:} must feed the path for file "
                        ": {:}".format(fast_mode, file_path_or_dict)
                    )
                else:
                    file_path_or_dict = pickle_load(file_path_or_dict)
        elif isinstance(file_path_or_dict, dict):
            file_path_or_dict = copy.deepcopy(file_path_or_dict)
        self.verbose = verbose
        if isinstance(file_path_or_dict, dict):
            keys = ("meta_archs", "arch2infos", "evaluated_indexes")
            for key in keys:
                if key not in file_path_or_dict:
                    raise ValueError("Can not find key[{:}] in the dict".format(key))
            self.meta_archs = copy.deepcopy(file_path_or_dict["meta_archs"])
            # NOTE(xuanyidong): This is a dict mapping each architecture to a dict,
            # where the key is #epochs and the value is ArchResults
            self.arch2infos_dict = collections.OrderedDict()
            self._avaliable_hps = set()
            for xkey in sorted(list(file_path_or_dict["arch2infos"].keys())):
                all_infos = file_path_or_dict["arch2infos"][xkey]
                hp2archres = collections.OrderedDict()
                for hp_key, results in all_infos.items():
                    hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
                    self._avaliable_hps.add(
                        hp_key
                    )  # save the avaliable hyper-parameter
                self.arch2infos_dict[xkey] = hp2archres
            self.evaluated_indexes = set(file_path_or_dict["evaluated_indexes"])
        elif self.archive_dir is not None:
            benchmark_meta = pickle_load(
                "{:}/meta.{:}".format(self.archive_dir, PICKLE_EXT)
            )
            self.meta_archs = copy.deepcopy(benchmark_meta["meta_archs"])
            self.arch2infos_dict = collections.OrderedDict()
            self._avaliable_hps = set()
            self.evaluated_indexes = set()
        else:
            raise ValueError(
                "file_path_or_dict [{:}] must be a dict or archive_dir "
                "must be set".format(type(file_path_or_dict))
            )
        self.archstr2index = {}
        for idx, arch in enumerate(self.meta_archs):
            if arch in self.archstr2index:
                raise ValueError(
                    "This [{:}]-th arch {:} already in the "
                    "dict ({:}).".format(idx, arch, self.archstr2index[arch])
                )
            self.archstr2index[arch] = idx
        if self.verbose:
            print(
                "{:} Create NATS-Bench (size) done with {:}/{:} architectures "
                "avaliable.".format(
                    time_string(), len(self.evaluated_indexes), len(self.meta_archs)
                )
            )

    @property
    def is_size(self):
        return True

    @property
    def is_topology(self):
        return False

    @property
    def full_epochs_in_paper(self):
        return 90

    def query_info_str_by_arch(self, arch, hp: Text = "12"):
        """Query the information of a specific architecture.

        Args:
          arch: it can be an architecture index or an architecture string.

          hp: the hyperparamete indicator, could be 01, 12, or 90. The difference
              between these three configurations are the number of training epochs.

        Returns:
          ArchResults instance
        """
        if self.verbose:
            print(
                "{:} Call query_info_str_by_arch with arch={:}"
                "and hp={:}".format(time_string(), arch, hp)
            )
        return self._query_info_str_by_arch(arch, hp, print_information)

    def get_more_info(
        self, index, dataset, iepoch=None, hp: Text = "12", is_random: bool = True
    ):
        """Return the metric for the `index`-th architecture.

        Args:
          index: the architecture index.
          dataset:
              'cifar10-valid'  : using the proposed train set of CIFAR-10 as the training set
              'cifar10'        : using the proposed train+valid set of CIFAR-10 as the training set
              'cifar100'       : using the proposed train set of CIFAR-100 as the training set
              'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
          iepoch: the index of training epochs from 0 to 11/199.
              When iepoch=None, it will return the metric for the last training epoch
              When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
          hp: indicates different hyper-parameters for training
              When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs
              When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs
              When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs
          is_random:
              When is_random=True, the performance of a random architecture will be returned
              When is_random=False, the performanceo of all trials will be averaged.

        Returns:
          a dict, where key is the metric name and value is its value.
        """
        if self.verbose:
            print(
                "{:} Call the get_more_info function with index={:}, dataset={:}, "
                "iepoch={:}, hp={:}, and is_random={:}.".format(
                    time_string(), index, dataset, iepoch, hp, is_random
                )
            )
        index = self.query_index_by_arch(
            index
        )  # To avoid the input is a string or an instance of a arch object
        self._prepare_info(index)
        if index not in self.arch2infos_dict:
            raise ValueError("Did not find {:} from arch2infos_dict.".format(index))
        archresult = self.arch2infos_dict[index][str(hp)]
        # if randomly select one trial, select the seed at first
        if isinstance(is_random, bool) and is_random:
            seeds = archresult.get_dataset_seeds(dataset)
            is_random = random.choice(seeds)
        # collect the training information
        train_info = archresult.get_metrics(
            dataset, "train", iepoch=iepoch, is_random=is_random
        )
        total = train_info["iepoch"] + 1
        xinfo = {
            "train-loss": train_info["loss"],
            "train-accuracy": train_info["accuracy"],
            "train-per-time": train_info["all_time"] / total,
            "train-all-time": train_info["all_time"],
        }
        # collect the evaluation information
        if dataset == "cifar10-valid":
            valid_info = archresult.get_metrics(
                dataset, "x-valid", iepoch=iepoch, is_random=is_random
            )
            try:
                test_info = archresult.get_metrics(
                    dataset, "ori-test", iepoch=iepoch, is_random=is_random
                )
            except Exception as unused_e:  # pylint: disable=broad-except
                test_info = None
            valtest_info = None
            xinfo[
                "comment"
            ] = "In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.".format(
                hp
            )
        else:
            if dataset == "cifar10":
                xinfo[
                    "comment"
                ] = "In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.".format(
                    hp
                )
            try:  # collect results on the proposed test set
                if dataset == "cifar10":
                    test_info = archresult.get_metrics(
                        dataset, "ori-test", iepoch=iepoch, is_random=is_random
                    )
                else:
                    test_info = archresult.get_metrics(
                        dataset, "x-test", iepoch=iepoch, is_random=is_random
                    )
            except Exception as unused_e:  # pylint: disable=broad-except
                test_info = None
            try:  # collect results on the proposed validation set
                valid_info = archresult.get_metrics(
                    dataset, "x-valid", iepoch=iepoch, is_random=is_random
                )
            except Exception as unused_e:  # pylint: disable=broad-except
                valid_info = None
            try:
                if dataset != "cifar10":
                    valtest_info = archresult.get_metrics(
                        dataset, "ori-test", iepoch=iepoch, is_random=is_random
                    )
                else:
                    valtest_info = None
            except Exception as unused_e:  # pylint: disable=broad-except
                valtest_info = None
        if valid_info is not None:
            xinfo["valid-loss"] = valid_info["loss"]
            xinfo["valid-accuracy"] = valid_info["accuracy"]
            xinfo["valid-per-time"] = valid_info["all_time"] / total
            xinfo["valid-all-time"] = valid_info["all_time"]
        if test_info is not None:
            xinfo["test-loss"] = test_info["loss"]
            xinfo["test-accuracy"] = test_info["accuracy"]
            xinfo["test-per-time"] = test_info["all_time"] / total
            xinfo["test-all-time"] = test_info["all_time"]
        if valtest_info is not None:
            xinfo["valtest-loss"] = valtest_info["loss"]
            xinfo["valtest-accuracy"] = valtest_info["accuracy"]
            xinfo["valtest-per-time"] = valtest_info["all_time"] / total
            xinfo["valtest-all-time"] = valtest_info["all_time"]
        return xinfo

    def show(self, index: int = -1) -> None:
        """Print the information of a specific (or all) architecture(s)."""
        self._show(index, print_information)


================================================
FILE: nats_bench/api_topology.py
================================================
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# The history of benchmark files are as follows,                             #
# where the format is (the name is NATS-tss-[version]-[md5].pickle.pbz2)     #
# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2                               #
##############################################################################
# pylint: disable=line-too-long
"""The API for topology search space in NATS-Bench."""
import collections
import copy
import os
import random
from typing import Any, Dict, List, Optional, Text, Union

from nats_bench.api_utils import ArchResults
from nats_bench.api_utils import NASBenchMetaAPI
from nats_bench.api_utils import get_torch_home
from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
from nats_bench.api_utils import pickle_load
from nats_bench.api_utils import time_string

from nats_bench.genotype_utils import topology_str2structure


ALL_BASE_NAMES = ["NATS-tss-v1_0-3ffb9"]


def print_information(information, extra_info=None, show=False):
    """print out the information of a given ArchResults."""
    dataset_names = information.get_dataset_names()
    strings = [
        information.arch_str,
        "datasets : {:}, extra-info : {:}".format(dataset_names, extra_info),
    ]

    def metric2str(loss, acc):
        return "loss = {:.3f} & top1 = {:.2f}%".format(loss, acc)

    for dataset in dataset_names:
        metric = information.get_compute_costs(dataset)
        flop, param, latency = metric["flops"], metric["params"], metric["latency"]
        str1 = "{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.".format(
            dataset,
            flop,
            param,
            "{:.2f}".format(latency * 1000)
            if latency is not None and latency > 0
            else None,
        )
        train_info = information.get_metrics(dataset, "train")
        if dataset == "cifar10-valid":
            valid_info = information.get_metrics(dataset, "x-valid")
            str2 = "{:14s} train : [{:}], valid : [{:}]".format(
                dataset,
                metric2str(train_info["loss"], train_info["accuracy"]),
                metric2str(valid_info["loss"], valid_info["accuracy"]),
            )
        elif dataset == "cifar10":
            test__info = information.get_metrics(dataset, "ori-test")
            str2 = "{:14s} train : [{:}], test  : [{:}]".format(
                dataset,
                metric2str(train_info["loss"], train_info["accuracy"]),
                metric2str(test__info["loss"], test__info["accuracy"]),
            )
        else:
            valid_info = information.get_metrics(dataset, "x-valid")
            test__info = information.get_metrics(dataset, "x-test")
            str2 = "{:14s} train : [{:}], valid : [{:}], test : [{:}]".format(
                dataset,
                metric2str(train_info["loss"], train_info["accuracy"]),
                metric2str(valid_info["loss"], valid_info["accuracy"]),
                metric2str(test__info["loss"], test__info["accuracy"]),
            )
        strings += [str1, str2]
    if show:
        print("\n".join(strings))
    return strings


class NATStopology(NASBenchMetaAPI):
    """This is the class for the API of topology search space in NATS-Bench."""

    def __init__(
        self,
        file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
        fast_mode: bool = False,
        verbose: bool = True,
    ):
        """The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
        self._all_base_names = ALL_BASE_NAMES
        self.filename = None
        self._search_space_name = "topology"
        self._fast_mode = fast_mode
        self._archive_dir = None
        self._full_train_epochs = 200
        self.reset_time()
        if file_path_or_dict is None:
            if self._fast_mode:
                self._archive_dir = os.path.join(
                    get_torch_home(), "{:}-simple".format(ALL_BASE_NAMES[-1])
                )
            else:
                file_path_or_dict = os.path.join(
                    get_torch_home(), "{:}.{:}".format(ALL_BASE_NAMES[-1], PICKLE_EXT)
                )
            print(
                "{:} Try to use the default NATS-Bench (topology) path from "
                "fast_mode={:} and path={:}.".format(
                    time_string(), self._fast_mode, file_path_or_dict
                )
            )
        if isinstance(file_path_or_dict, str):
            file_path_or_dict = str(file_path_or_dict)
            if verbose:
                print(
                    "{:} Try to create the NATS-Bench (topology) api "
                    "from {:} with fast_mode={:}".format(
                        time_string(), file_path_or_dict, fast_mode
                    )
                )
            if not nats_is_file(file_path_or_dict) and not nats_is_dir(
                file_path_or_dict
            ):
                raise ValueError(
                    "{:} is neither a file or a dir.".format(file_path_or_dict)
                )
            self.filename = os.path.basename(file_path_or_dict)
            if fast_mode:
                if nats_is_file(file_path_or_dict):
                    raise ValueError(
                        "fast_mode={:} must feed the path for directory "
                        ": {:}".format(fast_mode, file_path_or_dict)
                    )
                else:
                    self._archive_dir = file_path_or_dict
            else:
                if nats_is_dir(file_path_or_dict):
                    raise ValueError(
                        "fast_mode={:} must feed the path for file "
                        ": {:}".format(fast_mode, file_path_or_dict)
                    )
                else:
                    file_path_or_dict = pickle_load(file_path_or_dict)
        elif isinstance(file_path_or_dict, dict):
            file_path_or_dict = copy.deepcopy(file_path_or_dict)
        self.verbose = verbose
        if isinstance(file_path_or_dict, dict):
            keys = ("meta_archs", "arch2infos", "evaluated_indexes")
            for key in keys:
                if key not in file_path_or_dict:
                    raise ValueError("Can not find key[{:}] in the dict".format(key))
            self.meta_archs = copy.deepcopy(file_path_or_dict["meta_archs"])
            # NOTE(xuanyidong): This is a dict mapping each architecture to a dict,
            # where the key is #epochs and the value is ArchResults
            self.arch2infos_dict = collections.OrderedDict()
            self._avaliable_hps = set()
            for xkey in sorted(list(file_path_or_dict["arch2infos"].keys())):
                all_infos = file_path_or_dict["arch2infos"][xkey]
                hp2archres = collections.OrderedDict()
                for hp_key, results in all_infos.items():
                    hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
                    self._avaliable_hps.add(
                        hp_key
                    )  # save the avaliable hyper-parameter
                self.arch2infos_dict[xkey] = hp2archres
            self.evaluated_indexes = set(file_path_or_dict["evaluated_indexes"])
        elif self.archive_dir is not None:
            benchmark_meta = pickle_load(
                "{:}/meta.{:}".format(self.archive_dir, PICKLE_EXT)
            )
            self.meta_archs = copy.deepcopy(benchmark_meta["meta_archs"])
            self.arch2infos_dict = collections.OrderedDict()
            self._avaliable_hps = set()
            self.evaluated_indexes = set()
        else:
            raise ValueError(
                "file_path_or_dict [{:}] must be a dict or archive_dir "
                "must be set".format(type(file_path_or_dict))
            )
        self.archstr2index = {}
        for idx, arch in enumerate(self.meta_archs):
            if arch in self.archstr2index:
                raise ValueError(
                    "This [{:}]-th arch {:} already in the "
                    "dict ({:}).".format(idx, arch, self.archstr2index[arch])
                )
            self.archstr2index[arch] = idx
        if self.verbose:
            print(
                "{:} Create NATS-Bench (topology) done with {:}/{:} architectures "
                "avaliable.".format(
                    time_string(), len(self.evaluated_indexes), len(self.meta_archs)
                )
            )

    @property
    def is_size(self):
        return False

    @property
    def is_topology(self):
        return True

    @property
    def full_epochs_in_paper(self):
        return 200

    def get_unique_str(self, arch):
        """Return a unique string for the isomorphism architectures.

        Args:
          arch: it can be an architecture index or an architecture string.

        Returns:
          the unique string.
        """
        index = self.query_index_by_arch(
            arch
        )  # To avoid the arch is a string or an instance of a arch object
        arch_str = self.meta_archs[index]
        structure = topology_str2structure(arch_str)
        return structure.to_unique_str(consider_zero=True)

    def query_info_str_by_arch(self, arch, hp: Text = "12"):
        """Query the information of a specific architecture.

        Args:
          arch: it can be an architecture index or an architecture string.

          hp: the hyperparamete indicator, could be 12 or 200. The difference
              between these three configurations are the number of training epochs.

        Returns:
          ArchResults instance
        """
        if self.verbose:
            print(
                "{:} Call query_info_str_by_arch with arch={:}"
                "and hp={:}".format(time_string(), arch, hp)
            )
        return self._query_info_str_by_arch(arch, hp, print_information)

    def get_more_info(
        self, index, dataset, iepoch=None, hp: Text = "12", is_random: bool = True
    ):
        """Return the metric for the `index`-th architecture."""
        if self.verbose:
            print(
                "{:} Call the get_more_info function with index={:}, dataset={:}, "
                "iepoch={:}, hp={:}, and is_random={:}.".format(
                    time_string(), index, dataset, iepoch, hp, is_random
                )
            )
        index = self.query_index_by_arch(
            index
        )  # To avoid the input is a string or an instance of a arch object
        self._prepare_info(index)
        if index not in self.arch2infos_dict:
            raise ValueError("Did not find {:} from arch2infos_dict.".format(index))
        archresult = self.arch2infos_dict[index][str(hp)]
        # if randomly select one trial, select the seed at first
        if isinstance(is_random, bool) and is_random:
            seeds = archresult.get_dataset_seeds(dataset)
            is_random = random.choice(seeds)
        # collect the training information
        train_info = archresult.get_metrics(
            dataset, "train", iepoch=iepoch, is_random=is_random
        )
        total = train_info["iepoch"] + 1
        xinfo = {
            "train-loss": train_info["loss"],
            "train-accuracy": train_info["accuracy"],
            "train-per-time": train_info["all_time"] / total
            if train_info["all_time"] is not None
            else None,
            "train-all-time": train_info["all_time"],
        }
        # collect the evaluation information
        if dataset == "cifar10-valid":
            valid_info = archresult.get_metrics(
                dataset, "x-valid", iepoch=iepoch, is_random=is_random
            )
            try:
                test_info = archresult.get_metrics(
                    dataset, "ori-test", iepoch=iepoch, is_random=is_random
                )
            except Exception as unused_e:  # pylint: disable=broad-except
                test_info = None
            valtest_info = None
            xinfo[
                "comment"
            ] = "In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.".format(
                hp
            )
        else:
            if dataset == "cifar10":
                xinfo[
                    "comment"
                ] = "In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.".format(
                    hp
                )
            try:  # collect results on the proposed test set
                if dataset == "cifar10":
                    test_info = archresult.get_metrics(
                        dataset, "ori-test", iepoch=iepoch, is_random=is_random
                    )
                else:
                    test_info = archresult.get_metrics(
                        dataset, "x-test", iepoch=iepoch, is_random=is_random
                    )
            except Exception as unused_e:  # pylint: disable=broad-except
                test_info = None
            try:  # collect results on the proposed validation set
                valid_info = archresult.get_metrics(
                    dataset, "x-valid", iepoch=iepoch, is_random=is_random
                )
            except Exception as unused_e:  # pylint: disable=broad-except
                valid_info = None
            try:
                if dataset != "cifar10":
                    valtest_info = archresult.get_metrics(
                        dataset, "ori-test", iepoch=iepoch, is_random=is_random
                    )
                else:
                    valtest_info = None
            except Exception as unused_e:  # pylint: disable=broad-except
                valtest_info = None
        if valid_info is not None:
            xinfo["valid-loss"] = valid_info["loss"]
            xinfo["valid-accuracy"] = valid_info["accuracy"]
            xinfo["valid-per-time"] = (
                valid_info["all_time"] / total
                if valid_info["all_time"] is not None
                else None
            )
            xinfo["valid-all-time"] = valid_info["all_time"]
        if test_info is not None:
            xinfo["test-loss"] = test_info["loss"]
            xinfo["test-accuracy"] = test_info["accuracy"]
            xinfo["test-per-time"] = (
                test_info["all_time"] / total
                if test_info["all_time"] is not None
                else None
            )
            xinfo["test-all-time"] = test_info["all_time"]
        if valtest_info is not None:
            xinfo["valtest-loss"] = valtest_info["loss"]
            xinfo["valtest-accuracy"] = valtest_info["accuracy"]
            xinfo["valtest-per-time"] = (
                valtest_info["all_time"] / total
                if valtest_info["all_time"] is not None
                else None
            )
            xinfo["valtest-all-time"] = valtest_info["all_time"]
        return xinfo

    def show(self, index: int = -1) -> None:
        """This function will print the information of a specific (or all) architecture(s)."""
        self._show(index, print_information)

    @staticmethod
    def str2lists(arch_str: Text) -> List[Any]:
        """Shows how to read the string-based architecture encoding.

        Args:
          arch_str: the input is a string indicates the architecture topology, such as
                        |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
        Returns:
          a list of tuple, contains multiple (op, input_node_index) pairs.

        [USAGE]
        It is the same as the `str2structure` func in AutoDL-Projects:
          `github.com/D-X-Y/AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
        ```
          arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
          print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
          for i, node in enumerate(arch):
            print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
        ```
        """
        node_strs = arch_str.split("+")
        genotypes = []
        for unused_i, node_str in enumerate(node_strs):
            inputs = list(
                filter(lambda x: x != "", node_str.split("|"))
            )  # pylint: disable=g-explicit-bool-comparison
            for xinput in inputs:
                assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
                    xinput
                )
            inputs = (xi.split("~") for xi in inputs)
            input_infos = tuple((op, int(idx)) for (op, idx) in inputs)
            genotypes.append(input_infos)
        return genotypes

    @staticmethod
    def str2matrix(
        arch_str: Text,
        search_space: List[Text] = (
            "none",
            "skip_connect",
            "nor_conv_1x1",
            "nor_conv_3x3",
            "avg_pool_3x3",
        ),
    ):
        """Convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.

        Args:
          arch_str: the input is a string indicates the architecture topology, such as
                        |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
          search_space: a list of operation string, the default list is the topology search space for NATS-BENCH.
            the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/models/cell_operations.py#L24

        Returns:
          the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology

        [USAGE]
          matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
          This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
             [ [0, 0, 0, 0],  # the first line represents the input (0-th) node
               [2, 0, 0, 0],  # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node )
               [0, 0, 0, 0],  # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node )
               [0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
          In the topology search space in NATS-BENCH, 0-th-op is 'none', 1-th-op is 'skip_connect',
             2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
        [NOTE]
          If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
        """
        import numpy as np

        node_strs = arch_str.split("+")
        num_nodes = len(node_strs) + 1
        matrix = np.zeros((num_nodes, num_nodes))
        for i, node_str in enumerate(node_strs):
            inputs = list(
                filter(lambda x: x != "", node_str.split("|"))
            )  # pylint: disable=g-explicit-bool-comparison
            for xinput in inputs:
                assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
                    xinput
                )
            for xi in inputs:
                op, idx = xi.split("~")
                if op not in search_space:
                    raise ValueError(
                        "this op ({:}) is not in {:}".format(op, search_space)
                    )
                op_idx, node_idx = search_space.index(op), int(idx)
                matrix[i + 1, node_idx] = op_idx
        return matrix


================================================
FILE: nats_bench/api_utils.py
================================================
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
"""In this file, we define NASBenchMetaAPI, ArchResults, and ResultsCount.

   NASBenchMetaAPI is the abstract class for benchmark APIs.
   We also define the class ArchResults, which contains all
   information of a single architecture trained by one kind of hyper-parameters
   on three datasets. We also define the class ResultsCount, which contains all
   information of a single trial for a single architecture.
"""
import abc
import bz2
import collections
import copy
import os
import pickle
import random
import time
from typing import Any, Dict, Optional, Text, Union
import warnings


_FILE_SYSTEM = "default"
PICKLE_EXT = "pickle.pbz2"


def mean(xlist):
    return sum(xlist) / len(xlist)


def time_string():
    iso_time_format = "%Y-%m-%d %X"
    string = "[{:}]".format(time.strftime(iso_time_format, time.gmtime(time.time())))
    return string


def reset_file_system(lib: Text = "default"):
    global _FILE_SYSTEM
    _FILE_SYSTEM = lib


def get_file_system():
    return _FILE_SYSTEM


def get_torch_home():
    if "TORCH_HOME" in os.environ:
        return os.environ["TORCH_HOME"]
    elif "HOME" in os.environ:
        return os.path.join(os.environ["HOME"], ".torch")
    else:
        raise ValueError(
            "Did not find HOME in os.environ. "
            "Please at least setup the path of HOME or TORCH_HOME "
            "in the environment."
        )


def nats_is_dir(file_path):
    if _FILE_SYSTEM == "default":
        return os.path.isdir(file_path)
    elif _FILE_SYSTEM == "google":
        import tensorflow as tf  # pylint: disable=g-import-not-at-top

        return tf.io.gfile.isdir(file_path)
    else:
        raise ValueError("Unknown file system lib: {:}".format(_FILE_SYSTEM))


def nats_is_file(file_path):
    if _FILE_SYSTEM == "default":
        return os.path.isfile(file_path)
    elif _FILE_SYSTEM == "google":
        import tensorflow as tf  # pylint: disable=g-import-not-at-top

        return tf.io.gfile.exists(file_path) and not tf.io.gfile.isdir(file_path)
    else:
        raise ValueError("Unknown file system lib: {:}".format(_FILE_SYSTEM))


def pickle_save(obj, file_path, ext=".pbz2", protocol=4):
    """Use pickle to save data (obj) into file_path.

    Args:
      obj: The object to be saved into a path.
      file_path: The target saving path.
      ext: The extension of file name.
      protocol: The pickle protocol. According to this documentation
        (https://docs.python.org/3/library/pickle.html#data-stream-format),
        the protocol version 4 was added in Python 3.4. It adds support for very
        large objects, pickling more kinds of objects, and some data format
        optimizations. It is the default protocol starting with Python 3.8.
    """
    # with open(file_path, 'wb') as cfile:
    if _FILE_SYSTEM == "default":
        with bz2.BZ2File(str(file_path) + ext, "wb") as cfile:
            pickle.dump(
                obj, cfile, protocol=protocol
            )  # pytype: disable=wrong-arg-types
    else:
        raise ValueError("Unknown file system lib: {:}".format(_FILE_SYSTEM))


def pickle_load(file_path, ext=".pbz2"):
    """Use pickle to load the file on different systems."""
    # return pickle.load(open(file_path, "rb"))
    if nats_is_file(str(file_path)):
        xfile_path = str(file_path)
    else:
        xfile_path = str(file_path) + ext
    if _FILE_SYSTEM == "default":
        with bz2.BZ2File(xfile_path, "rb") as cfile:
            return pickle.load(cfile)  # pytype: disable=wrong-arg-types
    elif _FILE_SYSTEM == "google":
        import tensorflow as tf  # pylint: disable=g-import-not-at-top

        file_content = tf.io.gfile.GFile(file_path, mode="rb").read()
        byte_content = bz2.decompress(file_content)
        return pickle.loads(byte_content)
    else:
        raise ValueError("Unknown file system lib: {:}".format(_FILE_SYSTEM))


def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
    """Re-map the metric_on_set to internal keys."""
    if verbose:
        print(
            "Call internal function _remap_dataset_set_names with dataset={:} "
            "and metric_on_set={:}".format(dataset, metric_on_set)
        )
    if dataset == "cifar10" and metric_on_set == "valid":
        dataset, metric_on_set = "cifar10-valid", "x-valid"
    elif dataset == "cifar10" and metric_on_set == "test":
        dataset, metric_on_set = "cifar10", "ori-test"
    elif dataset == "cifar10" and metric_on_set == "train":
        dataset, metric_on_set = "cifar10", "train"
    elif (
        dataset == "cifar100" or dataset == "ImageNet16-120"
    ) and metric_on_set == "valid":
        metric_on_set = "x-valid"
    elif (
        dataset == "cifar100" or dataset == "ImageNet16-120"
    ) and metric_on_set == "test":
        metric_on_set = "x-test"
    if verbose:
        print(
            "  return dataset={:} and metric_on_set={:}".format(dataset, metric_on_set)
        )
    return dataset, metric_on_set


class NASBenchMetaAPI(metaclass=abc.ABCMeta):
    """The abstract class for NATS Bench API."""

    @abc.abstractmethod
    def __init__(
        self,
        file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
        fast_mode: bool = False,
        verbose: bool = True,
    ):
        """The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
        # NOTE(xuanyidong): the following attributes must be initilaized in subclass
        self.meta_archs = None
        self.verbose = None
        self.evaluated_indexes = None
        self.arch2infos_dict = None
        self.filename = None
        self._fast_mode = None
        self._archive_dir = None
        self._avaliable_hps = None
        self.archstr2index = None

    def __getitem__(self, index: int):
        return copy.deepcopy(self.meta_archs[index])

    def arch(self, index: int):
        """Return the topology structure of the `index`-th architecture."""
        if self.verbose:
            print("Call the arch function with index={:}".format(index))
        if index < 0 or index >= len(self.meta_archs):
            raise ValueError(
                "invalid index : {:} vs. {:}.".format(index, len(self.meta_archs))
            )
        return copy.deepcopy(self.meta_archs[index])

    def __len__(self):
        return len(self.meta_archs)

    def __repr__(self):
        return (
            "{name}({num}/{total} architectures, fast_mode={fast_mode}, "
            "file={filename})".format(
                name=self.__class__.__name__,
                num=len(self.evaluated_indexes),
                total=len(self.meta_archs),
                fast_mode=self.fast_mode,
                filename=self.filename,
            )
        )

    @property
    def avaliable_hps(self):
        return list(copy.deepcopy(self._avaliable_hps))

    @property
    def used_time(self):
        return self._used_time

    @property
    def search_space_name(self):
        return self._search_space_name

    @property
    def fast_mode(self):
        return self._fast_mode

    @property
    def archive_dir(self):
        return self._archive_dir

    @property
    def full_train_epochs(self):
        return self._full_train_epochs

    def reset_archive_dir(self, archive_dir):
        self._archive_dir = archive_dir

    def reset_fast_mode(self, fast_mode):
        self._fast_mode = fast_mode

    def reset_time(self):
        self._used_time = 0

    @abc.abstractmethod
    def get_more_info(
        self, index, dataset, iepoch=None, hp: Text = "12", is_random: bool = True
    ):
        """Return the metric for the `index`-th architecture."""

    def simulate_train_eval(
        self, arch, dataset, iepoch=None, hp="12", account_time=True
    ):
        """This function is used to simulate training and evaluating an arch."""
        index = self.query_index_by_arch(arch)
        all_names = ("cifar10", "cifar100", "ImageNet16-120")
        if dataset not in all_names:
            raise ValueError(
                "Invalid dataset name : {:} vs {:}".format(dataset, all_names)
            )
        if dataset == "cifar10":
            info = self.get_more_info(
                index, "cifar10-valid", iepoch=iepoch, hp=hp, is_random=True
            )
        else:
            info = self.get_more_info(
                index, dataset, iepoch=iepoch, hp=hp, is_random=True
            )
        if "valid-accuracy" in info:
            valid_acc, time_cost = (
                info["valid-accuracy"],
                info["train-all-time"] + info["valid-per-time"],
            )
        else:
            valid_acc = info["valtest-accuracy"]
            temp_info = self.get_more_info(
                index, dataset, iepoch=None, hp=hp, is_random=True
            )
            time_cost = info["train-all-time"] + temp_info["valid-per-time"]
        latency = self.get_latency(index, dataset, hp=hp)
        if account_time:
            self._used_time += time_cost
        return valid_acc, latency, time_cost, self._used_time

    def random(self):
        """Return a random index of all architectures."""
        return random.randint(0, len(self.meta_archs) - 1)

    def reload(self, archive_root: Text = None, index: int = None):
        """Overwrite all information of the 'index'-th architecture in search space.

        Args:
          archive_root: If archive_root is None, it will try to load from the
            default path os.environ['TORCH_HOME'] / 'BASE_NAME'-full.
          index: If index is None, overwrite all ckps.
        """
        if self.verbose:
            print(
                "{:} Call clear_params with archive_root={:} and index={:}".format(
                    time_string(), archive_root, index
                )
            )
        if archive_root is None:
            archive_root = os.path.join(
                os.environ["TORCH_HOME"], "{:}-full".format(self._all_base_names[-1])
            )
            if not nats_is_dir(archive_root):
                warnings.warn(
                    "The input archive_root is None and the default "
                    "archive_root path ({:}) does not exist, try to use "
                    "self.archive_dir.".format(archive_root)
                )
                archive_root = self.archive_dir
        if archive_root is None or not nats_is_dir(archive_root):
            raise ValueError("Invalid archive_root : {:}".format(archive_root))
        if index is None:
            indexes = list(range(len(self)))
        else:
            indexes = [index]
        for idx in indexes:
            if not (
                0 <= idx < len(self.meta_archs)
            ):  # pylint: disable=superfluous-parens
                raise ValueError("invalid index of {:}".format(idx))
            xfile_path = os.path.join(
                archive_root, "{:06d}.{:}".format(idx, PICKLE_EXT)
            )
            if not nats_is_file(xfile_path):
                xfile_path = os.path.join(
                    archive_root, "{:d}.{:}".format(idx, PICKLE_EXT)
                )
            assert nats_is_file(xfile_path), "invalid data path : {:}".format(
                xfile_path
            )
            xdata = pickle_load(xfile_path)
            assert isinstance(xdata, dict), "invalid format of data in {:}".format(
                xfile_path
            )
            self.evaluated_indexes.add(idx)
            hp2archres = collections.OrderedDict()
            for hp_key, results in xdata.items():
                hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
                self._avaliable_hps.add(hp_key)
            self.arch2infos_dict[idx] = hp2archres

    def query_index_by_arch(self, arch):
        """Query the index of an architecture in the search space.

        Args:
          arch: For topology search space, the input arch can be an architecture
           string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';  # pylint: disable=line-too-long
              or an instance that has the 'tostr' function that can
                  generate the architecture string;
              or it is directly an architecture index, in this case,
                  we will check whether it is valid or not.
           This function will return the index.
           If return -1, it means this architecture is not in the search space.
           Otherwise, it will return an intenger in
              [0, the-number-of-candidates-in-the-search-space).

        Raises:
          ValueError: If did not find the architecture in this benchmark.

        Returns:
          The index of the architcture in this benchmark.
        """
        if self.verbose:
            print(
                "{:} Call query_index_by_arch with arch={:}".format(time_string(), arch)
            )
        if isinstance(arch, int):
            if 0 <= arch < len(self):
                return arch
            else:
                raise ValueError(
                    "Invalid architecture index {:} vs [{:}, {:}].".format(
                        arch, 0, len(self)
                    )
                )
        elif isinstance(arch, str):
            if arch in self.archstr2index:
                arch_index = self.archstr2index[arch]
            else:
                arch_index = -1
        elif hasattr(arch, "tostr"):
            if arch.tostr() in self.archstr2index:
                arch_index = self.archstr2index[arch.tostr()]
            else:
                arch_index = -1
        else:
            arch_index = -1
        return arch_index

    def query_by_arch(self, arch, hp):
        """Make the current version be compatible with the old NAS-Bench-201 version."""
        return self.query_info_str_by_arch(arch, hp)

    def _prepare_info(self, index):
        """This is a function to load the data from disk when using fast mode."""
        if index not in self.arch2infos_dict:
            if self.fast_mode and self.archive_dir is not None:
                self.reload(self.archive_dir, index)
            elif not self.fast_mode:
                if self.verbose:
                    print(
                        "{:} Call _prepare_info with index={:} skip because it is not"
                        "the fast mode.".format(time_string(), index)
                    )
            else:
                raise ValueError(
                    "Invalid status: fast_mode={:} and "
                    "archive_dir={:}".format(self.fast_mode, self.archive_dir)
                )
        else:
            if index not in self.evaluated_indexes:
                raise ValueError(
                    "The index of {:} is not in self.evaluated_indexes, "
                    "there must be something wrong.".format(index)
                )
            if self.verbose:
                print(
                    "{:} Call _prepare_info with index={:} skip because it is in "
                    "arch2infos_dict".format(time_string(), index)
                )

    def clear_params(self, index: int, hp: Optional[Text] = None):
        """Remove the architecture's weights to save memory.

        Args:
          index: the index of the target architecture
          hp: a flag to controll how to clear the parameters.
            -- None: clear all the weights in '01'/'12'/'90', which indicates
                 the number of training epochs.
            -- '01' or '12' or '90': clear all the weights in
                 arch2infos_dict[index][hp].
        """
        if self.verbose:
            print(
                "{:} Call clear_params with index={:} and hp={:}".format(
                    time_string(), index, hp
                )
            )
        if index not in self.arch2infos_dict:
            warnings.warn(
                "The {:}-th architecture is not in the benchmark data yet, "
                "no need to clear params.".format(index)
            )
        elif hp is None:
            for key, result in self.arch2infos_dict[index].items():
                result.clear_params()
        else:
            if str(hp) not in self.arch2infos_dict[index]:
                raise ValueError(
                    "The {:}-th architecture only has hyper-parameters "
                    "of {:} instead of {:}.".format(
                        index, list(self.arch2infos_dict[index].keys()), hp
                    )
                )
            self.arch2infos_dict[index][str(hp)].clear_params()

    @abc.abstractmethod
    def query_info_str_by_arch(self, arch, hp: Text = "12"):
        """This function is used to query the information of a specific architecture."""

    def _query_info_str_by_arch(self, arch, hp: Text = "12", print_information=None):
        """Internal function to query the information of `arch` when using `hp`."""
        arch_index = self.query_index_by_arch(arch)
        self._prepare_info(arch_index)
        if arch_index in self.arch2infos_dict:
            if hp not in self.arch2infos_dict[arch_index]:
                raise ValueError(
                    "The {:}-th architecture only has hyper-parameters of "
                    "{:} instead of {:}.".format(
                        arch_index, list(self.arch2infos_dict[arch_index].keys()), hp
                    )
                )
            info = self.arch2infos_dict[arch_index][hp]
            strings = print_information(info, "arch-index={:}".format(arch_index))
            return "\n".join(strings)
        else:
            warnings.warn(
                "Find this arch-index : {:}, but this arch is not "
                "evaluated.".format(arch_index)
            )
            return None

    def query_meta_info_by_index(self, arch_index, hp: Text = "12"):
        """Return ArchResults for the 'arch_index'-th architecture."""
        if self.verbose:
            print(
                "Call query_meta_info_by_index with arch_index={:}, hp={:}".format(
                    arch_index, hp
                )
            )
        self._prepare_info(arch_index)
        if arch_index in self.arch2infos_dict:
            if str(hp) not in self.arch2infos_dict[arch_index]:
                raise ValueError(
                    "The {:}-th architecture only has hyper-parameters of "
                    "{:} instead of {:}.".format(
                        arch_index, list(self.arch2infos_dict[arch_index].keys()), hp
                    )
                )
            info = self.arch2infos_dict[arch_index][str(hp)]
        else:
            raise ValueError(
                "arch_index [{:}] does not in arch2infos".format(arch_index)
            )
        return copy.deepcopy(info)

    def query_by_index(
        self, arch_index: int, dataname: Union[None, Text] = None, hp: Text = "12"
    ):
        """Query the information with the training of 01/12/90/200 epochs.

        Args:
          arch_index: The architecture index in this benchmark.
          dataname: If dataname is None, return the ArchResults; otherwise, we will
                    return a dict with all trials on that dataset
                    (the key is the seed).
                    Options are 'cifar10-valid', 'cifar10', 'cifar100',
                      and 'ImageNet16-120'.
              -- cifar10-valid : train the model on CIFAR-10 training set.
              -- cifar10 : train the model on CIFAR-10 training + validation set.
              -- cifar100 : train the model on CIFAR-100 training set.
              -- ImageNet16-120 : train the model on ImageNet16-120 training set.
          hp: The hyperparameters.
            If hp=01, we train the model by 01 epochs.
            If hp=12, we train the model by 01 epochs.
            If hp=90, we train the model by 01 epochs.
            If hp=200, we train the model by 01 epochs.
            See github.com/D-X-Y/AutoDL-Projects/configs/nas-benchmark/hyper-opts
              for more details.

        Raises:
          ValueError: If not find the matched serach space description.

        Returns:
          An instance fo ArchResults.
        """
        if self.verbose:
            print(
                "{:} Call query_by_index with arch_index={:}, dataname={:}, "
                "hp={:}".format(time_string(), arch_index, dataname, hp)
            )
        info = self.query_meta_info_by_index(arch_index, str(hp))
        if dataname is None:
            return info
        else:
            if dataname not in info.get_dataset_names():
                raise ValueError(
                    "invalid dataset-name : {:} vs. {:}".format(
                        dataname, info.get_dataset_names()
                    )
                )
            return info.query(dataname)

    def find_best(
        self,
        dataset,
        metric_on_set,
        flop_max=None,
        param_max=None,
        hp: Text = "12",
        enforce_all: Optional[bool] = None,
    ):
        """Find the architecture with the highest accuracy based on some constraints."""
        # Please see how to set the `dataset` and `metric_on_set` (setname) at here:
        # https://github.com/D-X-Y/NATS-Bench/blob/main/nats_bench/api_utils.py#L702
        if self.verbose:
            print(
                "{:} Call find_best with dataset={:}, metric_on_set={:}, hp={:} "
                "| with #FLOPs < {:} and #Params < {:}".format(
                    time_string(), dataset, metric_on_set, hp, flop_max, param_max
                )
            )
        dataset, metric_on_set = remap_dataset_set_names(
            dataset, metric_on_set, self.verbose
        )
        best_index, highest_accuracy = -1, None
        if enforce_all is None:
            enforce_all = True if self.fast_mode else False
        if enforce_all:
            # We set this arg `enforce_all` because in the fast mode, evaluated_indexes will be empty
            # `evaluated_indexes` is dynamically inserted with architectures along with the query
            assert (
                self.fast_mode
            ), "enforce_all can only be set when fast_mode=True; if you are using non-fast-mode, please set it as False"
            evaluated_indexes = list(range(len(self)))
        else:
            evaluated_indexes = sorted(list(self.evaluated_indexes))
        for arch_index in evaluated_indexes:
            self._prepare_info(arch_index)
            arch_info = self.arch2infos_dict[arch_index][hp]
            info = arch_info.get_compute_costs(dataset)  # the information of costs
            flop, param, latency = info["flops"], info["params"], info["latency"]
            if flop_max is not None and flop > flop_max:
                continue
            if param_max is not None and param > param_max:
                continue
            xinfo = arch_info.get_metrics(
                dataset, metric_on_set
            )  # the information of loss and accuracy
            loss, accuracy = xinfo["loss"], xinfo["accuracy"]
            if best_index == -1:
                best_index, highest_accuracy = arch_index, accuracy
            elif highest_accuracy < accuracy:
                best_index, highest_accuracy = arch_index, accuracy
            del latency, loss
        if self.verbose:
            if not evaluated_indexes:
                print(
                    "The evaluated_indexes is empty, please fill it before call find_best."
                )
            else:
                print(
                    "  the best architecture : [{:}] {:} with accuracy={:.3f}%".format(
                        best_index, self.arch(best_index), highest_accuracy
                    )
                )
        return best_index, highest_accuracy

    def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = "12"):
        """Obtain the trained weights of the `index`-th arch on `dataset`.

        Args:
          index: The architecture index.
          dataset: The training dataset name.
          seed:
            -- None : return a dict containing the trained weights of all trials,
                      where each key is a seed and its corresponding value
                      is the weights.
            -- Interger : return the weights of a specific trial, whose seed
                      is this interger.
          hp:
            -- 01 : train the model by 01 epochs
            -- 12 : train the model by 12 epochs
            -- 90 : train the model by 90 epochs
            -- 200 : train the model by 200 epochs
        Returns:
          PyTorch weights.
        """
        if self.verbose:
            print(
                "{:} Call the get_net_param function with index={:}, dataset={:}, "
                "seed={:}, hp={:}".format(time_string(), index, dataset, seed, hp)
            )
        info = self.query_meta_info_by_index(index, hp)
        return info.get_net_param(dataset, seed)

    def get_net_config(self, index: int, dataset: Text):
        """Obtain the configuration for the `index`-th architecture on `dataset`.

        Args:
          index: The architecture index.
          dataset: 4 possible options as follows,
            -- cifar10-valid : train the model on the CIFAR-10 training set.
            -- cifar10 : train the model on the CIFAR-10 training + validation set.
            -- cifar100 : train the model on the CIFAR-100 training set.
            -- ImageNet16-120 : train the model on the ImageNet16-120 training set.
        Returns:
          A dict.

        Note: some examlpes for using this function:
          config = api.get_net_config(128, 'cifar10')
        """
        if self.verbose:
            print(
                "{:} Call the get_net_config function with index={:}, "
                "dataset={:}.".format(time_string(), index, dataset)
            )
        self._prepare_info(index)
        if index in self.arch2infos_dict:
            info = self.arch2infos_dict[index]
        else:
            raise ValueError(
                "The arch_index={:} is not in arch2infos_dict.".format(index)
            )
        info = next(iter(info.values()))
        results = info.query(dataset, None)
        results = next(iter(results.values()))
        return results.get_config(None)

    def get_cost_info(
        self, index: int, dataset: Text, hp: Text = "12"
    ) -> Dict[Text, float]:
        """To obtain the cost metric for the `index`-th architecture on a dataset."""
        if self.verbose:
            print(
                "{:} Call the get_cost_info function with index={:}, "
                "dataset={:}, and hp={:}.".format(time_string(), index, dataset, hp)
            )
        self._prepare_info(index)
        info = self.query_meta_info_by_index(index, hp)
        return info.get_compute_costs(dataset)

    def get_latency(self, index: int, dataset: Text, hp: Text = "12") -> float:
        """Obtain the latency of the network.

        Note: by default it will return the latency with the batch size of 256.
        Args:
          index: the index of the target architecture
          dataset: the dataset name (cifar10-valid, cifar10, cifar100,
                                     and ImageNet16-120)
          hp: the hyperparamete indicator.

        Returns:
          return a float value in seconds
        """
        if self.verbose:
            print(
                "{:} Call the get_latency function with index={:}, "
                "dataset={:}, and hp={:}.".format(time_string(), index, dataset, hp)
            )
        cost_dict = self.get_cost_info(index, dataset, hp)
        return cost_dict["latency"]

    @abc.abstractmethod
    def show(self, index=-1):
        """This function will print the information of a specific (or all) architecture(s)."""

    def _show(self, index=-1, print_information=None) -> None:
        """Print the information of a specific (or all) architecture(s).

        Args:
          index: If the index < 0: it will loop for all architectures and print
                 their information one by one. Else: it will print the information
                 of the 'index'-th architecture.

          print_information: A function to print result.

        Returns: None
        """
        if index < 0:  # show all architectures
            print(self)
            evaluated_indexes = sorted(list(self.evaluated_indexes))
            for i, idx in enumerate(evaluated_indexes):
                print(
                    "\n" + "-" * 10 + " The ({:5d}/{:5d}) {:06d}-th "
                    "architecture! ".format(i, len(evaluated_indexes), idx) + "-" * 10
                )
                print("arch : {:}".format(self.meta_archs[idx]))
                for unused_key, result in self.arch2infos_dict[index].items():
                    strings = print_information(result)
                    print(
                        ">" * 40
                        + " {:03d} epochs ".format(result.get_total_epoch())
                        + ">" * 40
                    )
                    print("\n".join(strings))
                print("<" * 40 + "------------" + "<" * 40)
        else:
            if 0 <= index < len(self.meta_archs):
                if index not in self.evaluated_indexes:
                    self._prepare_info(index)
                if index not in self.evaluated_indexes:
                    print(
                        "The {:}-th architecture has not been evaluated "
                        "or not saved.".format(index)
                    )
                else:
                    # arch_info = self.arch2infos_dict[index]
                    for unused_key, result in self.arch2infos_dict[index].items():
                        strings = print_information(result)
                        print(
                            ">" * 40
                            + " {:03d} epochs ".format(result.get_total_epoch())
                            + ">" * 40
                        )
                        print("\n".join(strings))
                    print("<" * 40 + "------------" + "<" * 40)
            else:
                print(
                    "This index ({:}) is out of range (0~{:}).".format(
                        index, len(self.meta_archs)
                    )
                )

    def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]:
        """This function will count the number of total trials."""
        if self.verbose:
            print(
                "Call the statistics function with dataset={:} and hp={:}.".format(
                    dataset, hp
                )
            )
        valid_datasets = ["cifar10-valid", "cifar10", "cifar100", "ImageNet16-120"]
        if dataset not in valid_datasets:
            raise ValueError("{:} not in {:}".format(dataset, valid_datasets))
        nums, hp = collections.defaultdict(lambda: 0), str(hp)
        # for index in range(len(self)):
        for index in self.evaluated_indexes:
            arch_info = self.arch2infos_dict[index][hp]
            dataset_seed = arch_info.dataset_seed
            if dataset not in dataset_seed:
                nums[0] += 1
            else:
                nums[len(dataset_seed[dataset])] += 1
        return dict(nums)


class ArchResults(object):
    """A class to maintain the results of an architecture under different settings."""

    def __init__(self, arch_index, arch_str):
        self.arch_index = int(arch_index)
        self.arch_str = copy.deepcopy(arch_str)
        self.all_results = dict()
        self.dataset_seed = dict()
        self.clear_net_done = False

    def get_compute_costs(self, dataset):
        """Return the computation cost on the input dataset."""
        x_seeds = self.dataset_seed[dataset]
        results = [self.all_results[(dataset, seed)] for seed in x_seeds]

        flops = [result.flop for result in results]
        params = [result.params for result in results]
        # NOTE(xuanyidong):
        # Due to the legacy issue, flops and params may be incorrect.
        # This is a quick fix for this legacy issue.
        def fix_legacy_issue(raw_list):
            xlist = [f"{x:.5f}" for x in raw_list]
            if len(xlist) > 1 and len(set(xlist)) > 1:  # inconsistent
                return [raw_list[x_seeds.index(888)]]
            else:
                return raw_list

        flops = fix_legacy_issue(flops)
        params = fix_legacy_issue(params)
        latencies = [result.get_latency() for result in results]
        latencies = [x for x in latencies if x > 0]
        mean_latency = mean(latencies) if len(latencies) else None
        time_infos = collections.defaultdict(list)
        for result in results:
            time_info = result.get_times()
            for key, value in time_info.items():
                time_infos[key].append(value)

        info = {
            "flops": mean(flops),
            "params": mean(params),
            "latency": mean_latency,
        }
        for key, value in time_infos.items():
            if len(value) and value[0] is not None:
                info[key] = mean(value)
            else:
                info[key] = None
        return info

    def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
        """Obtain the loss, accuracy, etc information on a specific dataset.

          If not specify, each set refer to the proposed split in NAS-Bench-201.
          If some args return None or raise error, then it is not avaliable.
          ========================================

        Args:
          dataset: 4 possible options as follows
            -- cifar10-valid : train the model on the CIFAR-10 training set.
            -- cifar10 : train the model on the CIFAR-10 training + validation set.
            -- cifar100 : train the model on the CIFAR-100 training set.
            -- ImageNet16-120 : train the model on the ImageNet16-120 training set.
          setname: each dataset has different setnames
            -- When dataset = cifar10-valid, you can use 'train',
                                       'x-valid', and 'ori-test'
            ------ 'train' : the metric on the training set.
            ------ 'x-valid' : the metric on the validation set.
            ------ 'ori-test' : the metric on the test set.
            -- When dataset = cifar10, you can use 'train', 'ori-test'.
            ------ 'train' : the metric on the training + validation set.
            ------ 'ori-test' : the metric on the test set.
            -- When dataset = cifar100 or ImageNet16-120, you can use 'train',
                                          'ori-test', 'x-valid', and 'x-test'
            ------ 'train' : the metric on the training set.
            ------ 'x-valid' : the metric on the validation set.
            ------ 'x-test' : the metric on the test set.
            ------ 'ori-test' : the metric on the validation + test set.
          iepoch: (None or an integer in [0, the-number-of-total-training-epochs)
            ------ None : return the metric after the last training epoch.
            ------ an integer i : return the metric after the i-th training epoch.
          is_random:
            ------ True : return the metric of a randomly selected trial.
            ------ False : return the averaged metric of all avaliable trials.
            ------ an integer indicating the 'seed' value : return the metric of a
                   specific trial (whose random seed is 'is_random').

        Returns:
          All the metrics given the input setting.
        """
        x_seeds = self.dataset_seed[dataset]
        results = [self.all_results[(dataset, seed)] for seed in x_seeds]
        infos = collections.defaultdict(list)
        for result in results:
            if setname == "train":
                info = result.get_train(iepoch)
            else:
                info = result.get_eval(setname, iepoch)
            for key, value in info.items():
                infos[key].append(value)
        return_info = dict()
        if isinstance(is_random, bool) and is_random:  # randomly select one
            index = random.randint(0, len(results) - 1)
            for key, value in infos.items():
                return_info[key] = value[index]
        elif isinstance(is_random, bool) and not is_random:  # average
            for key, value in infos.items():
                if len(value) and value[0] is not None:
                    return_info[key] = mean(value)
                else:
                    return_info[key] = None
        elif isinstance(is_random, int):  # specify the seed
            if is_random not in x_seeds:
                raise ValueError(
                    "can not find random seed ({:}) from {:}".format(is_random, x_seeds)
                )
            index = x_seeds.index(is_random)
            for key, value in infos.items():
                return_info[key] = value[index]
        else:
            raise ValueError("invalid value for is_random: {:}".format(is_random))
        return return_info

    # def show(self, is_print=False):
    #   return print_information(self, None, is_print)

    def get_dataset_names(self):
        return list(self.dataset_seed.keys())

    def get_dataset_seeds(self, dataset):
        return copy.deepcopy(self.dataset_seed[dataset])

    def get_net_param(self, dataset: Text, seed: Union[None, int] = None):
        """Return the trained network's weights on the 'dataset'.

        Args:
          dataset: 'cifar10-valid', 'cifar10', 'cifar100', or 'ImageNet16-120'.
          seed: an integer indicates the seed value
                or None that indicates returing all trials.

        Returns:
          The trained weights (parameters).
        """
        if seed is None:
            x_seeds = self.dataset_seed[dataset]
            return {
                seed: self.all_results[(dataset, seed)].get_net_param()
                for seed in x_seeds
            }
        else:
            xkey = (dataset, seed)
            if xkey in self.all_results:
                return self.all_results[xkey].get_net_param()
            else:
                raise ValueError(
                    "key={:} not in {:}".format(xkey, list(self.all_results.keys()))
                )

    def reset_latency(
        self, dataset: Text, seed: Union[None, Text], latency: float
    ) -> None:
        """This function is used to reset the latency in all corresponding ResultsCount(s)."""
        if seed is None:
            for seed in self.dataset_seed[dataset]:
                self.all_results[(dataset, seed)].update_latency([latency])
        else:
            self.all_results[(dataset, seed)].update_latency([latency])

    def reset_pseudo_train_times(
        self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float
    ) -> None:
        """This function is used to reset the train-times in all corresponding ResultsCount(s)."""
        if seed is None:
            for seed in self.dataset_seed[dataset]:
                self.all_results[(dataset, seed)].reset_pseudo_train_times(
                    estimated_per_epoch_time
                )
        else:
            self.all_results[(dataset, seed)].reset_pseudo_train_times(
                estimated_per_epoch_time
            )

    def reset_pseudo_eval_times(
        self,
        dataset: Text,
        seed: Union[None, Text],
        eval_name: Text,
        estimated_per_epoch_time: float,
    ) -> None:
        """This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
        if seed is None:
            for seed in self.dataset_seed[dataset]:
                self.all_results[(dataset, seed)].reset_pseudo_eval_times(
                    eval_name, estimated_per_epoch_time
                )
        else:
            self.all_results[(dataset, seed)].reset_pseudo_eval_times(
                eval_name, estimated_per_epoch_time
            )

    def get_latency(self, dataset: Text) -> float:
        """Get the latency of a model on the target dataset."""
        latencies = []
        for seed in self.dataset_seed[dataset]:
            latency = self.all_results[(dataset, seed)].get_latency()
            if not isinstance(latency, float) or latency <= 0:
                raise ValueError(
                    "invalid latency of {:} with seed={:} : {:}".format(
                        dataset, seed, latency
                    )
                )
            latencies.append(latency)
        return sum(latencies) / len(latencies)

    def get_total_epoch(self, dataset=None):
        """Return the total number of training epochs."""
        if dataset is None:
            epochss = []
            for xdata, x_seeds in self.dataset_seed.items():
                epochss += [
                    self.all_results[(xdata, seed)].get_total_epoch()
                    for seed in x_seeds
                ]
        elif isinstance(dataset, str):
            x_seeds = self.dataset_seed[dataset]
            epochss = [
                self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds
            ]
        else:
            raise ValueError("invalid dataset={:}".format(dataset))
        if len(set(epochss)) > 1:
            raise ValueError(
                "Each trial mush have the same number of training epochs : {:}".format(
                    epochss
                )
            )
        return epochss[-1]

    def query(self, dataset, seed=None):
        """Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'."""
        if seed is None:
            x_seeds = self.dataset_seed[dataset]
            return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
        else:
            return self.all_results[(dataset, seed)]

    def arch_idx_str(self):
        return "{:06d}".format(self.arch_index)

    def update(self, dataset_name, seed, result):
        """Update the result for the given dataset and seed."""
        if dataset_name not in self.dataset_seed:
            self.dataset_seed[dataset_name] = []
        if seed in self.dataset_seed[dataset_name]:
            raise ValueError(
                "{:}-th arch alreadly has this seed ({:}) on {:}".format(
                    self.arch_index, seed, dataset_name
                )
            )
        self.dataset_seed[dataset_name].append(seed)
        self.dataset_seed[dataset_name] = sorted(self.dataset_seed[dataset_name])
        assert (dataset_name, seed) not in self.all_results
        self.all_results[(dataset_name, seed)] = result
        self.clear_net_done = False

    def state_dict(self):
        """Return a dict that can be used to re-create this instance."""
        state_dict = dict()
        for key, value in self.__dict__.items():
            if key == "all_results":  # contain the class of ResultsCount
                xvalue = dict()
                if not isinstance(value, dict):
                    raise ValueError(
                        "invalid type of value for {:} : {:}".format(key, type(value))
                    )
                for cur_k, cur_v in value.items():
                    if not isinstance(cur_v, ResultsCount):
                        raise ValueError(
                            "invalid type of value for {:}/{:} : {:}".format(
                                key, cur_k, type(cur_v)
                            )
                        )
                    xvalue[cur_k] = cur_v.state_dict()
            else:
                xvalue = value
            state_dict[key] = xvalue
        return state_dict

    def load_state_dict(self, state_dict):
        """Update self based on the input dict."""
        new_state_dict = dict()
        for key, value in state_dict.items():
            if key == "all_results":  # To convert to the class of ResultsCount
                xvalue = dict()
                if not isinstance(value, dict):
                    raise ValueError(
                        "invalid type of value for {:} : {:}".format(key, type(value))
                    )
                for cur_k, cur_v in value.items():
                    xvalue[cur_k] = ResultsCount.create_from_state_dict(cur_v)
            else:
                xvalue = value
            new_state_dict[key] = xvalue
        self.__dict__.update(new_state_dict)

    @staticmethod
    def create_from_state_dict(state_dict_or_file):
        """Create the ArchResults instance from a dict or a file."""
        x = ArchResults(-1, -1)
        if isinstance(state_dict_or_file, str):  # a file path
            state_dict = pickle_load(state_dict_or_file)
        elif isinstance(state_dict_or_file, dict):
            state_dict = state_dict_or_file
        else:
            raise ValueError(
                "invalid type of state_dict_or_file : {:}".format(
                    type(state_dict_or_file)
                )
            )
        x.load_state_dict(state_dict)
        return x

    def clear_params(self):
        """Clear the weights saved in each 'result'."""
        # NOTE(xuanyidong): This can help reduce the memory footprint.
        for unused_key, result in self.all_results.items():
            del result.net_state_dict
            result.net_state_dict = None
        self.clear_net_done = True

    def debug_test(self):
        """Help debug and test, which will call most methods."""
        all_dataset = ["cifar10-valid", "cifar10", "cifar100", "ImageNet16-120"]
        for dataset in all_dataset:
            print("---->>>> {:}".format(dataset))
            print(
                "The latency on {:} is {:} s".format(dataset, self.get_latency(dataset))
            )
            for seed in self.dataset_seed[dataset]:
                result = self.all_results[(dataset, seed)]
                print("  ==>> result = {:}".format(result))
                print("  ==>> cost = {:}".format(result.get_times()))

    def __repr__(self):
        return (
            "{name}(arch-index={index}, arch={arch}, "
            "{num} runs, clear={clear})".format(
                name=self.__class__.__name__,
                index=self.arch_index,
                arch=self.arch_str,
                num=len(self.all_results),
                clear=self.clear_net_done,
            )
        )


class ResultsCount(object):
    """ResultsCount is to save the information of one trial for a single architecture."""

    def __init__(
        self,
        name,
        state_dict,
        train_accs,
        train_losses,
        params,
        flop,
        arch_config,
        seed,
        epochs,
        latency,
    ):
        self.name = name
        self.net_state_dict = state_dict
        self.train_acc1es = copy.deepcopy(train_accs)
        self.train_acc5es = None
        self.train_losses = copy.deepcopy(train_losses)
        self.train_times = None
        self.arch_config = copy.deepcopy(arch_config)
        self.params = params
        self.flop = flop
        self.seed = seed
        self.epochs = epochs
        self.latency = latency
        # evaluation results
        self.reset_eval()

    def update_train_info(
        self, train_acc1es, train_acc5es, train_losses, train_times
    ) -> None:
        self.train_acc1es = train_acc1es
        self.train_acc5es = train_acc5es
        self.train_losses = train_losses
        self.train_times = train_times

    def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
        """Assign the training times."""
        train_times = collections.OrderedDict()
        for i in range(self.epochs):
            train_times[i] = estimated_per_epoch_time
        self.train_times = train_times

    def reset_pseudo_eval_times(
        self, eval_name: Text, estimated_per_epoch_time: float
    ) -> None:
        """Assign the evaluation times."""
        if eval_name not in self.eval_names:
            raise ValueError("invalid eval name : {:}".format(eval_name))
        for i in range(self.epochs):
            self.eval_times["{:}@{:}".format(eval_name, i)] = estimated_per_epoch_time

    def reset_eval(self):
        self.eval_names = []
        self.eval_acc1es = {}
        self.eval_times = {}
        self.eval_losses = {}

    def update_latency(self, latency):
        self.latency = copy.deepcopy(latency)

    def get_latency(self) -> float:
        """Return the latency value in seconds."""
        # NOTE(xuanyidong): -1 represents not avaliable,
        # NOTE(xuanyidong): otherwise it should be a float value.
        if self.latency is None:
            return -1.0
        else:
            return sum(self.latency) / len(self.latency)

    def update_eval(self, accs, losses, times):
        """To update the evaluataion results."""
        data_names = set([x.split("@")[0] for x in accs.keys()])
        for data_name in data_names:
            if data_name in self.eval_names:
                raise ValueError(
                    "{:} has already been added into " "eval-names".format(data_name)
                )
            self.eval_names.append(data_name)
            for iepoch in range(self.epochs):
                xkey = "{:}@{:}".format(data_name, iepoch)
                self.eval_acc1es[xkey] = accs[xkey]
                self.eval_losses[xkey] = losses[xkey]
                self.eval_times[xkey] = times[xkey]

    def update_OLD_eval(self, name, accs, losses):  # pylint: disable=invalid-name
        """To update the evaluataion results (old NAS-Bench-201 version)."""
        assert name not in self.eval_names, "{:} has already added".format(name)
        self.eval_names.append(name)
        for iepoch in range(self.epochs):
            if iepoch in accs:
                self.eval_acc1es["{:}@{:}".format(name, iepoch)] = accs[iepoch]
                self.eval_losses["{:}@{:}".format(name, iepoch)] = losses[iepoch]

    def __repr__(self):
        num_eval = len(self.eval_names)
        set_name = "[" + ", ".join(self.eval_names) + "]"
        return (
            "{name}({xname}, arch={arch}, FLOP={flop:.2f}M, "
            "Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: "
            "{set_name})".format(
                name=self.__class__.__name__,
                xname=self.name,
                arch=self.arch_config["arch_str"],
                flop=self.flop,
                param=self.params,
                seed=self.seed,
                num_eval=num_eval,
                set_name=set_name,
            )
        )

    def get_total_epoch(self):
        return copy.deepcopy(self.epochs)

    def get_times(self):
        """Obtain the information regarding both training and evaluation time."""
        if self.train_times is not None and isinstance(self.train_times, dict):
            train_times = list(self.train_times.values())
            time_info = {
                "T-train@epoch": mean(train_times),
                "T-train@total": sum(train_times),
            }
        else:
            time_info = {"T-train@epoch": None, "T-train@total": None}
        for name in self.eval_names:
            try:
                xtimes = [
                    self.eval_times["{:}@{:}".format(name, i)]
                    for i in range(self.epochs)
                ]
                time_info["T-{:}@epoch".format(name)] = mean(xtimes)
                time_info["T-{:}@total".format(name)] = sum(xtimes)
            except Exception as unused_e:  # pylint: disable=broad-except
                time_info["T-{:}@epoch".format(name)] = None
                time_info["T-{:}@total".format(name)] = None
        return time_info

    def get_eval_set(self):
        return self.eval_names

    def judge_valid(self, iepoch: Optional[int]):
        if iepoch < 0 or iepoch >= self.epochs:
            raise ValueError("invalid iepoch={:} < {:}".format(iepoch, self.epochs))

    def get_train(self, iepoch: Optional[int] = None):
        """Get the training information."""
        if iepoch is None:
            iepoch = self.epochs - 1
        self.judge_valid(iepoch)
        if self.train_times is not None:
            xtime = self.train_times[iepoch]
            atime = sum([self.train_times[i] for i in range(iepoch + 1)])
        else:
            xtime, atime = None, None
        return {
            "iepoch": iepoch,
            "loss": self.train_losses[iepoch],
            "accuracy": self.train_acc1es[iepoch],
            "cur_time": xtime,
            "all_time": atime,
        }

    def get_eval(self, name, iepoch: Optional[int] = None):
        """Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
        if iepoch is None:
            iepoch = self.epochs - 1
        self.judge_valid(iepoch)

        def _internal_query(xname):
            if isinstance(self.eval_times, dict) and len(self.eval_times):
                key = "{:}@{:}".format(xname, iepoch)
                if key in self.eval_times:
                    xtime = self.eval_times[key]
                else:
                    raise ValueError(
                        "{:} not in {:}".format(key, self.eval_times.keys())
                    )
                atime = sum(
                    [
                        self.eval_times["{:}@{:}".format(xname, i)]
                        for i in range(iepoch + 1)
                    ]
                )
            else:
                xtime, atime = None, None
            return {
                "iepoch": iepoch,
                "loss": self.eval_losses["{:}@{:}".format(xname, iepoch)],
                "accuracy": self.eval_acc1es["{:}@{:}".format(xname, iepoch)],
                "cur_time": xtime,
                "all_time": atime,
            }

        if name == "valid":
            return _internal_query("x-valid")
        else:
            return _internal_query(name)

    def get_net_param(self, clone: bool = False):
        if clone:
            return copy.deepcopy(self.net_state_dict)
        else:
            return self.net_state_dict

    def get_config(self, str2structure):
        """This function is used to obtain the config dict for this architecture."""
        if str2structure is None:
            # In this case, this is an arch in size search space of NATS-BENCH.
            if (
                "name" in self.arch_config
                and self.arch_config["name"] == "infer.shape.tiny"
            ):
                return {
                    "name": "infer.shape.tiny",
                    "channels": self.arch_config["channels"],
                    "genotype": self.arch_config["genotype"],
                    "num_classes": self.arch_config["class_num"],
                }
            else:  # This is an arch in NATS-BENCH's topology search space.
                return {
                    "name": "infer.tiny",
                    "C": self.arch_config["channel"],
                    "N": self.arch_config["num_cells"],
                    "arch_str": self.arch_config["arch_str"],
                    "num_classes": self.arch_config["class_num"],
                }
        else:  # This is an arch in the size search space of NATS-BENCH.
            if (
                "name" in self.arch_config
                and self.arch_config["name"] == "infer.shape.tiny"
            ):
                return {
                    "name": "infer.shape.tiny",
                    "channels": self.arch_config["channels"],
                    "genotype": str2structure(self.arch_config["genotype"]),
                    "num_classes": self.arch_config["class_num"],
                }
            else:  # This is an arch in the topology search space of NATS-BENCH.
                return {
                    "name": "infer.tiny",
                    "C": self.arch_config["channel"],
                    "N": self.arch_config["num_cells"],
                    "genotype": str2structure(self.arch_config["arch_str"]),
                    "num_classes": self.arch_config["class_num"],
                }

    def state_dict(self):
        collected_state_dict = {key: value for key, value in self.__dict__.items()}
        return collected_state_dict

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)

    @staticmethod
    def create_from_state_dict(state_dict):
        x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
        x.load_state_dict(state_dict)
        return x


================================================
FILE: nats_bench/genotype_utils.py
================================================
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.09 #
#####################################################
# Moved from https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/models/cell_searchs/genotypes.py
#####################################################
from copy import deepcopy


def topology_str2structure(xstr):
    return TopologyStructure.str2structure(xstr)


class TopologyStructure:
    """A class to describe the topology, especially that used in NATS-Bench."""

    def __init__(self, genotype):
        assert isinstance(genotype, list) or isinstance(
            genotype, tuple
        ), "invalid class of genotype : {:}".format(type(genotype))
        self.node_num = len(genotype) + 1
        self.nodes = []
        self.node_N = []
        for idx, node_info in enumerate(genotype):
            assert isinstance(node_info, list) or isinstance(
                node_info, tuple
            ), "invalid class of node_info : {:}".format(type(node_info))
            assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info))
            for node_in in node_info:
                assert isinstance(node_in, list) or isinstance(
                    node_in, tuple
                ), "invalid class of in-node : {:}".format(type(node_in))
                assert (
                    len(node_in) == 2 and node_in[1] <= idx
                ), "invalid in-node : {:}".format(node_in)
            self.node_N.append(len(node_info))
            self.nodes.append(tuple(deepcopy(node_info)))

    def tolist(self, remove_str):
        # convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
        # note that we re-order the input node in this function
        # return the-genotype-list and success [if unsuccess, it is not a connectivity]
        genotypes = []
        for node_info in self.nodes:
            node_info = list(node_info)
            node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
            node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
            if len(node_info) == 0:
                return None, False
            genotypes.append(node_info)
        return genotypes, True

    def node(self, index):
        assert index > 0 and index <= len(self), "invalid index={:} < {:}".format(
            index, len(self)
        )
        return self.nodes[index]

    def tostr(self):
        strings = []
        for node_info in self.nodes:
            string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info])
            string = "|{:}|".format(string)
            strings.append(string)
        return "+".join(strings)

    def check_valid(self):
        nodes = {0: True}
        for i, node_info in enumerate(self.nodes):
            sums = []
            for op, xin in node_info:
                if op == "none" or nodes[xin] is False:
                    x = False
                else:
                    x = True
                sums.append(x)
            nodes[i + 1] = sum(sums) > 0
        return nodes[len(self.nodes)]

    def to_unique_str(self, consider_zero=False):
        # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
        # two operations are special, i.e., none and skip_connect
        nodes = {0: "0"}
        for i_node, node_info in enumerate(self.nodes):
            cur_node = []
            for op, xin in node_info:
                if consider_zero is None:
                    x = "(" + nodes[xin] + ")" + "@{:}".format(op)
                elif consider_zero:
                    if op == "none" or nodes[xin] == "#":
                        x = "#"  # zero
                    elif op == "skip_connect":
                        x = nodes[xin]
                    else:
                        x = "(" + nodes[xin] + ")" + "@{:}".format(op)
                else:
                    if op == "skip_connect":
                        x = nodes[xin]
                    else:
                        x = "(" + nodes[xin] + ")" + "@{:}".format(op)
                cur_node.append(x)
            nodes[i_node + 1] = "+".join(sorted(cur_node))
        return nodes[len(self.nodes)]

    def check_valid_op(self, op_names):
        for node_info in self.nodes:
            for inode_edge in node_info:
                # assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
                if inode_edge[0] not in op_names:
                    return False
        return True

    def __repr__(self):
        return "{name}({node_num} nodes with {node_info})".format(
            name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__
        )

    def __len__(self):
        return len(self.nodes) + 1

    def __getitem__(self, index):
        return self.nodes[index]

    @staticmethod
    def str2structure(xstr):
        if isinstance(xstr, TopologyStructure):
            return xstr
        assert isinstance(xstr, str), "must take string (not {:}) as input".format(
            type(xstr)
        )
        nodestrs = xstr.split("+")
        genotypes = []
        for i, node_str in enumerate(nodestrs):
            inputs = list(filter(lambda x: x != "", node_str.split("|")))
            for xinput in inputs:
                assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
                    xinput
                )
            inputs = (xi.split("~") for xi in inputs)
            input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs)
            genotypes.append(input_infos)
        return TopologyStructure(genotypes)

    @staticmethod
    def str2fullstructure(xstr, default_name="none"):
        assert isinstance(xstr, str), "must take string (not {:}) as input".format(
            type(xstr)
        )
        nodestrs = xstr.split("+")
        genotypes = []
        for i, node_str in enumerate(nodestrs):
            inputs = list(filter(lambda x: x != "", node_str.split("|")))
            for xinput in inputs:
                assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
                    xinput
                )
            inputs = (xi.split("~") for xi in inputs)
            input_infos = list((op, int(IDX)) for (op, IDX) in inputs)
            all_in_nodes = list(x[1] for x in input_infos)
            for j in range(i):
                if j not in all_in_nodes:
                    input_infos.append((default_name, j))
            node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
            genotypes.append(tuple(node_info))
        return TopologyStructure(genotypes)

    @staticmethod
    def gen_all(search_space, num, return_ori):
        assert isinstance(search_space, list) or isinstance(
            search_space, tuple
        ), "invalid class of search-space : {:}".format(type(search_space))
        assert (
            num >= 2
        ), "There should be at least two nodes in a neural cell instead of {:}".format(
            num
        )
        all_archs = get_combination(search_space, 1)
        for i, arch in enumerate(all_archs):
            all_archs[i] = [tuple(arch)]

        for inode in range(2, num):
            cur_nodes = get_combination(search_space, inode)
            new_all_archs = []
            for previous_arch in all_archs:
                for cur_node in cur_nodes:
                    new_all_archs.append(previous_arch + [tuple(cur_node)])
            all_archs = new_all_archs
        if return_ori:
            return all_archs
        else:
            return [TopologyStructure(x) for x in all_archs]


================================================
FILE: notebooks/README.md
================================================
# Notebooks for NATS-Bench

We provide some examples on how to use NATS-Bench in the following notebooks:

- [`create-query-sss.ipynb`](create-query-sss.ipynb) : create the size search space and query the performance of some architectures.
- [`find-largest.ipynb`](find-largest.ipynb) : find the largest model and report its performance on each dataset.
- [`random-search.ipynb`](random-search.ipynb) : use random search on the topology search space of NATS-Bench
- [`issue-7.ipynb`](issue-7.ipynb) : show how to use `find_best` function.
- [`issue-11.ipynb`](issue-11.ipynb) : show how to use `get_more_info` function.
- [`issue-12.ipynb`](issue-12.ipynb) : show the unique string of each candidate and remove duplicated candidates.
- [`issue-21.ipynb`](issue-21.ipynb) : create the PyTorch model instance of an architecture candidate.


FYI, to open and run them in your machine, execute `jupyter notebook` under the `NATS-Bench/notebooks` folder.


================================================
FILE: notebooks/create-query-sss.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:23:08] Try to use the default NATS-Bench (size) path from fast_mode=True and path=None.\n",
      "[2021-03-29 08:23:08] Create NATS-Bench (size) done with 0/32768 architectures avaliable.\n",
      "\n",
      "API create done: NATSsize(0/32768 architectures, fast_mode=True, file=None)\n",
      "\n",
      "[2021-03-29 08:23:08] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=12, and is_random=True.\n",
      "[2021-03-29 08:23:08] Call query_index_by_arch with arch=1234\n",
      "[2021-03-29 08:23:08] Call clear_params with archive_root=/Users/xuanyidong/.torch/NATS-sss-v1_0-50262-simple and index=1234\n",
      "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n",
      "            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n",
      "            'performance of the CIFAR-10 test set after training on the '\n",
      "            'train+valid sets by 12 epochs. The per-time and total-time '\n",
      "            'indicate the per epoch and total time costs, respectively.',\n",
      " 'test-accuracy': 83.87,\n",
      " 'test-all-time': 8.31445026397705,\n",
      " 'test-loss': 0.4872739363670349,\n",
      " 'test-per-time': 0.6928708553314209,\n",
      " 'train-accuracy': 85.74,\n",
      " 'train-all-time': 69.73253917694092,\n",
      " 'train-loss': 0.4183172229385376,\n",
      " 'train-per-time': 5.811044931411743}\n"
     ]
    }
   ],
   "source": [
    "from nats_bench import create\n",
    "from pprint import pprint\n",
    "\n",
    "# Create the API instance for the size search space in NATS\n",
    "api = create(None, 'sss', fast_mode=True, verbose=True)\n",
    "print('\\nAPI create done: {:}\\n'.format(api))\n",
    "\n",
    "\n",
    "info = api.get_more_info(1234, 'cifar10')\n",
    "pprint(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:23:12] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=90, and is_random=True.\n",
      "[2021-03-29 08:23:12] Call query_index_by_arch with arch=1234\n",
      "[2021-03-29 08:23:12] Call _prepare_info with index=1234 skip because it is in arch2infos_dict\n",
      "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n",
      "            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n",
      "            'performance of the CIFAR-10 test set after training on the '\n",
      "            'train+valid sets by 90 epochs. The per-time and total-time '\n",
      "            'indicate the per epoch and total time costs, respectively.',\n",
      " 'test-accuracy': 89.4,\n",
      " 'test-all-time': 62.35837697982788,\n",
      " 'test-loss': 0.3388326271057129,\n",
      " 'test-per-time': 0.6928708553314209,\n",
      " 'train-accuracy': 95.206,\n",
      " 'train-all-time': 522.9940438270569,\n",
      " 'train-loss': 0.14320597895622253,\n",
      " 'train-per-time': 5.811044931411743}\n"
     ]
    }
   ],
   "source": [
    "info = api.get_more_info(1234, 'cifar10', hp='90')\n",
    "pprint(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:23:15] Call the get_cost_info function with index=12, dataset=cifar10, and hp=12.\n",
      "[2021-03-29 08:23:15] Call clear_params with archive_root=/Users/xuanyidong/.torch/NATS-sss-v1_0-50262-simple and index=12\n",
      "Call query_meta_info_by_index with arch_index=12, hp=12\n",
      "[2021-03-29 08:23:15] Call _prepare_info with index=12 skip because it is in arch2infos_dict\n",
      "{'T-ori-test@epoch': 0.6709375381469727,\n",
      " 'T-ori-test@total': 8.051250457763672,\n",
      " 'T-train@epoch': 5.539922475814819,\n",
      " 'T-train@total': 66.47906970977783,\n",
      " 'flops': 7.991706,\n",
      " 'latency': 0.014862352974560795,\n",
      " 'params': 0.067378}\n"
     ]
    }
   ],
   "source": [
    "# Query the flops, params, latency. info is a dict.\n",
    "info = api.get_cost_info(12, 'cifar10')\n",
    "pprint(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:23:17] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=12, and is_random=True.\n",
      "[2021-03-29 08:23:17] Call query_index_by_arch with arch=1234\n",
      "[2021-03-29 08:23:17] Call _prepare_info with index=1234 skip because it is in arch2infos_dict\n",
      "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n",
      "            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n",
      "            'performance of the CIFAR-10 test set after training on the '\n",
      "            'train+valid sets by 12 epochs. The per-time and total-time '\n",
      "            'indicate the per epoch and total time costs, respectively.',\n",
      " 'test-accuracy': 84.28,\n",
      " 'test-all-time': 8.31445026397705,\n",
      " 'test-loss': 0.46498328766822816,\n",
      " 'test-per-time': 0.6928708553314209,\n",
      " 'train-accuracy': 86.004,\n",
      " 'train-all-time': 69.73253917694092,\n",
      " 'train-loss': 0.405061281375885,\n",
      " 'train-per-time': 5.811044931411743}\n"
     ]
    }
   ],
   "source": [
    "info = api.get_more_info(1234, 'cifar10', hp='12', is_random=True)\n",
    "pprint(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:23:20] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=12, and is_random=True.\n",
      "[2021-03-29 08:23:20] Call query_index_by_arch with arch=1234\n",
      "[2021-03-29 08:23:20] Call _prepare_info with index=1234 skip because it is in arch2infos_dict\n",
      "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n",
      "            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n",
      "            'performance of the CIFAR-10 test set after training on the '\n",
      "            'train+valid sets by 12 epochs. The per-time and total-time '\n",
      "            'indicate the per epoch and total time costs, respectively.',\n",
      " 'test-accuracy': 83.87,\n",
      " 'test-all-time': 8.31445026397705,\n",
      " 'test-loss': 0.4872739363670349,\n",
      " 'test-per-time': 0.6928708553314209,\n",
      " 'train-accuracy': 85.74,\n",
      " 'train-all-time': 69.73253917694092,\n",
      " 'train-loss': 0.4183172229385376,\n",
      " 'train-per-time': 5.811044931411743}\n"
     ]
    }
   ],
   "source": [
    "# The same code as above, but return the different performance because we set is_random=True\n",
    "info = api.get_more_info(1234, 'cifar10', hp='12', is_random=True)\n",
    "pprint(info)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: notebooks/find-largest.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:22:18] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "from nats_bench import create\n",
    "from pprint import pprint\n",
    "# Create the API for tologoy search space\n",
    "api = create(None, 'tss', fast_mode=True, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The architecture-index for the largest model is 1462\n",
      "Its performance on cifar10 with 12-epoch-training\n",
      "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n",
      "            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n",
      "            'performance of the CIFAR-10 test set after training on the '\n",
      "            'train+valid sets by 12 epochs. The per-time and total-time '\n",
      "            'indicate the per epoch and total time costs, respectively.',\n",
      " 'test-accuracy': 82.2,\n",
      " 'test-all-time': 25.68491765430996,\n",
      " 'test-loss': 0.5235109260559082,\n",
      " 'test-per-time': 2.14040980452583,\n",
      " 'train-accuracy': 83.78,\n",
      " 'train-all-time': 415.2997846603394,\n",
      " 'train-loss': 0.4719834935951233,\n",
      " 'train-per-time': 34.60831538836162}\n",
      "Its performance on cifar10 with 200-epoch-training\n",
      "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n",
      "            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n",
      "            'performance of the CIFAR-10 test set after training on the '\n",
      "            'train+valid sets by 200 epochs. The per-time and total-time '\n",
      "            'indicate the per epoch and total time costs, respectively.',\n",
      " 'test-accuracy': 93.76,\n",
      " 'test-all-time': 428.08196090516384,\n",
      " 'test-loss': 0.29643801399866737,\n",
      " 'test-per-time': 2.1404098045258193,\n",
      " 'train-accuracy': 99.968,\n",
      " 'train-all-time': 6921.6630776723405,\n",
      " 'train-loss': 0.0021994023492435616,\n",
      " 'train-per-time': 34.6083153883617}\n",
      "Its performance on cifar100 with 12-epoch-training\n",
      "{'test-accuracy': 44.97999995727539,\n",
      " 'test-all-time': 12.84245882715498,\n",
      " 'test-loss': 2.069740362548828,\n",
      " 'test-per-time': 1.070204902262915,\n",
      " 'train-accuracy': 46.014,\n",
      " 'train-all-time': 415.2997846603394,\n",
      " 'train-loss': 1.9952968555450439,\n",
      " 'train-per-time': 34.60831538836162,\n",
      " 'valid-accuracy': 44.05999992675781,\n",
      " 'valid-all-time': 12.84245882715498,\n",
      " 'valid-loss': 2.077388186645508,\n",
      " 'valid-per-time': 1.070204902262915,\n",
      " 'valtest-accuracy': 44.52,\n",
      " 'valtest-all-time': 25.68491765430996,\n",
      " 'valtest-loss': 2.073564303588867,\n",
      " 'valtest-per-time': 2.14040980452583}\n",
      "Its performance on cifar100 with 200-epoch-training\n",
      "{'test-accuracy': 71.10666660563152,\n",
      " 'test-all-time': 214.04098045258192,\n",
      " 'test-loss': 1.3540414614995322,\n",
      " 'test-per-time': 1.0702049022629097,\n",
      " 'train-accuracy': 99.79133333333334,\n",
      " 'train-all-time': 6921.6630776723405,\n",
      " 'train-loss': 0.02413411712328593,\n",
      " 'train-per-time': 34.6083153883617,\n",
      " 'valid-accuracy': 70.70666666259766,\n",
      " 'valid-all-time': 214.04098045258192,\n",
      " 'valid-loss': 1.3654081104278564,\n",
      " 'valid-per-time': 1.0702049022629097,\n",
      " 'valtest-accuracy': 70.90666666666667,\n",
      " 'valtest-all-time': 428.08196090516384,\n",
      " 'valtest-loss': 1.3597248032251994,\n",
      " 'valtest-per-time': 2.1404098045258193}\n",
      "Its performance on ImageNet16-120 with 12-epoch-training\n",
      "{'test-accuracy': 22.39999992879232,\n",
      " 'test-all-time': 7.7054752962929856,\n",
      " 'test-loss': 3.1626377182006835,\n",
      " 'test-per-time': 0.6421229413577488,\n",
      " 'train-accuracy': 21.68885959195242,\n",
      " 'train-all-time': 1260.0195466594694,\n",
      " 'train-loss': 3.1863493608815463,\n",
      " 'train-per-time': 105.00162888828912,\n",
      " 'valid-accuracy': 23.266666631062826,\n",
      " 'valid-all-time': 7.7054752962929856,\n",
      " 'valid-loss': 3.1219845104217527,\n",
      " 'valid-per-time': 0.6421229413577488,\n",
      " 'valtest-accuracy': 22.833333323160808,\n",
      " 'valtest-all-time': 15.410950592585971,\n",
      " 'valtest-loss': 3.142311067581177,\n",
      " 'valtest-per-time': 1.2842458827154977}\n",
      "Its performance on ImageNet16-120 with 200-epoch-training\n",
      "{'test-accuracy': 41.44444444783529,\n",
      " 'test-all-time': 128.4245882715503,\n",
      " 'test-loss': 2.3114658287896046,\n",
      " 'test-per-time': 0.6421229413577515,\n",
      " 'train-accuracy': 50.604262800759415,\n",
      " 'train-all-time': 21000.325777657865,\n",
      " 'train-loss': 1.8626367051877495,\n",
      " 'train-per-time': 105.00162888828932,\n",
      " 'valid-accuracy': 40.777777659098305,\n",
      " 'valid-all-time': 128.4245882715503,\n",
      " 'valid-loss': 2.3157107713487415,\n",
      " 'valid-per-time': 0.6421229413577515,\n",
      " 'valtest-accuracy': 41.11111109754774,\n",
      " 'valtest-all-time': 256.8491765431006,\n",
      " 'valtest-loss': 2.313588462617662,\n",
      " 'valtest-per-time': 1.284245882715503}\n"
     ]
    }
   ],
   "source": [
    "# query the largest model's performance\n",
    "largest_candidate_tss = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|'\n",
    "arch_index = api.query_index_by_arch(largest_candidate_tss)\n",
    "print('The architecture-index for the largest model is {:}'.format(arch_index))\n",
    "datasets = ('cifar10', 'cifar100', 'ImageNet16-120')\n",
    "for dataset in datasets:\n",
    "    print('Its performance on {:} with 12-epoch-training'.format(dataset))\n",
    "    info = api.get_more_info(arch_index, dataset, hp='12', is_random=False)\n",
    "    pprint(info)\n",
    "    print('Its performance on {:} with 200-epoch-training'.format(dataset))\n",
    "    info = api.get_more_info(arch_index, dataset, hp='200', is_random=False)\n",
    "    pprint(info)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: notebooks/issue-11.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-09 08:44:19] Try to use the default NATS-Bench (size) path from fast_mode=True and path=None.\n"
     ]
    }
   ],
   "source": [
    "from nats_bench import create\n",
    "import numpy as np\n",
    "\n",
    "# Create the API for size search space\n",
    "api = create(None, 'sss', fast_mode=True, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 32768 architectures on the size search space\n"
     ]
    }
   ],
   "source": [
    "print('There are {:} architectures on the size search space'.format(len(api)))\n",
    "\n",
    "c2acc = dict()\n",
    "for index in range(len(api)):\n",
    "    info = api.get_more_info(index, 'cifar10', hp='90')\n",
    "    config = api.get_net_config(index, 'cifar10')\n",
    "    c2acc[config['channels']] = info['test-accuracy']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "91.08546417236329\n"
     ]
    }
   ],
   "source": [
    "print(np.mean(list(c2acc.values())))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: notebooks/issue-12.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-04-30 03:41:23] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n"
     ]
    }
   ],
   "source": [
    "from nats_bench import create\n",
    "from nats_bench.api_utils import time_string\n",
    "import numpy as np\n",
    "\n",
    "# Create the API for topology search space\n",
    "api = create(None, 'tss', fast_mode=True, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-04-30 03:41:23] There are 15625 architectures on the topology search space\n",
      "[2021-04-30 03:41:24] There are 6466 isomorphism architectures on the topology search space\n"
     ]
    }
   ],
   "source": [
    "print('{:} There are {:} architectures on the topology search space'.format(time_string(), len(api)))\n",
    "\n",
    "unique_strs = []\n",
    "for index in range(len(api)):\n",
    "    unique_str = api.get_unique_str(index)\n",
    "    unique_strs.append(unique_str)\n",
    "print('{:} There are {:} isomorphism architectures on the topology search space'.format(time_string(), len(set(unique_strs))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: notebooks/issue-21.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-08-07 21:51:10] Try to use the default NATS-Bench (size) path from fast_mode=True and path=None.\n"
     ]
    }
   ],
   "source": [
    "from nats_bench import create\n",
    "from nats_bench.api_utils import time_string\n",
    "import numpy as np\n",
    "\n",
    "# Create the API for size search space\n",
    "api = create(None, 'sss', fast_mode=True, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-08-07 21:51:10] There are 32768 architectures on the size search space\n",
      "{'name': 'infer.shape.tiny', 'channels': '8:8:8:16:40', 'genotype': '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', 'num_classes': 10}\n"
     ]
    }
   ],
   "source": [
    "print('{:} There are {:} architectures on the size search space'.format(time_string(), len(api)))\n",
    "\n",
    "# Obtain the 12-th candidate's configureation on CIFAR-10\n",
    "config = api.get_net_config(12, 'cifar10')\n",
    "print(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import xautodl  # import this lib -- \"https://github.com/D-X-Y/AutoDL-Projects\", you can use pip install xautodl\n",
    "from xautodl.models import get_cell_based_tiny_net\n",
    "# create the network, which is the sub-class of torch.nn.Module\n",
    "network = get_cell_based_tiny_net(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DynamicShapeTinyNet(\n",
      "  DynamicShapeTinyNet(C=(8, 8, 8, 16, 40), N=1, L=5)\n",
      "  (stem): Sequential(\n",
      "    (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  )\n",
      "  (cells): ModuleList(\n",
      "    (0): InferCell(\n",
      "      info :: nodes=4, inC=8, outC=8, [1<-(I0-L0) | 2<-(I0-L1,I1-L2) | 3<-(I0-L3,I1-L4,I2-L5)], |nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n",
      "      (layers): ModuleList(\n",
      "        (0): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (1): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (2): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (3): Identity()\n",
      "        (4): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (5): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (1): ResNetBasicblock(\n",
      "      ResNetBasicblock(inC=8, outC=8, stride=2)\n",
      "      (conv_a): ReLUConvBN(\n",
      "        (op): Sequential(\n",
      "          (0): ReLU()\n",
      "          (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "          (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "      (conv_b): ReLUConvBN(\n",
      "        (op): Sequential(\n",
      "          (0): ReLU()\n",
      "          (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "          (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "      (downsample): Sequential(\n",
      "        (0): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
      "        (1): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      )\n",
      "    )\n",
      "    (2): InferCell(\n",
      "      info :: nodes=4, inC=8, outC=8, [1<-(I0-L0) | 2<-(I0-L1,I1-L2) | 3<-(I0-L3,I1-L4,I2-L5)], |nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n",
      "      (layers): ModuleList(\n",
      "        (0): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (1): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (2): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (3): Identity()\n",
      "        (4): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (5): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (3): ResNetBasicblock(\n",
      "      ResNetBasicblock(inC=8, outC=16, stride=2)\n",
      "      (conv_a): ReLUConvBN(\n",
      "        (op): Sequential(\n",
      "          (0): ReLU()\n",
      "          (1): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "      (conv_b): ReLUConvBN(\n",
      "        (op): Sequential(\n",
      "          (0): ReLU()\n",
      "          (1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "      (downsample): Sequential(\n",
      "        (0): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
      "        (1): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      )\n",
      "    )\n",
      "    (4): InferCell(\n",
      "      info :: nodes=4, inC=16, outC=40, [1<-(I0-L0) | 2<-(I0-L1,I1-L2) | 3<-(I0-L3,I1-L4,I2-L5)], |nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n",
      "      (layers): ModuleList(\n",
      "        (0): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(16, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (1): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(16, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (2): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (3): FactorizedReduce(\n",
      "          C_in=16, C_out=40, stride=1\n",
      "          (relu): ReLU()\n",
      "          (conv): Conv2d(16, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "          (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        )\n",
      "        (4): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "        (5): ReLUConvBN(\n",
      "          (op): Sequential(\n",
      "            (0): ReLU()\n",
      "            (1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "            (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (lastact): Sequential(\n",
      "    (0): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (1): ReLU(inplace=True)\n",
      "  )\n",
      "  (global_pooling): AdaptiveAvgPool2d(output_size=1)\n",
      "  (classifier): Linear(in_features=40, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(network)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model parameters are 0.067378 MB\n"
     ]
    }
   ],
   "source": [
    "from xautodl.utils import count_parameters_in_MB\n",
    "print('The model parameters are {:} MB'.format(count_parameters_in_MB(network)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: notebooks/issue-27.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-10-21 07:08:52] Try to use the default NATS-Bench (topology) path from fast_mode=False and path=/Users/xuanyidong/.torch/NATS-tss-v1_0-3ffb9.pickle.pbz2.\n"
     ]
    }
   ],
   "source": [
    "from nats_bench import create\n",
    "from nats_bench.api_utils import time_string\n",
    "import numpy as np\n",
    "\n",
    "# Create the API for size search space\n",
    "api_tss = create(None, \"tss\", fast_mode=False, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------ImageNet16-120--------------------------------------------------\n",
      "Best (10676) architecture on validation: |nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_1x1~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n",
      "Best (857) architecture on       test: |nor_conv_1x1~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n",
      "using validation ::: validation = 46.73, test = 46.20\n",
      "\n",
      "using test       ::: validation = 46.38, test = 47.31\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def get_valid_test_acc(api, arch, dataset):\n",
    "    is_size_space = api.search_space_name == \"size\"\n",
    "    if dataset == \"cifar10\":\n",
    "        xinfo = api.get_more_info(\n",
    "            arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False\n",
    "        )\n",
    "        test_acc = xinfo[\"test-accuracy\"]\n",
    "        xinfo = api.get_more_info(\n",
    "            arch,\n",
    "            dataset=\"cifar10-valid\",\n",
    "            hp=90 if is_size_space else 200,\n",
    "            is_random=False,\n",
    "        )\n",
    "        valid_acc = xinfo[\"valid-accuracy\"]\n",
    "    else:\n",
    "        xinfo = api.get_more_info(\n",
    "            arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False\n",
    "        )\n",
    "        valid_acc = xinfo[\"valid-accuracy\"]\n",
    "        test_acc = xinfo[\"test-accuracy\"]\n",
    "    return (\n",
    "        valid_acc,\n",
    "        test_acc,\n",
    "        \"validation = {:.2f}, test = {:.2f}\\n\".format(valid_acc, test_acc),\n",
    "    )\n",
    "\n",
    "def find_best_valid(api, dataset):\n",
    "    all_valid_accs, all_test_accs = [], []\n",
    "    for index, arch in enumerate(api):\n",
    "        valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset)\n",
    "        all_valid_accs.append((index, valid_acc))\n",
    "        all_test_accs.append((index, test_acc))\n",
    "    best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0]\n",
    "    best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]\n",
    "\n",
    "    print(\"-\" * 50 + \"{:10s}\".format(dataset) + \"-\" * 50)\n",
    "    print(\n",
    "        \"Best ({:}) architecture on validation: {:}\".format(\n",
    "            best_valid_index, api[best_valid_index]\n",
    "        )\n",
    "    )\n",
    "    print(\n",
    "        \"Best ({:}) architecture on       test: {:}\".format(\n",
    "            best_test_index, api[best_test_index]\n",
    "        )\n",
    "    )\n",
    "    _, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)\n",
    "    print(\"using validation ::: {:}\".format(perf_str))\n",
    "    _, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)\n",
    "    print(\"using test       ::: {:}\".format(perf_str))\n",
    "\n",
    "dataset = \"ImageNet16-120\"\n",
    "find_best_valid(api_tss, dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: notebooks/issue-30.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-11-15 01:40:00] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n"
     ]
    }
   ],
   "source": [
    "from nats_bench import create\n",
    "import numpy as np\n",
    "\n",
    "# Create the API for the topology search space\n",
    "api = create(None, 'tss', fast_mode=True, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train-loss': 2.6261614637934385,\n",
       " 'train-accuracy': 33.934739619476986,\n",
       " 'train-per-time': 51.38146062405923,\n",
       " 'train-all-time': 10276.292124811845,\n",
       " 'valid-loss': 2.6819862944285076,\n",
       " 'valid-accuracy': 32.83333325195313,\n",
       " 'valid-per-time': 0.36819872515542285,\n",
       " 'valid-all-time': 73.63974503108457,\n",
       " 'test-loss': 2.7110290253957112,\n",
       " 'test-accuracy': 31.966666661580405,\n",
       " 'test-per-time': 0.36819872515542285,\n",
       " 'test-all-time': 73.63974503108457,\n",
       " 'valtest-loss': 2.6965076230367027,\n",
       " 'valtest-accuracy': 32.399999994913735,\n",
       " 'valtest-per-time': 0.7363974503108457,\n",
       " 'valtest-all-time': 147.27949006216915}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "name = 'ImageNet16-120'\n",
    "api.get_more_info(index=346, dataset=name, hp='200', iepoch=199, is_random=999)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "can not find random seed (999) from [777, 888]",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-9128042f185f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_more_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m347\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'200'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miepoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m199\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_random\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m999\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/nats_bench/api_topology.py\u001b[0m in \u001b[0;36mget_more_info\u001b[0;34m(self, index, dataset, iepoch, hp, is_random)\u001b[0m\n\u001b[1;32m    218\u001b[0m             \u001b[0mis_random\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseeds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    219\u001b[0m         \u001b[0;31m# collect the training information\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 220\u001b[0;31m         \u001b[0mtrain_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marchresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"train\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miepoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0miepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_random\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mis_random\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    221\u001b[0m         \u001b[0mtotal\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"iepoch\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    222\u001b[0m         xinfo = {\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/nats_bench/api_utils.py\u001b[0m in \u001b[0;36mget_metrics\u001b[0;34m(self, dataset, setname, iepoch, is_random)\u001b[0m\n\u001b[1;32m    770\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_random\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# specify the seed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    771\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mis_random\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mx_seeds\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 772\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"can not find random seed ({:}) from {:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_random\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_seeds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    773\u001b[0m             \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx_seeds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_random\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    774\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minfos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: can not find random seed (999) from [777, 888]"
     ]
    }
   ],
   "source": [
    "api.get_more_info(index=347, dataset=name, hp='200', iepoch=199, is_random=999)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: notebooks/issue-33.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-12-09 09:32:55] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n",
      "There are 15625 architectures on the size search space\n"
     ]
    }
   ],
   "source": [
    "from nats_bench import create\n",
    "import numpy as np\n",
    "\n",
    "# Create the API for size search space\n",
    "api = create(None, 'tss', fast_mode=True, verbose=False)\n",
    "print('There are {:} architectures on the size search space'.format(len(api)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(44.82400000488281, 0.014427971421626575, 287.3714566230774, 287.3714566230774)\n",
      "(19.71, 0.01402865074299, 567.5892686843872, 854.9607253074646)\n"
     ]
    }
   ],
   "source": [
    "print(api.simulate_train_eval(2, 'cifar10', iepoch=24, hp='200'))\n",
    "print(api.simulate_train_eval(2, 'cifar100', iepoch=24, hp='200'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(30.21, 0.01402865074299, 2290.6443033218384, 3145.605028629303)\n",
      "(34.86, 0.01402865074299, 3424.233141899109, 6569.838170528412)\n",
      "(51.65, 0.01402865074299, 4104.386445045471, 10674.224615573883)\n",
      "(54.640000018310545, 0.01402865074299, 4535.150203704834, 15209.374819278717)\n"
     ]
    }
   ],
   "source": [
    "print(api.simulate_train_eval(2, 'cifar100', iepoch=100, hp='200'))\n",
    "print(api.simulate_train_eval(2, 'cifar100', iepoch=150, hp='200'))\n",
    "print(api.simulate_train_eval(2, 'cifar100', iepoch=180, hp='200'))\n",
    "print(api.simulate_train_eval(2, 'cifar100', iepoch=199, hp='200'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(25.98333334350586, 0.012760758399963379, 12450.773810838671, 27660.148630117386)\n",
      "(27.033333257039388, 0.012760758399963379, 13757.711054611173, 41417.85968472856)\n"
     ]
    }
   ],
   "source": [
    "print(api.simulate_train_eval(2, 'ImageNet16-120', iepoch=180, hp='200'))\n",
    "print(api.simulate_train_eval(2, 'ImageNet16-120', iepoch=199, hp='200'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: notebooks/issue-36.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nats_bench import create\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "def close_to(a, b, eps=1e-4):\n",
    "    if b != 0 and abs(a-b) / abs(b) > eps:\n",
    "        return False\n",
    "    if a != 0 and abs(a-b) / abs(a) > eps:\n",
    "        return False\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_flops_params(xapi):\n",
    "    print(f\"Check {xapi}\")\n",
    "    datasets = (\"cifar10-valid\", \"cifar10\", \"cifar100\", \"ImageNet16-120\")\n",
    "    counts = 0\n",
    "    for index in tqdm(range(len(xapi))):\n",
    "        for dataset in datasets:\n",
    "            info_12 = xapi.get_cost_info(index, dataset, hp=\"12\")\n",
    "            info_full = xapi.get_cost_info(index, dataset, hp=xapi.full_epochs_in_paper)\n",
    "            assert close_to(info_12['flops'], info_full['flops']), f\"The {index}-th \" \\\n",
    "            f\"architecture has issues on {dataset} \" \\\n",
    "            f\"-- {info_12['flops']} vs {info_full['flops']}.\"  # check the FLOPs\n",
    "            assert close_to(info_12['params'], info_full['params']), f\"The {index}-th \" \\\n",
    "            f\"architecture has issues on {dataset} \" \\\n",
    "            f\"-- {info_12['params']} vs {info_full['params']}.\"  # check the number of parameters\n",
    "            counts += 1\n",
    "    print(f\"Check {xapi} completed -- {counts} arch-dataset pairs.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the API for size search space\n",
    "api = create(None, 'sss', fast_mode=True, verbose=False)\n",
    "print(f'There are {len(api)} architectures in the size search space -- {api}')\n",
    "check_flops_params(api)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/15625 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2022-01-20 01:03:30] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n",
      "There are 15625 architectures in the topology search space -- NATStopology(0/15625 architectures, fast_mode=True, file=None)\n",
      "Check NATStopology(0/15625 architectures, fast_mode=True, file=None)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 15625/15625 [20:06<00:00, 12.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Check NATStopology(15625/15625 architectures, fast_mode=True, file=None) completed -- 62500 arch-dataset pairs.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Create the API for topology search space\n",
    "api = create(None, 'tss', fast_mode=True, verbose=False)\n",
    "print(f'There are {len(api)} architectures in the topology search space -- {api}')\n",
    "check_flops_params(api)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # This code block is to figure out the real reason of issue#16\n",
    "# from xautodl.models import get_cell_based_tiny_net\n",
    "# from xautodl.utils import count_parameters_in_MB\n",
    "# from xautodl.utils import get_model_infos\n",
    "\n",
    "# api = create(None, 'tss', fast_mode=True, verbose=False)\n",
    "# print(api)\n",
    "\n",
    "# index, dataset = 296, \"cifar10\"\n",
    "# arch = \"|skip_connect~0|+|none~0|nor_conv_3x3~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|\"\n",
    "# index = api.query_index_by_arch(arch)\n",
    "\n",
    "\n",
    "# info_12 = api.get_cost_info(index, dataset, hp=\"12\")\n",
    "# info_full = api.get_cost_info(index, dataset, hp=api.full_epochs_in_paper)\n",
    "# print(info_12)\n",
    "# print(info_full)\n",
    "\n",
    "# config_12 = api.get_net_config(index, dataset)\n",
    "# print(config_12)\n",
    "# config_full = api.get_net_config(index, dataset)\n",
    "# print(config_full)\n",
    "\n",
    "# # create the network, which is the sub-class of torch.nn.Module\n",
    "# network = get_cell_based_tiny_net(config_full)\n",
    "\n",
    "# flop, param = get_model_infos(network, (1, 3, 32, 32))\n",
    "# print(f\"FLOPs={flop}, param={param}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# results = api.query_meta_info_by_index(index, hp=api.full_epochs_in_paper)\n",
    "# print(results.all_results.keys())\n",
    "# print(\"\")\n",
    "# print(results.dataset_seed[dataset])\n",
    "# print(results.get_compute_costs(dataset))\n",
    "# print(\"\")\n",
    "# print(results.all_results[(dataset, 777)].flop)\n",
    "# print(results.all_results[(dataset, 888)].flop)\n",
    "# print(results.all_results[(dataset, 999)].flop)\n",
    "# print(results.all_results[('cifar100', 777)])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: notebooks/issue-7.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NATS-Bench version: v1.3\n",
      "[2021-04-08 03:54:14] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n",
      "[2021-04-08 03:54:14] Create NATS-Bench (topology) done with 0/15625 architectures avaliable.\n"
     ]
    }
   ],
   "source": [
    "#\n",
    "# Aims to solve the issue mentioned in https://github.com/D-X-Y/NATS-Bench/issues/7\n",
    "#\n",
    "import nats_bench\n",
    "\n",
    "print('NATS-Bench version: {:}'.format(nats_bench.version()))\n",
    "# Create the API for tologoy search space\n",
    "api = nats_bench.create(None, 'tss', fast_mode=True, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find the best architecture on CIFAR-10 validation set\n",
    "api.verbose = False\n",
    "best_arch_index, highest_valid_accuracy = api.find_best(dataset='cifar10-valid', metric_on_set='x-valid', hp='12')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-04-08 04:48:19] The best architecture on CIFAR-10 validation set with 12-epoch training is: [13714] |nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_1x1~1|nor_conv_3x3~2|\n"
     ]
    }
   ],
   "source": [
    "print('{:} The best architecture on CIFAR-10 validation set with 12-epoch training is: [{:}] {:}'.format(\n",
    "    nats_bench.api_utils.time_string(), best_arch_index, api.arch(best_arch_index)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: notebooks/random-search.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:24:45] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "from nats_bench import create\n",
    "from pprint import pprint\n",
    "# Create the API for tologoy search space\n",
    "api_tss = create(None, 'tss', fast_mode=True, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_search(api, trials=20, dataset='ImageNet16-120'):\n",
    "    api.reset_time()\n",
    "    highest_accuracy, best_arch = -1, -1\n",
    "    for i in range(trials):\n",
    "        arch_index = random.randint(0, len(api)-1)\n",
    "        accuracy, _, _, total_cost = api.simulate_train_eval(\n",
    "            arch_index, dataset, hp=\"12\"\n",
    "        )\n",
    "        if accuracy > highest_accuracy:\n",
    "            highest_accuracy = accuracy\n",
    "            best_arch = arch_index\n",
    "    return arch_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random Search on ImageNet16-120 : 33.9463 $\\pm$ 9.4213\n"
     ]
    }
   ],
   "source": [
    "# Just a small example, not the full experiment in the paper\n",
    "dataset = 'ImageNet16-120'\n",
    "rs_times, accuracies = 100, []\n",
    "for i in range(rs_times):\n",
    "    arch_index = random_search(api_tss, dataset=dataset)\n",
    "    info = api_tss.get_more_info(arch_index, dataset, hp='200', is_random=False)\n",
    "    accuracies.append(info['test-accuracy'])\n",
    "print('Random Search on {:} : {:.4f} $\\pm$ {:.4f}'.format(dataset, np.mean(accuracies), np.std(accuracies)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: setup.py
================================================
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
#####################################################
"""The setup function for pypi."""
# The following is to make nats_bench avaliable on Python Package Index (PyPI)
#
# conda install -c conda-forge twine  # Use twine to upload nats_bench to pypi
#
# python setup.py sdist bdist_wheel
# python setup.py --help-commands
# twine check dist/*
#
# twine upload --repository-url https://test.pypi.org/legacy/ dist/*
# twine upload dist/*
# https://pypi.org/project/nats-bench
#
# NOTE(xuanyidong):
# local install = `pip install . --force`
# local test = `pytest . -s`
#
# TODO(xuanyidong): upload it to conda
#
import os
from setuptools import setup
from nats_bench import version

NAME = "nats_bench"
REQUIRES_PYTHON = ">=3.6"
DESCRIPTION = "API for NATS-Bench (a dataset/benchmark for neural architecture topology and size)."

VERSION = version()


def read(fname="README.md"):
    with open(
        os.path.join(os.path.dirname(__file__), fname), encoding="utf-8"
    ) as cfile:
        return cfile.read()


# What packages are required for this module to be executed?
REQUIRED = ["numpy>=1.16.5"]

setup(
    name=NAME,
    version=VERSION,
    author="Xuanyi Dong",
    author_email="dongxuanyi888@gmail.com",
    description=DESCRIPTION,
    license="MIT Licence",
    keywords="NAS Dataset API DeepLearning",
    url="https://github.com/D-X-Y/NATS-Bench",
    packages=["nats_bench"],
    install_requires=REQUIRED,
    python_requires=REQUIRES_PYTHON,
    long_description=read("README.md"),
    long_description_content_type="text/markdown",
    classifiers=[
        "Programming Language :: Python",
        "Programming Language :: Python :: 3",
        "Topic :: Database",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "License :: OSI Approved :: MIT License",
    ],
)


================================================
FILE: tests/api_test.py
================================================
###############################################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08                                           #
###############################################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size, IEEE TPAMI 2021 #
###############################################################################################
# pytest --capture=tee-sys                                                                    #
###############################################################################################
"""This file is used to quickly test the API."""
import os
import pytest
import random

from nats_bench.genotype_utils import topology_str2structure
from nats_bench.api_size import NATSsize
from nats_bench.api_size import ALL_BASE_NAMES as sss_base_names
from nats_bench.api_topology import NATStopology
from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names


def get_fake_torch_home_dir():
    print("This file is {:}".format(os.path.abspath(__file__)))
    print("The current directory is {:}".format(os.path.abspath(os.getcwd())))
    xname = "FAKE_TORCH_HOME"
    if xname in os.environ:
        return os.environ["FAKE_TORCH_HOME"]
    else:
        return os.path.join(
            os.path.dirname(os.path.abspath(__file__)), "..", "fake_torch_dir"
        )


def close_to(a, b, eps=1e-4):
    if b != 0 and abs(a - b) / abs(b) > eps:
        return False
    if a != 0 and abs(a - b) / abs(a) > eps:
        return False
    return True


class TestNATSBench(object):
    """A class to test different functions of NATS-Bench API."""

    def test_nats_bench_tss(self, benchmark_dir=None, fake_random=True):
        if benchmark_dir is None:
            benchmark_dir = os.path.join(
                get_fake_torch_home_dir(), tss_base_names[-1] + "-simple"
            )
        return _test_nats_bench(benchmark_dir, True, fake_random)

    def test_nats_bench_sss(self, benchmark_dir=None, fake_random=True):
        if benchmark_dir is None:
            benchmark_dir = os.path.join(
                get_fake_torch_home_dir(), sss_base_names[-1] + "-simple"
            )
        return _test_nats_bench(benchmark_dir, False, fake_random)

    def prepare_fake_tss(self):
        tss_benchmark_dir = os.path.join(
            get_fake_torch_home_dir(), tss_base_names[-1] + "-simple"
        )
        api = NATStopology(tss_benchmark_dir, True, False)
        return api

    def prepare_fake_sss(self):
        sss_benchmark_dir = os.path.join(
            get_fake_torch_home_dir(), sss_base_names[-1] + "-simple"
        )
        api = NATSsize(sss_benchmark_dir, True, False)
        return api

    def test_01_th_issue(self):
        # Link: https://github.com/D-X-Y/NATS-Bench/issues/1
        api = self.prepare_fake_tss()
        # The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs)
        info = api.get_more_info(0, "cifar10", hp=12)
        # First of all, the data split in NATS-Bench is different from that in the official CIFAR paper.
        # In NATS-Bench, we split the original CIFAR-10 training set into two parts, i.e., a training set and a validation set.
        # 
Download .txt
gitextract_fmzpgwc0/

├── .github/
│   ├── CODE-OF-CONDUCT.md
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug-report.md
│   │   └── question.md
│   └── workflows/
│       └── ci.yml
├── .gitignore
├── LICENSE.md
├── README.md
├── fake_torch_dir/
│   ├── NATS-sss-v1_0-50262-simple/
│   │   ├── 000000.pickle.pbz2
│   │   ├── 000011.pickle.pbz2
│   │   ├── 000284.pickle.pbz2
│   │   └── meta.pickle.pbz2
│   └── NATS-tss-v1_0-3ffb9-simple/
│       ├── 000000.pickle.pbz2
│       ├── 000011.pickle.pbz2
│       ├── 000284.pickle.pbz2
│       └── meta.pickle.pbz2
├── nats_bench/
│   ├── __init__.py
│   ├── api_size.py
│   ├── api_topology.py
│   ├── api_utils.py
│   └── genotype_utils.py
├── notebooks/
│   ├── README.md
│   ├── create-query-sss.ipynb
│   ├── find-largest.ipynb
│   ├── issue-11.ipynb
│   ├── issue-12.ipynb
│   ├── issue-21.ipynb
│   ├── issue-27.ipynb
│   ├── issue-30.ipynb
│   ├── issue-33.ipynb
│   ├── issue-36.ipynb
│   ├── issue-7.ipynb
│   └── random-search.ipynb
├── setup.py
└── tests/
    └── api_test.py
Download .txt
SYMBOL INDEX (142 symbols across 7 files)

FILE: nats_bench/__init__.py
  function version (line 32) | def version():
  function create (line 36) | def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
  function search_space_info (line 61) | def search_space_info(main_tag: Text, aux_tag: Optional[Text]):

FILE: nats_bench/api_size.py
  function print_information (line 31) | def print_information(information, extra_info=None, show=False):
  class NATSsize (line 85) | class NATSsize(NASBenchMetaAPI):
    method __init__ (line 88) | def __init__(
    method is_size (line 202) | def is_size(self):
    method is_topology (line 206) | def is_topology(self):
    method full_epochs_in_paper (line 210) | def full_epochs_in_paper(self):
    method query_info_str_by_arch (line 213) | def query_info_str_by_arch(self, arch, hp: Text = "12"):
    method get_more_info (line 232) | def get_more_info(
    method show (line 354) | def show(self, index: int = -1) -> None:

FILE: nats_bench/api_topology.py
  function print_information (line 33) | def print_information(information, extra_info=None, show=False):
  class NATStopology (line 85) | class NATStopology(NASBenchMetaAPI):
    method __init__ (line 88) | def __init__(
    method is_size (line 202) | def is_size(self):
    method is_topology (line 206) | def is_topology(self):
    method full_epochs_in_paper (line 210) | def full_epochs_in_paper(self):
    method get_unique_str (line 213) | def get_unique_str(self, arch):
    method query_info_str_by_arch (line 229) | def query_info_str_by_arch(self, arch, hp: Text = "12"):
    method get_more_info (line 248) | def get_more_info(
    method show (line 362) | def show(self, index: int = -1) -> None:
    method str2lists (line 367) | def str2lists(arch_str: Text) -> List[Any]:
    method str2matrix (line 402) | def str2matrix(

FILE: nats_bench/api_utils.py
  function mean (line 30) | def mean(xlist):
  function time_string (line 34) | def time_string():
  function reset_file_system (line 40) | def reset_file_system(lib: Text = "default"):
  function get_file_system (line 45) | def get_file_system():
  function get_torch_home (line 49) | def get_torch_home():
  function nats_is_dir (line 62) | def nats_is_dir(file_path):
  function nats_is_file (line 73) | def nats_is_file(file_path):
  function pickle_save (line 84) | def pickle_save(obj, file_path, ext=".pbz2", protocol=4):
  function pickle_load (line 107) | def pickle_load(file_path, ext=".pbz2"):
  function remap_dataset_set_names (line 127) | def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
  class NASBenchMetaAPI (line 155) | class NASBenchMetaAPI(metaclass=abc.ABCMeta):
    method __init__ (line 159) | def __init__(
    method __getitem__ (line 177) | def __getitem__(self, index: int):
    method arch (line 180) | def arch(self, index: int):
    method __len__ (line 190) | def __len__(self):
    method __repr__ (line 193) | def __repr__(self):
    method avaliable_hps (line 206) | def avaliable_hps(self):
    method used_time (line 210) | def used_time(self):
    method search_space_name (line 214) | def search_space_name(self):
    method fast_mode (line 218) | def fast_mode(self):
    method archive_dir (line 222) | def archive_dir(self):
    method full_train_epochs (line 226) | def full_train_epochs(self):
    method reset_archive_dir (line 229) | def reset_archive_dir(self, archive_dir):
    method reset_fast_mode (line 232) | def reset_fast_mode(self, fast_mode):
    method reset_time (line 235) | def reset_time(self):
    method get_more_info (line 239) | def get_more_info(
    method simulate_train_eval (line 244) | def simulate_train_eval(
    method random (line 278) | def random(self):
    method reload (line 282) | def reload(self, archive_root: Text = None, index: int = None):
    method query_index_by_arch (line 339) | def query_index_by_arch(self, arch):
    method query_by_arch (line 387) | def query_by_arch(self, arch, hp):
    method _prepare_info (line 391) | def _prepare_info(self, index):
    method clear_params (line 419) | def clear_params(self, index: int, hp: Optional[Text] = None):
    method query_info_str_by_arch (line 455) | def query_info_str_by_arch(self, arch, hp: Text = "12"):
    method _query_info_str_by_arch (line 458) | def _query_info_str_by_arch(self, arch, hp: Text = "12", print_informa...
    method query_meta_info_by_index (line 480) | def query_meta_info_by_index(self, arch_index, hp: Text = "12"):
    method query_by_index (line 504) | def query_by_index(
    method find_best (line 551) | def find_best(
    method get_net_param (line 616) | def get_net_param(self, index, dataset, seed: Optional[int], hp: Text ...
    method get_net_config (line 644) | def get_net_config(self, index: int, dataset: Text):
    method get_cost_info (line 677) | def get_cost_info(
    method get_latency (line 690) | def get_latency(self, index: int, dataset: Text, hp: Text = "12") -> f...
    method show (line 712) | def show(self, index=-1):
    method _show (line 715) | def _show(self, index=-1, print_information=None) -> None:
    method statistics (line 772) | def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int,...
  class ArchResults (line 795) | class ArchResults(object):
    method __init__ (line 798) | def __init__(self, arch_index, arch_str):
    method get_compute_costs (line 805) | def get_compute_costs(self, dataset):
    method get_metrics (line 845) | def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
    method get_dataset_names (line 921) | def get_dataset_names(self):
    method get_dataset_seeds (line 924) | def get_dataset_seeds(self, dataset):
    method get_net_param (line 927) | def get_net_param(self, dataset: Text, seed: Union[None, int] = None):
    method reset_latency (line 953) | def reset_latency(
    method reset_pseudo_train_times (line 963) | def reset_pseudo_train_times(
    method reset_pseudo_eval_times (line 977) | def reset_pseudo_eval_times(
    method get_latency (line 995) | def get_latency(self, dataset: Text) -> float:
    method get_total_epoch (line 1009) | def get_total_epoch(self, dataset=None):
    method query (line 1033) | def query(self, dataset, seed=None):
    method arch_idx_str (line 1041) | def arch_idx_str(self):
    method update (line 1044) | def update(self, dataset_name, seed, result):
    method state_dict (line 1060) | def state_dict(self):
    method load_state_dict (line 1083) | def load_state_dict(self, state_dict):
    method create_from_state_dict (line 1101) | def create_from_state_dict(state_dict_or_file):
    method clear_params (line 1117) | def clear_params(self):
    method debug_test (line 1125) | def debug_test(self):
    method __repr__ (line 1138) | def __repr__(self):
  class ResultsCount (line 1151) | class ResultsCount(object):
    method __init__ (line 1154) | def __init__(
    method update_train_info (line 1182) | def update_train_info(
    method reset_pseudo_train_times (line 1190) | def reset_pseudo_train_times(self, estimated_per_epoch_time: float) ->...
    method reset_pseudo_eval_times (line 1197) | def reset_pseudo_eval_times(
    method reset_eval (line 1206) | def reset_eval(self):
    method update_latency (line 1212) | def update_latency(self, latency):
    method get_latency (line 1215) | def get_latency(self) -> float:
    method update_eval (line 1224) | def update_eval(self, accs, losses, times):
    method update_OLD_eval (line 1239) | def update_OLD_eval(self, name, accs, losses):  # pylint: disable=inva...
    method __repr__ (line 1248) | def __repr__(self):
    method get_total_epoch (line 1266) | def get_total_epoch(self):
    method get_times (line 1269) | def get_times(self):
    method get_eval_set (line 1292) | def get_eval_set(self):
    method judge_valid (line 1295) | def judge_valid(self, iepoch: Optional[int]):
    method get_train (line 1299) | def get_train(self, iepoch: Optional[int] = None):
    method get_eval (line 1317) | def get_eval(self, name, iepoch: Optional[int] = None):
    method get_net_param (line 1353) | def get_net_param(self, clone: bool = False):
    method get_config (line 1359) | def get_config(self, str2structure):
    method state_dict (line 1401) | def state_dict(self):
    method load_state_dict (line 1405) | def load_state_dict(self, state_dict):
    method create_from_state_dict (line 1409) | def create_from_state_dict(state_dict):

FILE: nats_bench/genotype_utils.py
  function topology_str2structure (line 9) | def topology_str2structure(xstr):
  class TopologyStructure (line 13) | class TopologyStructure:
    method __init__ (line 16) | def __init__(self, genotype):
    method tolist (line 38) | def tolist(self, remove_str):
    method node (line 52) | def node(self, index):
    method tostr (line 58) | def tostr(self):
    method check_valid (line 66) | def check_valid(self):
    method to_unique_str (line 79) | def to_unique_str(self, consider_zero=False):
    method check_valid_op (line 104) | def check_valid_op(self, op_names):
    method __repr__ (line 112) | def __repr__(self):
    method __len__ (line 117) | def __len__(self):
    method __getitem__ (line 120) | def __getitem__(self, index):
    method str2structure (line 124) | def str2structure(xstr):
    method str2fullstructure (line 144) | def str2fullstructure(xstr, default_name="none"):
    method gen_all (line 167) | def gen_all(search_space, num, return_ori):

FILE: setup.py
  function read (line 34) | def read(fname="README.md"):

FILE: tests/api_test.py
  function get_fake_torch_home_dir (line 20) | def get_fake_torch_home_dir():
  function close_to (line 32) | def close_to(a, b, eps=1e-4):
  class TestNATSBench (line 40) | class TestNATSBench(object):
    method test_nats_bench_tss (line 43) | def test_nats_bench_tss(self, benchmark_dir=None, fake_random=True):
    method test_nats_bench_sss (line 50) | def test_nats_bench_sss(self, benchmark_dir=None, fake_random=True):
    method prepare_fake_tss (line 57) | def prepare_fake_tss(self):
    method prepare_fake_sss (line 64) | def prepare_fake_sss(self):
    method test_01_th_issue (line 71) | def test_01_th_issue(self):
    method test_02_th_issue (line 119) | def test_02_th_issue(self):
    method test_07_th_issue (line 142) | def test_07_th_issue(self):
    method test_12_th_issue (line 158) | def test_12_th_issue(self):
    method test_36_th_issue (line 170) | def test_36_th_issue(self):
    method test_44_th_issue (line 193) | def test_44_th_issue(self):
  function _test_nats_bench (line 201) | def _test_nats_bench(benchmark_dir, is_tss, fake_random, hp="12", verbos...
Condensed preview — 34 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (216K chars).
[
  {
    "path": ".github/CODE-OF-CONDUCT.md",
    "chars": 3355,
    "preview": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, w"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.md",
    "chars": 513,
    "preview": "---\nname: Bug Report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the b"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/question.md",
    "chars": 370,
    "preview": "---\nname: Questions about NATS-Bench\nabout: Ask questions about or discuss on NATS-Bench\ntitle: ''\nlabels: ''\nassignees:"
  },
  {
    "path": ".github/workflows/ci.yml",
    "chars": 1690,
    "preview": "name: Run Python Tests\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\n\njobs:\n  build:"
  },
  {
    "path": ".gitignore",
    "chars": 1162,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE.md",
    "chars": 1109,
    "preview": "MIT License\n\nCopyright (c) since 2020 Xuanyi Dong (GitHub: https://github.com/D-X-Y)\n\nPermission is hereby granted, free"
  },
  {
    "path": "README.md",
    "chars": 21098,
    "preview": "# [NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size](https://arxiv.org/abs/2009.00437)\n\nXuanyi"
  },
  {
    "path": "nats_bench/__init__.py",
    "chars": 3448,
    "preview": "##############################################################################\n# Copyright (c) Xuanyi Dong [GitHub D-X-Y"
  },
  {
    "path": "nats_bench/api_size.py",
    "chars": 16420,
    "preview": "#####################################################\n# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #\n############"
  },
  {
    "path": "nats_bench/api_topology.py",
    "chars": 20399,
    "preview": "#####################################################\n# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #\n############"
  },
  {
    "path": "nats_bench/api_utils.py",
    "chars": 57006,
    "preview": "#####################################################\n# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #\n############"
  },
  {
    "path": "nats_bench/genotype_utils.py",
    "chars": 7681,
    "preview": "#####################################################\n# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.09 #\n############"
  },
  {
    "path": "notebooks/README.md",
    "chars": 950,
    "preview": "# Notebooks for NATS-Bench\n\nWe provide some examples on how to use NATS-Bench in the following notebooks:\n\n- [`create-qu"
  },
  {
    "path": "notebooks/create-query-sss.ipynb",
    "chars": 7842,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "notebooks/find-largest.ipynb",
    "chars": 7164,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "notebooks/issue-11.ipynb",
    "chars": 1841,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "notebooks/issue-12.ipynb",
    "chars": 1825,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "notebooks/issue-21.ipynb",
    "chars": 12565,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "notebooks/issue-27.ipynb",
    "chars": 4429,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "notebooks/issue-30.ipynb",
    "chars": 7027,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "notebooks/issue-33.ipynb",
    "chars": 3005,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "notebooks/issue-36.ipynb",
    "chars": 5826,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n "
  },
  {
    "path": "notebooks/issue-7.ipynb",
    "chars": 2206,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "notebooks/random-search.ipynb",
    "chars": 2460,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\":"
  },
  {
    "path": "setup.py",
    "chars": 1925,
    "preview": "#####################################################\n# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #\n############"
  },
  {
    "path": "tests/api_test.py",
    "chars": 10530,
    "preview": "###############################################################################################\n# Copyright (c) Xuanyi D"
  }
]

// ... and 8 more files (download for full content)

About this extraction

This page contains the full source code of the D-X-Y/NATS-Bench GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 34 files (199.1 KB), approximately 55.3k tokens, and a symbol index with 142 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!