Repository: thu-ml/tianshou Branch: master Commit: 1bbe05b3365f Files: 303 Total size: 4.9 MB Directory structure: gitextract_z9q4ijrb/ ├── .devcontainer/ │ └── devcontainer.json ├── .dockerignore ├── .github/ │ ├── ISSUE_TEMPLATE.md │ ├── PULL_REQUEST_TEMPLATE.md │ └── workflows/ │ ├── extra_sys.yml │ ├── gputest.yml │ ├── lint_and_docs.yml │ ├── publish.yaml │ └── pytest.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── README.md ├── benchmark/ │ └── run_benchmark.py ├── docs/ │ ├── .gitignore │ ├── 01_user_guide/ │ │ ├── 00_training_process.md │ │ ├── 01_apis.md │ │ ├── 02_core_abstractions.md │ │ └── index.rst │ ├── 02_deep_dives/ │ │ ├── 0_intro.md │ │ ├── L1_Batch.ipynb │ │ ├── L2_Buffer.ipynb │ │ ├── L3_Environments.ipynb │ │ ├── L4_GAE.ipynb │ │ ├── L5_Collector.ipynb │ │ └── L6_MARL.ipynb │ ├── 04_benchmarks/ │ │ └── benchmarks.rst │ ├── 05_developer_guide/ │ │ └── developer_guide.md │ ├── 06_contributors/ │ │ └── contributors.rst │ ├── _config.yml │ ├── _static/ │ │ ├── css/ │ │ │ └── style.css │ │ └── js/ │ │ ├── benchmark.js │ │ ├── copybutton.js │ │ ├── mujoco/ │ │ │ └── benchmark/ │ │ │ └── Ant-v4/ │ │ │ └── results.json │ │ ├── v5.json │ │ ├── vega-embed@5.js │ │ ├── vega-lite@5.js │ │ └── vega@5.js │ ├── autogen_rst.py │ ├── bibtex.json │ ├── create_toc.py │ ├── index.rst │ ├── nbstripout.py │ └── refs.bib ├── examples/ │ ├── __init__.py │ ├── atari/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── atari_c51.py │ │ ├── atari_dqn.py │ │ ├── atari_dqn_hl.py │ │ ├── atari_fqf.py │ │ ├── atari_iqn.py │ │ ├── atari_iqn_hl.py │ │ ├── atari_ppo.py │ │ ├── atari_ppo_hl.py │ │ ├── atari_qrdqn.py │ │ ├── atari_rainbow.py │ │ ├── atari_sac.py │ │ └── atari_sac_hl.py │ ├── box2d/ │ │ ├── README.md │ │ ├── acrobot_dualdqn.py │ │ ├── bipedal_bdq.py │ │ ├── bipedal_hardcore_sac.py │ │ ├── lunarlander_dqn.py │ │ └── mcc_sac.py │ ├── discrete/ │ │ ├── discrete_dqn.py │ │ └── discrete_dqn_hl.py │ ├── inverse/ │ │ ├── README.md │ │ └── irl_gail.py │ ├── modelbased/ │ │ └── README.md │ ├── mujoco/ │ │ ├── README.md │ │ ├── analysis.py │ │ ├── fetch_her_ddpg.py │ │ ├── mujoco_a2c.py │ │ ├── mujoco_a2c_hl.py │ │ ├── mujoco_ddpg.py │ │ ├── mujoco_ddpg_hl.py │ │ ├── mujoco_env.py │ │ ├── mujoco_npg.py │ │ ├── mujoco_npg_hl.py │ │ ├── mujoco_ppo.py │ │ ├── mujoco_ppo_hl.py │ │ ├── mujoco_redq.py │ │ ├── mujoco_redq_hl.py │ │ ├── mujoco_reinforce.py │ │ ├── mujoco_reinforce_hl.py │ │ ├── mujoco_sac.py │ │ ├── mujoco_sac_hl.py │ │ ├── mujoco_td3.py │ │ ├── mujoco_td3_hl.py │ │ ├── mujoco_trpo.py │ │ ├── mujoco_trpo_hl.py │ │ ├── plotter.py │ │ └── tools.py │ ├── offline/ │ │ ├── README.md │ │ ├── atari_bcq.py │ │ ├── atari_cql.py │ │ ├── atari_crr.py │ │ ├── atari_il.py │ │ ├── convert_rl_unplugged_atari.py │ │ ├── d4rl_bcq.py │ │ ├── d4rl_cql.py │ │ ├── d4rl_il.py │ │ ├── d4rl_td3_bc.py │ │ └── utils.py │ └── vizdoom/ │ ├── .gitignore │ ├── README.md │ ├── env.py │ ├── maps/ │ │ ├── D1_basic.cfg │ │ ├── D1_basic.wad │ │ ├── D2_navigation.cfg │ │ ├── D2_navigation.wad │ │ ├── D3_battle.cfg │ │ ├── D3_battle.wad │ │ ├── D4_battle2.cfg │ │ ├── D4_battle2.wad │ │ ├── README.md │ │ └── spectator.py │ ├── replay.py │ ├── vizdoom_c51.py │ └── vizdoom_ppo.py ├── pyproject.toml ├── test/ │ ├── __init__.py │ ├── base/ │ │ ├── __init__.py │ │ ├── env.py │ │ ├── test_action_space_sampling.py │ │ ├── test_batch.py │ │ ├── test_buffer.py │ │ ├── test_collector.py │ │ ├── test_env.py │ │ ├── test_env_finite.py │ │ ├── test_logger.py │ │ ├── test_policy.py │ │ ├── test_returns.py │ │ ├── test_stats.py │ │ └── test_utils.py │ ├── continuous/ │ │ ├── __init__.py │ │ ├── test_ddpg.py │ │ ├── test_npg.py │ │ ├── test_ppo.py │ │ ├── test_redq.py │ │ ├── test_sac_with_il.py │ │ ├── test_td3.py │ │ └── test_trpo.py │ ├── determinism_test.py │ ├── discrete/ │ │ ├── __init__.py │ │ ├── test_a2c_with_il.py │ │ ├── test_bdqn.py │ │ ├── test_c51.py │ │ ├── test_discrete_sac.py │ │ ├── test_dqn.py │ │ ├── test_drqn.py │ │ ├── test_fqf.py │ │ ├── test_iqn.py │ │ ├── test_ppo_discrete.py │ │ ├── test_qrdqn.py │ │ ├── test_rainbow.py │ │ └── test_reinforce.py │ ├── highlevel/ │ │ ├── __init__.py │ │ ├── env_factory.py │ │ └── test_experiment_builder.py │ ├── modelbased/ │ │ ├── __init__.py │ │ ├── test_dqn_icm.py │ │ ├── test_ppo_icm.py │ │ └── test_psrl.py │ ├── offline/ │ │ ├── __init__.py │ │ ├── gather_cartpole_data.py │ │ ├── gather_pendulum_data.py │ │ ├── test_bcq.py │ │ ├── test_cql.py │ │ ├── test_discrete_bcq.py │ │ ├── test_discrete_cql.py │ │ ├── test_discrete_crr.py │ │ ├── test_gail.py │ │ └── test_td3_bc.py │ └── pettingzoo/ │ ├── pistonball.py │ ├── pistonball_continuous.py │ ├── test_pistonball.py │ ├── test_pistonball_continuous.py │ ├── test_tic_tac_toe.py │ └── tic_tac_toe.py └── tianshou/ ├── __init__.py ├── algorithm/ │ ├── __init__.py │ ├── algorithm_base.py │ ├── imitation/ │ │ ├── __init__.py │ │ ├── bcq.py │ │ ├── cql.py │ │ ├── discrete_bcq.py │ │ ├── discrete_cql.py │ │ ├── discrete_crr.py │ │ ├── gail.py │ │ ├── imitation_base.py │ │ └── td3_bc.py │ ├── modelbased/ │ │ ├── __init__.py │ │ ├── icm.py │ │ └── psrl.py │ ├── modelfree/ │ │ ├── __init__.py │ │ ├── a2c.py │ │ ├── bdqn.py │ │ ├── c51.py │ │ ├── ddpg.py │ │ ├── discrete_sac.py │ │ ├── dqn.py │ │ ├── fqf.py │ │ ├── iqn.py │ │ ├── npg.py │ │ ├── ppo.py │ │ ├── qrdqn.py │ │ ├── rainbow.py │ │ ├── redq.py │ │ ├── reinforce.py │ │ ├── sac.py │ │ ├── td3.py │ │ └── trpo.py │ ├── multiagent/ │ │ ├── __init__.py │ │ └── marl.py │ ├── optim.py │ └── random.py ├── config.py ├── data/ │ ├── __init__.py │ ├── batch.py │ ├── buffer/ │ │ ├── __init__.py │ │ ├── buffer_base.py │ │ ├── cached.py │ │ ├── her.py │ │ ├── manager.py │ │ ├── prio.py │ │ └── vecbuf.py │ ├── collector.py │ ├── stats.py │ ├── types.py │ └── utils/ │ ├── __init__.py │ ├── converter.py │ └── segtree.py ├── env/ │ ├── __init__.py │ ├── atari/ │ │ ├── atari_network.py │ │ └── atari_wrapper.py │ ├── gym_wrappers.py │ ├── pettingzoo_env.py │ ├── utils.py │ ├── venv_wrappers.py │ ├── venvs.py │ └── worker/ │ ├── __init__.py │ ├── dummy.py │ ├── ray.py │ ├── subproc.py │ └── worker_base.py ├── evaluation/ │ ├── __init__.py │ ├── launcher.py │ └── rliable_evaluation.py ├── exploration/ │ ├── __init__.py │ └── random.py ├── highlevel/ │ ├── __init__.py │ ├── algorithm.py │ ├── config.py │ ├── env.py │ ├── experiment.py │ ├── logger.py │ ├── module/ │ │ ├── __init__.py │ │ ├── actor.py │ │ ├── core.py │ │ ├── critic.py │ │ ├── intermediate.py │ │ └── special.py │ ├── params/ │ │ ├── __init__.py │ │ ├── algorithm_params.py │ │ ├── algorithm_wrapper.py │ │ ├── alpha.py │ │ ├── collector.py │ │ ├── dist_fn.py │ │ ├── env_param.py │ │ ├── lr_scheduler.py │ │ ├── noise.py │ │ └── optim.py │ ├── persistence.py │ ├── trainer.py │ └── world.py ├── py.typed ├── trainer.py └── utils/ ├── __init__.py ├── conversion.py ├── determinism.py ├── lagged_network.py ├── logger/ │ ├── __init__.py │ ├── logger_base.py │ ├── tensorboard.py │ └── wandb.py ├── logging.py ├── net/ │ ├── __init__.py │ ├── common.py │ ├── continuous.py │ └── discrete.py ├── print.py ├── progress_bar.py ├── space_info.py ├── statistics.py ├── torch_utils.py └── warning.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .devcontainer/devcontainer.json ================================================ { "name": "Tianshou", "dockerFile": "../Dockerfile", "workspaceFolder": "/workspaces/tianshou", "runArgs": ["--shm-size=1g"], "customizations": { "vscode": { "settings": { "terminal.integrated.shell.linux": "/bin/bash", "python.pythonPath": "/usr/local/bin/python" }, "extensions": [ "ms-python.python", "ms-toolsai.jupyter", "ms-python.vscode-pylance" ] } }, "forwardPorts": [], "postCreateCommand": "poetry install --with dev", "remoteUser": "root" } ================================================ FILE: .dockerignore ================================================ data logs test/log docs/jupyter_execute docs/.jupyter_cache .lsp .clj-kondo docs/_build coverage* __pycache__ *.egg-info *.egg .*cache dist ================================================ FILE: .github/ISSUE_TEMPLATE.md ================================================ - [ ] I have marked all applicable categories: + [ ] exception-raising bug + [ ] RL algorithm bug + [ ] documentation request (i.e. "X is missing from the documentation.") + [ ] new feature request + [ ] design request (i.e. "X should be changed to Y.") - [ ] I have visited the [source website](https://github.com/thu-ml/tianshou/) - [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates - [ ] I have mentioned version numbers, operating system and environment, where applicable: ```python import tianshou, gymnasium as gym, torch, numpy, sys print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform) ``` ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ - [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s) - [ ] I have provided a description of the changes in this Pull Request - [ ] I have added documentation for my changes and have listed relevant changes in CHANGELOG.md - [ ] If applicable, I have added tests to cover my changes. - [ ] If applicable, I have made sure that the determinism tests run through, meaning that my changes haven't influenced any aspect of training. See info in the contributing documentation. - [ ] I have reformatted the code using `poe format` - [ ] I have checked style and types with `poe lint` and `poe type-check` - [ ] (Optional) I ran tests locally with `poe test` (or a subset of them with `poe test-reduced`) ,and they pass - [ ] (Optional) I have tested that documentation builds correctly with `poe doc-build` ================================================ FILE: .github/workflows/extra_sys.yml ================================================ name: Windows/MacOS on: pull_request: branches: - master push: branches: - master workflow_dispatch: inputs: debug_enabled: type: boolean description: 'Run the build with tmate debugging enabled (https://github.com/marketplace/actions/debugging-with-tmate)' required: false default: false jobs: cpu-extra: runs-on: ${{ matrix.os }} if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: os: [macos-latest, windows-latest] python-version: [3.11] steps: - name: Setup tmate session uses: mxschmitt/action-tmate@v3 if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled }} - name: Cancel previous run uses: styfle/cancel-workflow-action@0.11.0 with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} # use poetry and cache installed packages, see https://github.com/marketplace/actions/python-poetry-action - name: Install poetry uses: abatilo/actions-poetry@v2 - name: Setup a local virtual environment (if no poetry.toml file) run: | poetry config virtualenvs.create true --local poetry config virtualenvs.in-project true --local - uses: actions/cache@v3 name: Define a cache for the virtual environment based on the dependencies lock file with: path: ./.venv key: venv-${{ hashFiles('poetry.lock') }} - name: Install the project dependencies # ugly as hell, but well... # see https://github.com/python-poetry/poetry/issues/7611 run: poetry install --with dev || poetry install --with dev || poetry install --with dev - name: wandb login run: poetry run wandb login e2366d661b89f2bee877c40bee15502d67b7abef - name: Test with pytest run: poetry run poe test-reduced ================================================ FILE: .github/workflows/gputest.yml ================================================ name: Ubuntu GPU on: pull_request: branches: - master push: branches: - master workflow_dispatch: inputs: debug_enabled: type: boolean description: 'Run the build with tmate debugging enabled (https://github.com/marketplace/actions/debugging-with-tmate)' required: false default: false jobs: gpu: runs-on: [self-hosted, Linux, X64] if: "!contains(github.event.head_commit.message, 'ci skip')" steps: - name: Setup tmate session uses: mxschmitt/action-tmate@v3 if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled }} - name: Cancel previous run uses: styfle/cancel-workflow-action@0.11.0 with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: python-version: "3.11" # use poetry and cache installed packages, see https://github.com/marketplace/actions/python-poetry-action - name: Install poetry uses: abatilo/actions-poetry@v2 - name: Setup a local virtual environment (if no poetry.toml file) run: | poetry config virtualenvs.create true --local poetry config virtualenvs.in-project true --local - uses: actions/cache@v3 name: Define a cache for the virtual environment based on the dependencies lock file with: path: ./.venv key: venv-${{ hashFiles('poetry.lock') }} - name: Install the project dependencies run: | poetry install --with dev --extras "envpool" - name: wandb login run: | poetry run wandb login e2366d661b89f2bee877c40bee15502d67b7abef - name: Test with pytest run: | poetry run poe test ================================================ FILE: .github/workflows/lint_and_docs.yml ================================================ name: Check Formatting/Typing and Build Docs on: pull_request: branches: - master push: branches: - master workflow_dispatch: inputs: debug_enabled: type: boolean description: 'Run the build with tmate debugging enabled (https://github.com/marketplace/actions/debugging-with-tmate)' required: false default: false jobs: check: runs-on: ubuntu-latest steps: - name: Setup tmate session uses: mxschmitt/action-tmate@v3 if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled }} - name: Cancel previous run uses: styfle/cancel-workflow-action@0.11.0 with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: python-version: 3.11 # use poetry and cache installed packages, see https://github.com/marketplace/actions/python-poetry-action - name: Install poetry uses: abatilo/actions-poetry@v2 - name: Setup a local virtual environment (if no poetry.toml file) run: | poetry config virtualenvs.create true --local poetry config virtualenvs.in-project true --local - uses: actions/cache@v3 name: Define a cache for the virtual environment based on the dependencies lock file with: path: ./.venv key: venv-${{ hashFiles('poetry.lock') }} - name: Install the project dependencies run: | poetry install --with dev --extras "eval" - name: Check formatting run: poetry run poe lint - name: Check typing run: poetry run poe type-check - name: Build docs run: MYSTNB_DEBUG=1 poetry run poe doc-build - name: Show errors (if any) if: failure() run: find docs/_build/reports -name "*.err.log" -exec echo "--- {} ---" \; -exec cat {} \; ================================================ FILE: .github/workflows/publish.yaml ================================================ name: Upload Python Package on: release: types: [created] jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v1 with: python-version: 3.11 # use poetry and cache installed packages, see https://github.com/marketplace/actions/python-poetry-action - name: Install poetry uses: abatilo/actions-poetry@v2 - name: Setup a local virtual environment (if no poetry.toml file) run: | poetry config virtualenvs.create true --local poetry config virtualenvs.in-project true --local - name: Build and publish env: POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }} run: | if [ -z "${POETRY_PYPI_TOKEN_PYPI}" ]; then echo "Set the PYPI_TOKEN variable in your repository secrets"; exit 1; fi poetry publish --build ================================================ FILE: .github/workflows/pytest.yml ================================================ name: Ubuntu on: pull_request: branches: - master push: branches: - master workflow_dispatch: inputs: debug_enabled: type: boolean description: 'Run the build with tmate debugging enabled (https://github.com/marketplace/actions/debugging-with-tmate)' required: false default: false # This job runs the test suite in two environments: # - py_pinned: uses Python 3.11 with the existing poetry.lock file (our stable, pinned dev environment) # - py_latest: latest Python version we want to support, without the lock file to furthermore install the newest dependency versions # # This ensures compatibility with both our controlled dev setup and the latest upstream packages, # helping catch issues introduced by dependency updates or newer Python versions. jobs: cpu: runs-on: ubuntu-latest if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: include: - env_name: py_pinned python-version: "3.11" use_lock: true - env_name: py_latest python-version: "3.13" use_lock: false steps: - name: Setup tmate session uses: mxschmitt/action-tmate@v3 if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled }} - name: Cancel previous run uses: styfle/cancel-workflow-action@0.11.0 with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install poetry uses: abatilo/actions-poetry@v2 - name: Setup a local virtual environment (if no poetry.toml file) run: | poetry config virtualenvs.create true --local poetry config virtualenvs.in-project true --local - name: Remove poetry.lock for latest dependency test if: ${{ !matrix.use_lock }} run: rm -f poetry.lock - name: Define a cache for the virtual environment based on the dependencies lock file if: matrix.use_lock uses: actions/cache@v3 with: path: ./.venv key: venv-${{ matrix.env_name }}-${{ hashFiles('poetry.lock') }} restore-keys: | venv-${{ matrix.env_name }}- - name: Install the project dependencies run: | if [ "${{ matrix.env_name }}" = "py_latest" ]; then poetry install --with dev --extras "eval" else poetry install --with dev --extras "envpool eval" fi - name: List installed packages run: | poetry run pip list - name: wandb login run: | poetry run wandb login e2366d661b89f2bee877c40bee15502d67b7abef - name: Test with pytest run: | if [ "${{ matrix.env_name }}" = "py_pinned" ]; then poetry run poe test else poetry run poe test-nocov fi - name: Upload coverage to Codecov if: matrix.env_name == 'py_pinned' uses: codecov/codecov-action@v1 with: token: ${{ secrets.CODECOV }} file: ./coverage.xml flags: ${{ matrix.env_name }} name: codecov-${{ matrix.env_name }} fail_ci_if_error: false ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # .idea folder .idea/ # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv venv/ /ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # customize log/ MUJOCO_LOG.TXT *.pth .vscode/ .DS_Store *.zip *.pstats *.swp *.pkl *.hdf5 wandb/ videos/ # might be needed for IDE plugins that can't read ruff config .flake8 docs/notebooks/_build/ docs/conf.py # temporary scripts (for ad-hoc testing), temp folder /temp /temp*.py # Serena /.serena # determinism test snapshots /test/resources/determinism/ ================================================ FILE: .pre-commit-config.yaml ================================================ default_install_hook_types: [commit-msg, pre-commit] default_stages: [commit, manual] fail_fast: false repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - id: check-added-large-files - repo: local hooks: - id: ruff name: ruff entry: poetry run ruff require_serial: true language: system types: [python] - id: ruff-nb name: ruff-nb entry: poetry run nbqa ruff . require_serial: true language: system pass_filenames: false types: [python] - id: black name: black entry: poetry run black require_serial: true language: system types: [python] - id: poetry-check name: poetry check entry: poetry check language: system files: pyproject.toml pass_filenames: false - id: poetry-lock-check name: poetry lock check entry: poetry check args: [--lock] language: system pass_filenames: false - id: mypy name: mypy entry: poetry run mypy tianshou examples test # filenames should not be passed as they would collide with the config in pyproject.toml pass_filenames: false files: '^tianshou(/[^/]*)*/[^/]*\.py$' language: system - id: mypy-nb name: mypy-nb entry: poetry run nbqa mypy language: system ================================================ FILE: .readthedocs.yaml ================================================ # .readthedocs.yaml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 # Set the version of Python and other tools you might need build: os: ubuntu-22.04 tools: python: "3.11" commands: - mkdir -p $READTHEDOCS_OUTPUT/html - curl -sSL https://install.python-poetry.org | python - # - ~/.local/bin/poetry config virtualenvs.create false - ~/.local/bin/poetry install --with dev -E eval ## Same as poe tasks, but unfortunately poe doesn't work with poetry not creating virtualenvs - ~/.local/bin/poetry run python docs/autogen_rst.py - ~/.local/bin/poetry run which jupyter-book - ~/.local/bin/poetry run python docs/create_toc.py - ~/.local/bin/poetry run jupyter-book config sphinx docs/ - ~/.local/bin/poetry run sphinx-build -W -b html docs $READTHEDOCS_OUTPUT/html ================================================ FILE: CHANGELOG.md ================================================ # Release 2.0.0 (2025-12-01) This major release of Tianshou is a big step towards cleaner design and improved usability. Given the large extent of the changes, it was not possible to maintain compatibility with the previous version. * Persisted agents that were created with earlier versions cannot be loaded in v2. * Source code from v1 can, however, be migrated to v2 with minimal effort. See migration information below. For concrete examples, you may use git to diff individual example scripts with the corresponding ones in `v1.2.0`. This release is brought to you by [Applied AI Institute gGmbH](https://www.appliedai-institute.de). Developers: * Dr. Dominik Jain (@opcode81) * Michael Panchenko (@MischaPanch) ## Runtime Environment Compatibility Tianshou v2 is now compatible with * Python 3.12 and Python 3.13 #1274 * newer versions of gymnasium (v1+) and numpy (v2+) Our main test environment remains Python 3.11-based for the time being (see `poetry.lock`). ## Trainer Abstraction * The trainer logic and configuration is now properly separated between the three cases of on-policy, off-policy and offline learning: The base class is no longer a "God" class (formerly `BaseTrainer`) which does it all; logic and functionality has moved to the respective subclasses (`OnPolicyTrainer`, `OffPolicyTrainer` and `OfflineTrainer`, with `OnlineTrainer` being introduced as a base class for the two former specialisations). * The trainers now use configuration objects with central documentation (which has been greatly improved to enhance clarity and usability in general); every type of trainer now has a dedicated configuration class which provides precisely the options that are applicable. * The interface has been streamlined with improved naming of functions/parameters and limiting the public interface to purely the methods and attributes a user should reasonably access. * Further changes potentially affecting usage: * We dropped the iterator semantics: Method `__next__` has been replaced by `execute_epoch`. #913 * We no longer report outdated statistics (e.g. on rewards/returns when a training step does not collect any full episodes) * See also "Issues resolved" below (as issue resolution can result in usage changes) * The default value for `test_in_train` was changed from True to False (updating all usage sites to explicitly set the parameter), because False is the more natural default, which does not make assumptions about returns/score values computed for the data from a collection step being at all meaningful for early stopping * The management of epsilon-greedy exploration for discrete Q-learning algorithms has been simplified: * All respective Policy implementations (e.g. `DQNPolicy`, `C51Policy`, etc.) now accept two parameters `eps_training` and `eps_inference`, which allows the training and test collection cases to be sufficiently differentiated and makes the use of callback functions (`train_fn`, `test_fn`) unnecessary if only constants are to be set. * The setter method `set_eps` has been replaced with `set_eps_training` and `set_eps_inference` accordingly. * Further internal changes unlikely to affect usage: * Module `trainer.utils` was removed and the functions therein where moved to class `Trainer` * The two places that collected and evaluated test episodes (`_test_in_train` and `_reset`) in addition to `_test_step` were unified to use `_test_step` (with some minor parametrisation) and now log the results of the test step accordingly. * Issues resolved: * Methods `run` and `reset`: Parameter `reset_prior_to_run` of `run` was never respected if it was set to `False`, because the implementation of `__iter__` (now removed) would call `reset` regardless - and calling `reset` is indeed necessary, because it initializes the training. The parameter was removed and replaced by `reset_collectors` (such that `run` now replicates the parameters of `reset`). * Inconsistent configuration options now raise exceptions rather than silently ignoring the issue in the hope that default behaviour will achieve what the user intended. One condition where `test_in_train` was silently set to `False` was removed and replaced by a warning. * The stop criterion `stop_fn` did not consider scores as computed by `compute_score_fn` but instead always used mean returns (i.e. it was assumed that the default implementation of `compute_score_fn` applies). This is an inconsistency which has been resolved. * The `gradient_step` counter was flawed (as it made assumptions about the underlying algorithms, which were not valid). It has been replaced with an update step counter. Members of `InfoStats` and parameters of `Logger` (and subclasses) were changed accordingly. * Migration information at a glance: * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: `OnPolicyTrainerParams`, `OffPolicyTrainerParams`, `OfflineTrainerParams`. * Changed parameter default: Default for `test_in_train` was changed from True to False. * Changed parameter names to improve clarity: * `max_epoch` (`num_epochs` in high-level API) -> `max_epochs` * `step_per_epoch` -> `epoch_num_steps` * `episode_per_test` (`num_test_episodes` in high-level API) -> `test_step_num_episodes` * `step_per_collect` -> `collection_step_num_env_steps` * `episode_per_collect` -> collection_step_num_episodes` * `update_per_step` -> `update_step_num_gradient_steps_per_sample` * `repeat_per_collect` -> `update_step_num_repetitions` * Trainer classes have been renamed: * `OnpolicyTrainer` -> `OnPolicyTrainer` * `OffpolicyTrainer` -> `OffPolicyTrainer` * Method `run`: The parameter `reset_prior_to_run` was removed and replaced by `reset_collectors` (see above). * Methods `run` and `reset`: The parameter `reset_buffer` was renamed to `reset_collector_buffers` for clarity * Trainers are no longer iterators; manual usage (not using `run`) should simply call `reset` followed by calls of `execute_epoch`. ## Algorithms and Policies * We now conceptually differentiate between the learning algorithm and the policy being optimised: * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`, and the package was renamed from `tianshou.policy` to `tianshou.algorithm`. * Migration information: The instantiation of a policy is replaced by the instantiation of an `Algorithm`, which is passed a `Policy`. In most cases, the former policy class name `Policy` is replaced by algorithm class ``; exceptions are noted below. * `ImitationPolicy` -> `OffPolicyImitationLearning`, `OfflineImitationLearning` * `PGPolicy` -> `Reinforce` * `MultiAgentPolicyManager` -> `MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm` * `MARLRandomPolicy` -> `MARLRandomDiscreteMaskedOffPolicyAlgorithm` For the respective subtype of `Policy` to use, see the respective algorithm class' constructor. * Interface changes/improvements: * Core methods have been renamed (and removed from the public interface; #898): * `process_fn` -> `_preprocess_batch` * `post_process_fn` -> `_postprocess_batch` * `learn` -> `_update_with_batch` * The updating interface has been cleaned up (#949): * Functions `update` and `_update_with_batch` (formerly `learn`) no longer have `*args` and `**kwargs`. * Instead, the interfaces for the offline, off-policy and on-policy cases are properly differentiated. * New method `run_training`: The `Algorithm` abstraction can now directly initiate the learning process via this method. * `Algorithms` no longer require `torch.optim.Optimizer` instances and instead require `OptimizerFactory` instances, which create the actual optimizers internally. #959 The new `OptimizerFactory` abstraction simultaneously handles the creation of learning rate schedulers for the optimizers created (via method `with_lr_scheduler_factory` and accompanying factory abstraction `LRSchedulerFactory`). The parameter `lr_scheduler` has thus been removed from all algorithm constructors. * The flag `updating` has been removed (no internal usage, general usefulness questionable). * Removed `max_action_num`, instead read it off from `action_space` * Parameter changes: * `actor_step_size` -> `trust_region_size` in NP * `discount_factor` -> `gamma` (was already used internally almost everywhere) * `reward_normalization` -> `return_standardization` or `return_scaling` (more precise naming) or removed (was actually unsupported by Q-learning algorithms) * `return_standardization` in `Reinforce` and `DiscreteCRR` (as it applies standardization of returns) * `return_scaling` in actor-critic on-policy algorithms (A2C, PPO, GAIL, NPG, TRPO) * removed from Q-learning algorithms, where it was actually unsupported (DQN, C561, etc.) * `clip_grad` -> `max_grad_norm` (for consistency) * `clip_loss_grad` -> `huber_loss_delta` (allowing to control not only the use of the Huber loss but also its essential parameter) * `estimation_step` -> `n_step_return_horizon` (more precise naming) * Internal design improvements: * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) in `SAC`, `DiscreteSAC` and other algorithms. * Class hierarchy: * Abstract base class `Alpha` base class with value property and update method * `FixedAlpha` for constant entropy coefficients * `AutoAlpha` for automatic entropy tuning (replaces the old tuple-based representation) * The (auto-)updating logic is now completely encapsulated, reducing the complexity of the algorithms. * Implementations for continuous and discrete cases now share the same abstraction, making the codebase more consistent while preserving the original functionality. * Introduced a policy base class `ContinuousPolicyWithExplorationNoise` which encapsulates noise generation for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). * Multi-agent RL methods are now differentiated by the type of the sub-algorithms being employed (`MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm`), which renders all interfaces clean. Helper class `MARLDispatcher` has been factored out to manage the dispatching of data to the respective agents. * Algorithms now internally use a wrapper (`Algorithm.Optimizer`) around the optimizers; creation is handled by method `_create_optimizer`. * This facilitates backpropagation steps with gradient clipping. * The optimizers of an Algorithm instance are now centrally tracked, such that we can ensure that the optimizers' states are handled alongside the model parameters when calling `state_dict` or `load_state_dict` on the `Algorithm` instance. Special handling of the restoration of optimizers' state dicts was thus removed from examples and tests. * Lagged networks (target networks) are now conveniently handled via the new algorithm mixins `LaggedNetworkPolyakUpdateAlgorithmMixin` and `LaggedNetworkFullUpdateAlgorithmMixin`. Using these mixins, * a lagged network can simply be added by calling `_add_lagged_network` * the torch method `train` must no longer be overridden to ensure that the target networks are never set to train mode/remain in eval mode (which was prone to errors), * a method which updates all target networks with their source networks is automatically provided and does not need to be implemented specifically for every algorithm (`_update_lagged_network_weights`). All classes which make use of lagged networks were updated to use these mixins, simplifying the implementations and reducing the potential for implementation errors. (In the BCQ implementation, the VAE network was not correctly handled, but due to the way in which examples were structured, it did not result in an error.) * Fixed issues in the class hierarchy (particularly critical violations of the Liskov substitution principle): * Introduced base classes (to retain factorization without abusive inheritance): * `ActorCriticOnPolicyAlgorithm` * `ActorCriticOffPolicyAlgorithm` * `ActorDualCriticsOffPolicyAlgorithm` (extends `ActorCriticOffPolicyAlgorithm`) * `QLearningOffPolicyAlgorithm` * `A2C`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `Reinforce` * `BDQN`: * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` * Remove parameter `clip_loss_grad` (unused; only passed on to former base class) * Remove parameter `estimation_step`, for which only one option was valid * `C51`: * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) * `CQL`: * Inherit directly from `OfflineAlgorithm` instead of `SAC` (off-policy). * Remove parameter `estimation_step` (now `n_step_return_horizon`), which was not actually used (it was only passed it on to its superclass). * `DiscreteBCQ`: * Inherit directly from `OfflineAlgorithm` instead of `DQN` * Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to former the base class but actually unused. * `DiscreteCQL`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to base class `QRDQN` (and unused by it). * `DiscreteCRR`: Inherit directly from `OfflineAlgorithm` instead of `Reinforce` (on-policy) * `FQF`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to base class `QRDQN` (and unused by it). * `IQN`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to base class `QRDQN` (and unused by it). * `NPG`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `A2C` * `QRDQN`: * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) * `REDQ`: Inherit from `ActorCriticOffPolicyAlgorithm` instead of `DDPG` * `SAC`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` * `TD3`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` ## High-Level API * Detailed optimizer configuration (analogous to the procedural API) is now possible: * All optimizers can be configured in the respective algorithm-specific `Params` object by using `OptimizerFactoryFactory` instances as parameter values (e.g. `optim`, `actor_optim`, `critic_optim`, etc.). * Learning rate schedulers remain separate parameters and now use `LRSchedulerFactoryFactory` instances. The respective parameter names now use the suffix `lr_scheduler` instead of `lr_scheduler_factory` (as the precise nature need not be reflected in the name; brevity is preferable). * `SamplingConfig` is replaced by `TrainingConfig` and subclasses differentiating off-policy and on-policy cases appropriately (`OnPolicyTrainingConfig`, `OffPolicyTrainingConfig`). * The `test_in_train` parameter is now exposed (default False). * Inapplicable arguments can no longer be set in the respective subclass (e.g. `OffPolicyTrainingConfig` does not contain parameter `repeat_per_collect`). * All parameter names have been aligned with the new names used by `TrainerParams` (see above). * Add option to customize the factory for the collector (`ExperimentBuilder.with_collector_factory`), adding the abstraction `CollectorFactory`. #1256 ## Peripheral Changes * The `Actor` classes have been renamed for clarity (#1091): * `BaseActor` -> `Actor` * `continuous.ActorProb` -> `ContinuousActorProbabilistic` * `coninuous.Actor` -> `ContinuousActorDeterministic` * `discrete.Actor` -> `DiscreteActor` * The `Critic` classes have been renamed for clarity (#1091): * `continuous.Critic` -> `ContinuousCritic` * `discrete.Critic` -> `DiscreteCritic` * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. * Fix issues pertaining to the torch device assignment of network components (#810): * Remove 'device' member (and the corresponding constructor argument) from the following classes: `BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProbabilistic`, `ContinuousCritic`, `DiscreteActor`, `DiscreteCritic`, `DQNet`, `FullQuantileFunction`, `ImplicitQuantileNetwork`, `IntrinsicCuriosityModule`, `MLPActor`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`, `RecurrentActorProb`, `RecurrentCritic`, `VAE` * (Peripheral change:) Require the use of keyword arguments for the constructors of all of these classes * Clean up handling of modules that define attribute `output_dim`, introducing the explicit base class `ModuleWithVectorOutput` * Interfaces where one could specify either a module with `output_dim` or additionally provide the output dimension as an argument were changed to use `ModuleWithVectorOutput`. * The high-level API class `IntermediateModule` can now provide a `ModuleWithVectorOutput` instance (via adaptation if necessary). * The class hierarchy of supporting `nn.Module` implementations was cleaned up (#1091): * With the fundamental base classes `ActionReprNet` and `ActionReprNetWithVectorOutput`, we etablished a well-defined interface for the most commonly used `forward` interface in Tianshou's algorithms & policies. #948 * Some network classes were renamed: * `ScaledObsInputModule` -> `ScaledObsInputActionReprNet` * `Rainbow` -> `RainbowNet` * All modules containing base classes were renamed from `base` to a more descriptive name, rendering file names unique. # Release 1.2.0 (2025-06-23) This is the final release in the 1.x series before Tianshou v2.0.0. It resolves performance regressions introduced in v1.1.0 and resolves several issues, partly by backporting improvements from the upcoming v2.0.0 release. This release is brought to you by [Applied AI Institute gGmbH](https://www.appliedai-institute.de). Core developers: * Dr. Dominik Jain (@opcode81) * Michael Panchenko (@MischaPanch) ## Changes/Improvements - `trainer`: - Custom scoring now supported for selecting the best model. #1202 - `highlevel`: - `DiscreteSACExperimentBuilder`: Expose method `with_actor_factory_default` #1248 #1250 - `ActorFactoryDefault`: Fix parameters for hidden sizes and activation not being passed on in the discrete case (affects `with_actor_factory_default` method of experiment builders) - `ExperimentConfig`: Do not inherit from other classes, as this breaks automatic handling by `jsonargparse` when the class is used to define interfaces (as in high-level API examples) - `AutoAlphaFactoryDefault`: Differentiate discrete and continuous action spaces and allow coefficient to be modified, adding an informative docstring (previous implementation was reasonable only for continuous action spaces) - Adjust usage in `atari_sac_hl` example accordingly. - `NPGAgentFactory`, `TRPOAgentFactory`: Fix optimizer instantiation including the actor parameters (which was misleadingly suggested in the docstring in the respective policy classes; docstrings were fixed), as the actor parameters are intended to be handled via natural gradients internally - `data`: - `ReplayBuffer`: Fix collection of empty episodes being disallowed - Collection was slow due to `isinstance` checks on Protocols and due to Buffer integrity validation. This was solved by no longer performing `isinstance` on Protocols and by making the integrity validation disabled by default. - Tests: - We have introduced extensive **determinism tests** which allow to validate whether training processes deterministically compute the same results across different development branches. This is an important step towards ensuring reproducibility and consistency, which will be instrumental in supporting Tianshou developers in their work, especially in the context of algorithm development and evaluation. ## Breaking Changes - `trainer`: - `BaseTrainer.run` and `__iter__`: Resetting was never optional prior to running the trainer, yet the recently introduced parameter `reset_prior_to_run` of `run` suggested that it _was_ optional. Yet the parameter was ultimately not respected, because `__iter__` would always call `reset(reset_collectors=True, reset_buffer=False)` regardless. The parameter was removed; instead, the parameters of `run` now mirror the parameters of `reset`, and the implicit `reset` call in `__iter__` was removed. This aligns with upcoming changes in Tianshou v2.0.0. * NOTE: If you have been using a trainer without calling `run` but by directly iterating over it, you will need to call `reset` on the trainer explicitly before iterating over the trainer. * Using a trainer as an iterator is considered deprecated and support for this will be removed in Tianshou v2.0.0. - `data`: - `InfoStats` has a new non-optional field `best_score` which is used for selecting the best model. #1202 - `highlevel`: - Change the way in which seeding is handled: The mechanism introduced in v1.1.0 was completely revised: - The `training_seed` and `test_seed` attributes were removed from `SamplingConfig`. Instead, the seeds are derived from the seed defined in `ExperimentConfig`. - Seed attributes of `EnvFactory` classes were removed. Instead, seeds are passed to methods of `EnvFactory`. # Release 1.1.0 (2024-08-10) **NOTE**: This release introduced (potentially severe) performance regressions in data collection, please switch to a newer release for better performance. ## Highlights ### Evaluation Package This release introduces a new package `evaluation` that integrates best practices for running experiments (seeding test and train environmets) and for evaluating them using the [rliable](https://github.com/google-research/rliable) library. This should be especially useful for algorithm developers for comparing performances and creating meaningful visualizations. **This functionality is currently in alpha state** and will be further improved in the next releases. You will need to install tianshou with the extra `eval` to use it. The creation of multiple experiments with varying random seeds has been greatly facilitated. Moreover, the `ExpLauncher` interface has been introduced and implemented with several backends to support the execution of multiple experiments in parallel. An example for this using the high-level interfaces can be found [here](examples/mujoco/mujoco_ppo_hl_multi.py), examples that use low-level interfaces will follow soon. ### Improvements in Batch Apart from that, several important extensions have been added to internal data structures, most notably to `Batch`. Batches now implement `__eq__` and can be meaningfully compared. Applying operations in a nested fashion has been significantly simplified, and checking for NaNs and dropping them is now possible. One more notable change is that torch `Distribution` objects are now sliced when slicing a batch. Previously, when a Batch with say 10 actions and a dist corresponding to them was sliced to `[:3]`, the `dist` in the result would still correspond to all 10 actions. Now, the dist is also "sliced" to be the distribution of the first 3 actions. A detailed list of changes can be found below. ## Changes/Improvements - `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 #1141 #1183 - `data`: - `Batch`: - Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098 - Add methods `to_numpy_` and `to_torch_`. #1098, #1117 - Add `__eq__` (semantic equality check). #1098 - `keys()` deprecated in favor of `get_keys()` (needed to make iteration consistent with naming) #1105. - Major: new methods for applying functions to values, to check for NaNs and drop them, and to set values. #1181 - Slicing a batch with a torch distribution now also slices the distribution. #1181 - `data.collector`: - `Collector`: - Introduced `BaseCollector` as a base class for all collectors. #1123 - Add method `close` #1063 - Method `reset` is now more granular (new flags controlling behavior). #1063 - `CollectStats`: Add convenience constructor `with_autogenerated_stats`. #1063 - `trainer`: - Trainers can now control whether collectors should be reset prior to training. #1063 - `policy`: - introduced attribute `in_training_step` that is controlled by the trainer. #1123 - policy automatically set to `eval` mode when collecting and to `train` mode when updating. #1123 - Extended interface of `compute_action` to also support array-like inputs #1169 - `highlevel`: - `SamplingConfig`: - Add support for `batch_size=None`. #1077 - Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074 - `experiment`: - `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and which determines the default run name and therefore the persistence subdirectory. It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than `experiment_name` (although the latter will still be interpreted correctly). #1074 #1131 - Add class `ExperimentCollection` for the convenient execution of multiple experiment runs #1131 - The `World` object, containing all low-level objects needed for experimentation, can now be extracted from an `Experiment` instance. This enables customizing the experiment prior to its execution, bridging the low and high-level interfaces. #1187 - `ExperimentBuilder`: - Add method `build_seeded_collection` for the sound creation of multiple experiments with varying random seeds #1131 - Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131 - `env`: - Added new `VectorEnvType` called `SUBPROC_SHARED_MEM_AUTO` and used in for Atari and Mujoco venv creation. #1141 - `utils`: - `logger`: - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 - Wandb logger extended #1183 - `net.continuous.Critic`: - Add flag `apply_preprocess_net_to_obs_only` to allow the preprocessing network to be applied to the observations only (without the actions concatenated), which is essential for the case where we want to reuse the actor's preprocessing network #1128 - `torch_utils` (new module) - Added context managers `torch_train_mode` and `policy_within_training_step` #1123 - `print` - `DataclassPPrintMixin` now supports outputting a string, not just printing the pretty repr. #1141 ## Fixes - `highlevel`: - `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics, fixing the case where we want to reuse an actor's preprocessing network for the critic (affects usages of the experiment builder method `with_critic_factory_use_actor` with continuous environments) #1128 - Policy parameter `action_scaling` value `"default"` was not correctly transformed to a Boolean value for algorithms SAC, DDPG, TD3 and REDQ. The value `"default"` being truthy caused action scaling to be enabled even for discrete action spaces. #1191 - `atari_network.DQN`: - Fix constructor input validation #1128 - Fix `output_dim` not being set if `features_only`=True and `output_dim_added_layer` is not None #1128 - `PPOPolicy`: - Fix `max_batchsize` not being used in `logp_old` computation inside `process_fn` #1168 - Fix `Batch.__eq__` to allow comparing Batches with scalar array values #1185 ## Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Improved typing for `exploration_noise` and within Collector. #1063 - Better variable names related to model outputs (logits, dist input etc.). #1032 - Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc., instead of just `nn.Module`. #1032 - Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032 - Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032 - Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032 - Exception no longer raised on `len` of empty `Batch`. #1084 - tests and examples are covered by `mypy`. #1077 - `Actor` is more used, stricter typing by making it generic. #1077 - Use explicit multiprocessing context for creating `Pipe` in `subproc.py`. #1102 ## Breaking Changes - `data`: - `Collector`: - Removed `.data` attribute. #1063 - Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063 - Removed `no_grad` argument from `collect` method (was unused in tianshou). #1123 - `Batch`: - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 - The methods `to_numpy` and `to_torch` in are not in-place anymore (use `to_numpy_` or `to_torch_` instead). #1098, #1117 - The method `Batch.is_empty` has been removed. Instead, the user can simply check for emptiness of Batch by using `len` on dicts. #1144 - Stricter `cat_`, only concatenation of batches with the same structure is allowed. #1181 - `to_torch` and `to_numpy` are no longer static methods. So `Batch.to_numpy(batch)` should be replaced by `batch.to_numpy()`. #1200 - `utils`: - `logger`: - `BaseLogger.prepare_dict_for_logging` is now abstract. #1074 - Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074 - `utils.net`: - `Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 - Modules with code that was copied from sensAI have been replaced by imports from new dependency sensAI-utils: - `tianshou.utils.logging` is replaced with `sensai.util.logging` - `tianshou.utils.string` is replaced with `sensai.util.string` - `tianshou.utils.pickle` is replaced with `sensai.util.pickle` - `env`: - All VectorEnvs now return a numpy array of info-dicts on reset instead of a list. #1063 - `policy`: - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 - `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074 - `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074 - `highlevel`: - `params`: The parameter `dist_fn` has been removed from the parameter objects (`PGParams`, `A2CParams`, `PPOParams`, `NPGParams`, `TRPOParams`). The correct distribution is now determined automatically based on the actor factory being used, avoiding the possibility of misspecification. Persisted configurations/policies continue to work as expected, but code must not specify the `dist_fn` parameter. #1194 #1195 - `env`: - `EnvFactoryRegistered`: parameter `seed` has been replaced by the pair of parameters `training_seed` and `test_seed` Persisted instances will continue to work correctly. Subclasses such as `AtariEnvFactory` are also affected requires explicit train and test seeds. #1074 - `VectorEnvType`: `SUBPROC_SHARED_MEM` has been replaced by `SUBPROC_SHARED_MEM_DEFAULT`. It is recommended to use `SUBPROC_SHARED_MEM_AUTO` instead. However, persisted configs will continue working. #1141 ## Tests - Fixed env seeding it `test_sac_with_il.py` so that the test doesn't fail randomly. #1081 - Improved CI triggers and added telemetry (if requested by user) #1177 - Improved environment used in tests. - Improved tests bach equality to check with scalar values #1185 ## Dependencies - [DeepDiff](https://github.com/seperman/deepdiff) added to help with diffs of batches in tests. #1098 - Bumped black, idna, pillow - New extra "eval" - Bumped numba to >=60.0.0, permitting installation on python 3.12 # 1177 - New dependency sensai-utils # Release 1.0.0 (2024-03-20) This release focuses on updating and improving Tianshou internals (in particular, code quality) while creating relatively few breaking changes (apart from things like the python and dependencies' versions). We view it as a significant step for transforming Tianshou into the go-to place both for RL researchers, as well as for RL practitioners working on industry projects.   This is the first release after the [appliedAI Institute](https://www.appliedai-institute.de/en/) (the [TransferLab](https://transferlab.ai/) division) has decided to further develop Tianshou and provide long-term support.  ## Breaking Changes - dropped support of python<3.11 - dropped support of gym, from now on only Gymnasium envs are supported - removed functions like `offpolicy_trainer` in favor of `OffpolicyTrainer(...).run()` (this affects all example scripts) - several breaking changes related to removing `**kwargs` from signatures, renamings of internal attributes (like `critic1` -> `critic`) - Outputs of training methods are now dataclasses instead of dicts ## Functionality Extensions ### Major - High level interfaces for experiments, demonstrated by the new example scripts with names ending in `_hl.py` ### Minor - Method to compute action directly from a policy's observation, can be used for unrolling - Support for custom keys in ReplayBuffer - Support for CalQL as part of CQL - Support for explicit setting of multiprocessing context for SubprocEnvWorker - `critic2` no longer has to be explicitly constructed and passed if it is supposed to be the same network as `critic` (formerly `critic1`) ## Internal Improvements ### Build and Docs - Completely changed the build pipeline. Tianshou now uses poetry, black, ruff, poethepoet, nbqa and other niceties. - Notebook tutorials are now part of the repository (previously they were in a drive). They were fixed and are executed during the build as integration tests, in addition to serving as documentation. Parts of the content have been improved. - Documentation is now built with jupyter book. JavaScript code has been slightly improved, JS dependencies are included as part of the repository. - Many improvements in docstrings ### Typing - Adding `BatchPrototypes` to cover the fields needed and returned by methods relying on batches in a backwards compatible way - Removing `**kwargs` from policies' constructors - Overall, much stricter and more correct typing. Removing `kwargs` and replacing dicts by dataclasses in several places. - Making use of `Generic` to express different kinds of stats that can be returned by `learn` and `update` - Improved typing in `tests` and `examples`, close to passing mypy ### General - Reduced duplication, improved readability and simplified code in several places - Use `dist.mode` instead of inferring `loc` or `argmax` from the `dist_fn` input ## Contributions ### The OG creators - @Trinkle23897 participated in almost all aspects of the coordination and reviewed most of the merged PRs - @nuance1979 participated in several discussions ### From appliedAI The team working on this release of Tianshou consisted of @opcode81 @MischaPanch @maxhuettenrauch @carlocagnetta @bordeauxred ### External contributions - @BFAnas participated in several discussions and contributed the CalQL implementation, extending the pre-processing logic. - @dantp-ai fixed many mypy issues and improved the tests - @arnaujc91 improved the logic of computing deterministic actions - Many other contributors, among them many new ones participated in this release. The Tianshou team is very grateful for your contributions! # Older Releases See [releases on GitHub](https://github.com/thu-ml/tianshou/releases) ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Tianshou Please refer to the ['Developer Guide' on tianshou.org](https://tianshou.org/en/latest/04_developer_guide/developer_guide.html). ================================================ FILE: Dockerfile ================================================ # Use the official Python image for the base image. FROM --platform=linux/amd64 python:3.11-slim # Set environment variables to make Python print directly to the terminal and avoid .pyc files. ENV PYTHONUNBUFFERED=1 ENV PYTHONDONTWRITEBYTECODE=1 # Install system dependencies required for the project. RUN apt-get update && apt-get install -y --no-install-recommends \ curl \ build-essential \ git \ wget \ unzip \ libvips-dev \ gnupg2 \ && rm -rf /var/lib/apt/lists/* # Install pipx. RUN python3 -m pip install --no-cache-dir pipx \ && pipx ensurepath # Add poetry to the path ENV PATH="${PATH}:/root/.local/bin" # Install the latest version of Poetry using pipx. RUN pipx install poetry # Set the working directory. IMPORTANT: can't be changed as needs to be in sync to the dir where the project is cloned # to in the codespace WORKDIR /workspaces/tianshou # Copy the pyproject.toml and poetry.lock files (if available) into the image. COPY pyproject.toml poetry.lock* README.md /workspaces/tianshou/ RUN poetry config virtualenvs.create false RUN poetry install --no-root --with dev # The entrypoint will perform an editable install, it is expected that the code is mounted in the container then # If you don't want to mount the code, you should override the entrypoint ENTRYPOINT ["/bin/bash", "-c", "poetry install --with dev && poetry run jupyter trust notebooks/*.ipynb docs/02_notebooks/*.ipynb && $0 $@"] ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 Tianshou contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ include LICENSE ================================================ FILE: README.md ================================================
--- [![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) [![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock) [![Read the Docs](https://readthedocs.org/projects/tianshou/badge/?version=master)](https://tianshou.org/en/master/) [![Pytest](https://github.com/thu-ml/tianshou/actions/workflows/pytest.yml/badge.svg)](https://github.com/thu-ml/tianshou/actions) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) [![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues) [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) [![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE) > [!NOTE] > **Tianshou version 2 is here!** > > We have released the new major version of Tianshou on PyPI. > Version 2 is a complete overhaul of the software design of the procedural API, in which > * we establish a clear separation between learning algorithms and policies (via the separate abstractions `Algorithm` and `Policy`). > * we provide more well-defined, more usable interfaces with extensive documentation of all algorithm and trainer parameters, > renaming some parameters to make their names more consistent and intuitive. > * the class hierarchy is fully revised, establishing a clear separation between on-policy, off-policy and offline algorithms > at the type level and ensuring that all inheritance relationships are meaningful. > > Because of the extent of the changes, this version is not backwards compatible with previous versions of Tianshou. > For migration information, please see the [change log](CHANGELOG.md). **Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning (RL) library based on pure PyTorch and [Gymnasium](http://github.com/Farama-Foundation/Gymnasium). Tianshou's main features at a glance are: 1. Modular low-level interfaces for algorithm developers (RL researchers) that are both flexible, hackable and type-safe. 1. Convenient high-level interfaces for applications of RL (training an implemented algorithm on a custom environment). 1. Large scope: online (on- and off-policy) and offline RL, experimental support for multi-agent RL (MARL), experimental support for model-based RL, and more Unlike other reinforcement learning libraries, which may have complex codebases, unfriendly high-level APIs, or are not optimized for speed, Tianshou provides a high-performance, modularized framework and user-friendly interfaces for building deep reinforcement learning agents. One more aspect that sets Tianshou apart is its generality: it supports online and offline RL, multi-agent RL, and model-based algorithms. Tianshou aims at enabling concise implementations, both for researchers and practitioners, without sacrificing flexibility. Supported algorithms include: - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) - [Double DQN](https://arxiv.org/pdf/1509.06461.pdf) - [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf) - [Branching DQN](https://arxiv.org/pdf/1711.08946.pdf) - [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) - [Rainbow DQN (Rainbow)](https://arxiv.org/pdf/1710.02298.pdf) - [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf) - [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf) - [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf) - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) - [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Randomized Ensembled Double Q-Learning (REDQ)](https://arxiv.org/pdf/2101.05982.pdf) - [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - [Vanilla Imitation Learning](https://en.wikipedia.org/wiki/Apprenticeship_learning) - [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf) - [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf) - [Twin Delayed DDPG with Behavior Cloning (TD3+BC)](https://arxiv.org/pdf/2106.06860.pdf) - [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf) - [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf) - [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf) - [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/pdf/1606.03476.pdf) - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) - [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf) - [Hindsight Experience Replay (HER)](https://arxiv.org/pdf/1707.01495.pdf) Other noteworthy features: - Elegant framework with dual APIs: - Tianshou's high-level API maximizes ease of use for application development while still retaining a high degree of flexibility. - The fundamental procedural API provides a maximum of flexibility for algorithm development without being overly verbose. - State-of-the-art results in [MuJoCo benchmarks](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms - Support for vectorized environments (synchronous or asynchronous) for all algorithms (see [usage](https://tianshou.readthedocs.io/en/master/01_tutorials/07_cheatsheet.html#parallel-sampling)) - Support for super-fast vectorized environments based on [EnvPool](https://github.com/sail-sg/envpool/) for all algorithms (see [usage](https://tianshou.readthedocs.io/en/master/01_tutorials/07_cheatsheet.html#envpool-integration)) - Support for recurrent state representations in actor networks and critic networks (RNN-style training for POMDPs) (see [usage](https://tianshou.readthedocs.io/en/master/01_tutorials/07_cheatsheet.html#rnn-style-training)) - Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/01_tutorials/07_cheatsheet.html#user-defined-environment-and-different-state-representation) - Support for customized training processes (see [usage](https://tianshou.readthedocs.io/en/master/01_tutorials/07_cheatsheet.html#customize-training-process)) - Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are highly optimized thanks to numba's just-in-time compilation and vectorized numpy operations - Support for multi-agent RL (see [usage](https://tianshou.readthedocs.io/en/master/01_tutorials/07_cheatsheet.html#multi-agent-reinforcement-learning)) - Support for logging based on both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) - Support for multi-GPU training (see [usage](https://tianshou.readthedocs.io/en/master/01_tutorials/07_cheatsheet.html#multi-gpu)) - Comprehensive documentation, PEP8 code-style checking, type checking and thorough [tests](https://github.com/thu-ml/tianshou/actions) In Chinese, Tianshou means divinely ordained, being derived to the gift of being born. Tianshou is a reinforcement learning platform, and the nature of RL is not learn from humans. So taking "Tianshou" means that there is no teacher to learn from, but rather to learn by oneself through constant interaction with the environment. “天授”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,而强化学习算法并不是向人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。 ## Installation Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/) and [conda-forge](https://github.com/conda-forge/tianshou-feedstock). It requires Python >= 3.11. For installing the most recent version of Tianshou, the best way is clone the repository and install it with [poetry](https://python-poetry.org/) (which you need to install on your system first) ```bash git clone git@github.com:thu-ml/tianshou.git cd tianshou poetry install ``` You can also install the dev requirements by adding `--with dev` or the extras for say mujoco and acceleration by [envpool](https://github.com/sail-sg/envpool) by adding `--extras "mujoco envpool"` If you wish to install multiple extras, ensure that you include them in a single command. Sequential calls to `poetry install --extras xxx` will overwrite prior installations, leaving only the last specified extras installed. Or you may install all the following extras by adding `--all-extras`. Available extras are: - `atari` (for Atari environments) - `box2d` (for Box2D environments) - `classic_control` (for classic control (discrete) environments) - `mujoco` (for MuJoCo environments) - `mujoco-py` (for legacy mujoco-py environments[^1]) - `pybullet` (for pybullet environments) - `robotics` (for gymnasium-robotics environments) - `vizdoom` (for ViZDoom environments) - `envpool` (for [envpool](https://github.com/sail-sg/envpool/) integration) - `argparse` (in order to be able to run the high level API examples) [^1]: `mujoco-py` is a legacy package and is not recommended for new projects. It is only included for compatibility with older projects. Also note that there may be compatibility issues with macOS newer than Monterey. Otherwise, you can install the latest release from PyPI (currently far behind the master) with the following command: ```bash $ pip install tianshou ``` If you are using Anaconda or Miniconda, you can install Tianshou from conda-forge: ```bash $ conda install tianshou -c conda-forge ``` Alternatively to the poetry install, you can also install the latest source version through GitHub: ```bash $ pip install git+https://github.com/thu-ml/tianshou.git@master --upgrade ``` Finally, you may check the installation via your Python console as follows: ```python import tianshou print(tianshou.__version__) ``` If no errors are reported, you have successfully installed Tianshou. ## Documentation Find example scripts in the [test/]( https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders. Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/). ## Why Tianshou? ### Comprehensive Functionality ### High Software Engineering Standards | RL Platform | Documentation | Code Coverage | Type Hints | Last Update | | ------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------ | ----------------------------------------------------------------------------------------------------------------- | | [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) | [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) | :heavy_check_mark: | ![GitHub last commit](https://img.shields.io/github/last-commit/DLR-RM/stable-baselines3?label=last%20update) | | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [![](https://readthedocs.org/projects/ray/badge/?version=master)](http://docs.ray.io/en/master/rllib.html) | :heavy_minus_sign:(1) | :heavy_check_mark: | ![GitHub last commit](https://img.shields.io/github/last-commit/ray-project/ray?label=last%20update) | | [SpinningUp](https://github.com/openai/spinningup) | [![](https://img.shields.io/readthedocs/spinningup)](https://spinningup.openai.com/) | :x: | :x: | ![GitHub last commit](https://img.shields.io/github/last-commit/openai/spinningup?label=last%20update) | | [Dopamine](https://github.com/google/dopamine) | [![](https://img.shields.io/badge/docs-passing-green)](https://github.com/google/dopamine/tree/master/docs) | :x: | :x: | ![GitHub last commit](https://img.shields.io/github/last-commit/google/dopamine?label=last%20update) | | [ACME](https://github.com/deepmind/acme) | [![](https://img.shields.io/badge/docs-passing-green)](https://github.com/deepmind/acme/blob/master/docs/index.md) | :heavy_minus_sign:(1) | :heavy_check_mark: | ![GitHub last commit](https://img.shields.io/github/last-commit/deepmind/acme?label=last%20update) | | [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [:heavy_minus_sign:](https://arxiv.org/abs/2006.11751) | [![codecov](https://codecov.io/gh/alex-petrenko/sample-factory/branch/master/graph/badge.svg)](https://codecov.io/gh/alex-petrenko/sample-factory) | :x: | ![GitHub last commit](https://img.shields.io/github/last-commit/alex-petrenko/sample-factory?label=last%20update) | | | | | | | | [Tianshou](https://github.com/thu-ml/tianshou) | [![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master) | [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) | :heavy_check_mark: | ![GitHub last commit](https://img.shields.io/github/last-commit/thu-ml/tianshou?label=last%20update) | (1): it has continuous integration but the coverage rate is not available ### Reproducible, High-Quality Results Tianshou is rigorously tested. In contrast to other RL platforms, **our tests include the full agent training procedure for all of the implemented algorithms**. Our tests would fail once if any of the agents failed to achieve a consistent level of performance on limited epochs. Our tests thus ensure reproducibility. Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page for more detail. Atari and MuJoCo benchmark results can be found in the [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders respectively. **Our MuJoCo results reach or exceed the level of performance of most existing benchmarks.** ### Algorithm Abstraction Reinforcement learning algorithms are build on abstractions for - on-policy algorithms (`OnPolicyAlgorithm`), - off-policy algorithms (`OffPolicyAlgorithm`), and - offline algorithms (`OfflineAlgorithm`), all of which clearly separate the core algorithm from the training process and the respective environment interactions. In each case, the implementation of an algorithm necessarily involves only the implementation of methods for - pre-processing a batch of data, augmenting it with necessary information/sufficient statistics for learning (`_preprocess_batch`), - updating model parameters based on an augmented batch of data (`_update_with_batch`). The implementation of these methods suffices for a new algorithm to be applicable within Tianshou, making experimentation with new approaches particularly straightforward. ## Quick Start Tianshou provides two API levels: - the high-level interface, which provides ease of use for end users seeking to run deep reinforcement learning applications - the procedural interface, which provides a maximum of control, especially for very advanced users and developers of reinforcement learning algorithms. In the following, let us consider an example application using the _CartPole_ gymnasium environment. We shall apply the deep Q-network (DQN) learning algorithm using both APIs. ### High-Level API In the high-level API, the basis for an RL experiment is an `ExperimentBuilder` with which we can build the experiment we then seek to run. Since we want to use DQN, we use the specialization `DQNExperimentBuilder`. The high-level API provides largely declarative semantics, i.e. the code is almost exclusively concerned with configuration that controls what to do (rather than how to do it). ```python from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.env import ( EnvFactoryRegistered, VectorEnvType, ) from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig from tianshou.highlevel.params.algorithm_params import DQNParams from tianshou.highlevel.trainer import ( EpochStopCallbackRewardThreshold, ) experiment = ( DQNExperimentBuilder( EnvFactoryRegistered( task="CartPole-v1", venv_type=VectorEnvType.DUMMY, training_seed=0, test_seed=10, ), ExperimentConfig( persistence_enabled=False, watch=True, watch_render=1 / 35, watch_num_episodes=100, ), OffPolicyTrainingConfig( max_epochs=10, epoch_num_steps=10000, batch_size=64, num_training_envs=10, num_test_envs=100, buffer_size=20000, collection_step_num_env_steps=10, update_step_num_gradient_steps_per_sample=1 / 10, ), ) .with_dqn_params( DQNParams( lr=1e-3, gamma=0.9, n_step_return_horizon=3, target_update_freq=320, eps_training=0.3, eps_inference=0.0, ), ) .with_model_factory_default(hidden_sizes=(64, 64)) .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) .build() ) experiment.run() ``` The experiment builder takes three arguments: - the environment factory for the creation of environments. In this case, we use an existing factory implementation for gymnasium environments. - the experiment configuration, which controls persistence and the overall experiment flow. In this case, we have configured that we want to observe the agent's behavior after it is trained (`watch=True`) for a number of episodes (`watch_num_episodes=100`). We have disabled persistence, because we do not want to save training logs, the agent or its configuration for future use. - the training configuration, which controls fundamental training parameters, such as the total number of epochs we run the experiment for (`num_epochs=10`) and the number of environment steps each epoch shall consist of (`epoch_num_steps=10000`). Every epoch consists of a series of data collection (rollout) steps and training steps. The parameter `collection_step_num_env_steps` controls the amount of data that is collected in each collection step and after each collection step, we perform a training step, applying a gradient-based update based on a sample of data (`batch_size=64`) taken from the buffer of data that has been collected. For further details, see the documentation of configuration class. We then proceed to configure some of the parameters of the DQN algorithm itself: For instance, we control the epsilon parameter for exploration. We want to use random exploration during rollouts for training (`eps_training`), but we don't when evaluating the agent's performance in the test environments (`eps_inference`). Furthermore, we configure model parameters of the network for the Q function, parametrising the number of hidden layers of the default MLP factory. Find the script in [examples/discrete/discrete_dqn_hl.py](examples/discrete/discrete_dqn_hl.py). Here's a run (with the training time cut short):

Find many further applications of the high-level API in the `examples/` folder; look for scripts ending with `_hl.py`. Note that most of these examples require the extra `argparse` (install it by adding `--extras argparse` when invoking poetry). ### Procedural API Let us now consider an analogous example in the procedural API. Find the full script in [examples/discrete/discrete_dqn.py](https://github.com/thu-ml/tianshou/blob/master/examples/discrete/discrete_dqn.py). First, import the relevant packages: ```python import gymnasium as gym import tianshou as ts from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import CollectStats from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo from torch.utils.tensorboard import SummaryWriter ``` Define hyper-parameters: ```python task = 'CartPole-v1' lr, epoch, batch_size = 1e-3, 10, 64 num_training_envs, num_test_envs = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 epoch_num_steps, collection_step_num_env_steps = 10000, 10 ``` Initialize the logger: ```python logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) ``` Create the environments: ```python # You can also try SubprocVectorEnv, which will use parallelization training_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) ``` Create the network, policy, and algorithm: ```python # Create the network # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network env = gym.make(task, render_mode="human") assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) state_shape = space_info.observation_info.obs_shape action_shape = space_info.action_info.action_shape net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) optim = AdamOptimizerFactory(lr=lr) # Create the policy policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=eps_train, eps_inference=eps_test ) # Create the algorithm with the policy and optimizer factory algorithm = DQN( policy=policy, optim=AdamOptimizerFactory(lr=lr), gamma=gamma, n_step_return_horizon=n_step, target_update_freq=target_freq ) ``` Set up the collectors: ```python training_collector = ts.data.Collector[CollectStats]( algorithm, training_envs, ts.data.VectorReplayBuffer(buffer_size, num_training_envs), exploration_noise=True, ) test_collector = ts.data.Collector[CollectStats]( algorithm, test_envs, exploration_noise=True, ) ``` Let's train the model using the algorithm: ```python result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps, stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, logger=logger, test_in_training=True, ) ) print(f"Finished training in {result.timing.total_time} seconds") ``` This is how you could manually save/load the trained policy (it's exactly the same as loading a `torch.nn.module`): ```python torch.save(policy.state_dict(), 'dqn.pth') policy.load_state_dict(torch.load('dqn.pth')) ``` Now let's watch the agent with 35 FPS: ```python collector = ts.data.Collector(policy, env, exploration_noise=True) collector.collect(n_episode=1, render=1 / 35) ``` Inspect the data saved in TensorBoard: ```bash $ tensorboard --logdir log/dqn ``` Please read the [documentation](https://tianshou.readthedocs.io) for advanced usage. ## Contributing Tianshou is still under development. Further algorithms and features are continuously being added, and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.org/en/master/04_contributing/04_contributing.html). ## Citing Tianshou If you find Tianshou useful, please cite it in your publications. ```latex @article{tianshou, author = {Jiayi Weng and Huayu Chen and Dong Yan and Kaichao You and Alexis Duburcq and Minghao Zhang and Yi Su and Hang Su and Jun Zhu}, title = {Tianshou: A Highly Modularized Deep Reinforcement Learning Library}, journal = {Journal of Machine Learning Research}, year = {2022}, volume = {23}, number = {267}, pages = {1--6}, url = {http://jmlr.org/papers/v23/21-1127.html} } ``` ## Acknowledgments Tianshou is supported by [appliedAI Institute for Europe](https://www.appliedai-institute.de/en/), who is committed to providing long-term support and development. Tianshou was previously a reinforcement learning platform based on TensorFlow. You can check out the branch [`priv`](https://github.com/thu-ml/tianshou/tree/priv) for more detail. Many thanks to [Haosheng Zou](https://github.com/HaoshengZou)'s pioneering work for Tianshou before version 0.1.1. We would like to thank [TSAIL](http://ml.cs.tsinghua.edu.cn/) and [Institute for Artificial Intelligence, Tsinghua University](http://ml.cs.tsinghua.edu.cn/thuai/) for providing such an excellent AI research platform. ================================================ FILE: benchmark/run_benchmark.py ================================================ """Benchmark orchestration script for evaluating Tianshou's algorithm implementations. This module provides automated benchmarking capabilities for reinforcement learning algorithms across different environments (Atari, MuJoCo). It manages parallel experiment execution using tmux sessions, handles experiment lifecycle, and aggregates results. Key features: - Discovers and runs multiple RL algorithm scripts in parallel - Manages concurrency limits to prevent resource exhaustion - Each script runs in its own isolated tmux session for easy monitoring - Supports multiple tasks/environments per benchmark run - Aggregates rliable evaluation results into a unified format - Configurable experiment parameters (epochs, environments, parallel workers) - Filtering capabilities to run subsets of algorithms or tasks The script is designed to be run from the command line, allowing easy customization of benchmark parameters without code modification. Example usage: python run_benchmark.py --benchmark_type mujoco --num_experiments 5 --max_concurrent_sessions 4 """ import json import subprocess import sys import time from pathlib import Path from typing import Literal from sensai.util import logging from sensai.util.logging import datetime_tag TMUX_SESSION_PREFIX = "tianshou" # Sleep durations in seconds TMUX_SESSION_START_DELAY = 2 SESSION_CHECK_INTERVAL = 5 COMPLETION_CHECK_INTERVAL = 10 log = logging.getLogger("benchmark") # Default tasks for each benchmark type DEFAULT_TASKS = { "mujoco": [ "Ant-v4", "HalfCheetah-v4", "Hopper-v4", "Humanoid-v4", "InvertedDoublePendulum-v4", "InvertedPendulum-v4", "Reacher-v4", "Swimmer-v4", "Walker2d-v4", ], "atari": [ "PongNoFrameskip-v4", "BreakoutNoFrameskip-v4", "EnduroNoFrameskip-v4", "QbertNoFrameskip-v4", "MsPacmanNoFrameskip-v4", "SeaquestNoFrameskip-v4", "SpaceInvadersNoFrameskip-v4", ], } def find_script_paths( benchmark_type: str, exclude_filter: str | None = None, include_filter: str = "**/*_hl.py" ) -> list[str]: """Return all Python scripts matching the glob filter under examples/.""" base_dir = Path(__file__).parent.parent / "examples" / benchmark_type if not base_dir.exists(): raise FileNotFoundError(f"Directory '{base_dir}' does not exist.") scripts = sorted(str(p) for p in base_dir.glob(include_filter)) if not scripts: raise FileNotFoundError( f"Did not find any scripts matching '{include_filter}' in '{base_dir}'." ) # Apply exclusion filter if provided if exclude_filter: scripts = [s for s in scripts if not Path(s).match(exclude_filter)] if not scripts: raise FileNotFoundError( f"No scripts remaining after applying exclude filter '{exclude_filter}'." ) return scripts def get_current_tmux_sessions(benchmark_type: str) -> list[str]: """List active tmux sessions starting with TMUX_SESSION_PREFIX.""" try: output = subprocess.check_output(["tmux", "list-sessions"], stderr=subprocess.DEVNULL) sessions = [ line.split(b":")[0].decode() for line in output.splitlines() if line.startswith(f"{TMUX_SESSION_PREFIX}_{benchmark_type}".encode()) ] return sessions except subprocess.CalledProcessError: return [] def start_tmux_session( script_path: str, persistence_base_dir: Path | str, num_experiments: int, benchmark_type: str, task: str, max_epochs: int | None = None, epoch_num_steps: int | None = None, experiment_launcher: Literal["sequential", "joblib"] | None = None, num_training_envs: int | None = None, num_test_envs: int | None = None, ) -> bool: """Start a tmux session running the given experiment script, returning True on success.""" # Normalize paths for Git Bash / Windows compatibility python_exec = sys.executable.replace("\\", "/") script_path = script_path.replace("\\", "/") persistence_base_dir = str(persistence_base_dir).replace("\\", "/") # Include task name in session to avoid collisions when running multiple tasks script_name = Path(script_path).name.replace("_hl.py", "") # Remove benchmark_type from name since we add it explicitly below script_name = script_name.replace(benchmark_type, "").strip("_") session_name = f"{TMUX_SESSION_PREFIX}_{benchmark_type}_{task}_{script_name}" # Build command with optional max_epochs and epoch_num_steps cmd_args = f"{python_exec} {script_path} --num_experiments {num_experiments} --persistence_base_dir {persistence_base_dir} --task {task}" if max_epochs is not None: cmd_args += f" --max_epochs {max_epochs}" if epoch_num_steps is not None: cmd_args += f" --epoch_num_steps {epoch_num_steps}" if experiment_launcher is not None: cmd_args += f" --experiment_launcher {experiment_launcher}" if num_training_envs is not None: cmd_args += f" --num_training_envs {num_training_envs}" if num_test_envs is not None: cmd_args += f" --num_test_envs {num_test_envs}" cmd = [ "tmux", "new-session", "-d", "-s", session_name, f"{cmd_args}; echo 'Finished {script_path}'; tmux kill-session -t {session_name}", ] try: subprocess.run(cmd, check=True) log.info( f"Started {script_path} in session '{session_name}'. Attach with:\ntmux attach -t {session_name}" ) return True except subprocess.CalledProcessError as e: log.error(f"Failed to start {script_path} (session {session_name}): {e}") return False def aggregate_rliable_results(task_results_dir: str | Path) -> None: """Aggregate rliable results from all experiments into a single results.json per environment. This form is expected by `benchmark.js` in the docs. """ task_results_dir = Path(task_results_dir) if not task_results_dir.exists(): log.warning(f"Benchmark results directory does not exist: '{task_results_dir}'") return experiment_dirs = [d for d in task_results_dir.iterdir() if d.is_dir()] aggregated_results = [] for experiment_dir in experiment_dirs: agent_name = experiment_dir.name.split("Experiment")[0] if not agent_name: log.warning( f"Could not extract agent name from directory: '{experiment_dir.name}', skipping..." ) continue rliable_file = experiment_dir / "rliable_evaluation_test.json" if not rliable_file.exists(): log.warning(f"Missing rliable results file: '{rliable_file}', skipping...") continue try: with open(rliable_file) as f: result_entries = json.load(f) for result_entry in result_entries: result_entry["agent"] = agent_name aggregated_results.append(result_entry) except (OSError, json.JSONDecodeError) as e: log.error(f"Failed to read or parse '{rliable_file}': {e}") continue if not aggregated_results: log.warning(f"No results to aggregate for directory '{task_results_dir}'") return aggregated_results_path = task_results_dir / "results.json" try: with open(aggregated_results_path, "w") as f: json.dump(aggregated_results, f, indent=4) log.info(f"Aggregated {len(aggregated_results)} results to '{aggregated_results_path}'.") except OSError as e: log.error(f"Failed to write aggregated results to '{aggregated_results_path}': {e}") def main( max_concurrent_sessions: int | None = None, benchmark_type: Literal["mujoco", "atari"] = "atari", num_experiments: int = 1, max_scripts: int = -1, tasks: list[str] | None = None, max_tasks: int = -1, max_epochs: int | None = None, epoch_num_steps: int | None = None, num_training_envs: int | None = None, num_test_envs: int | None = None, experiment_launcher: Literal["sequential", "joblib"] | None = "sequential", include_filter: str = "**/*_hl.py", exclude_filter: str | None = None, ) -> None: """ Run the benchmarking by executing each selected script in its default configuration (apart from explicitly overridden parameters) in its own tmux session in parallel. Note that if you have unclosed tmux sessions from previous runs, they might count towards the max_concurrent_sessions limit. You can terminate all sessions with `tmux kill-server`. :param max_concurrent_sessions: optionally restrict how many tmux sessions to open in parallel, each script will run in a tmux session :param benchmark_type: mujoco or atari :param num_experiments: number of experiments to run per script :param max_scripts: maximum number of scripts to run, -1 for all. Set this to a low number for testing. :param tasks: optional list of task names to run benchmarks on. If None, uses default tasks for the benchmark_type. :param max_tasks: maximum number of tasks to run, -1 for all. Set this to a low number for testing. :param max_epochs: optional maximum number of training epochs to pass to all scripts. If None, uses script defaults. :param epoch_num_steps: optional number of environment steps per epoch to pass to all scripts. If None, uses script defaults. :param num_training_envs: optional number of training environments to pass to all scripts. If None, uses script defaults. :param num_test_envs: optional number of test environments to pass to all scripts. If None, uses script defaults. :param experiment_launcher: type of experiment launcher to use, only has an effect if `num_experiments>1`. By default, will use the experiment launchers defined in the individual scripts. :param include_filter: glob pattern to include scripts :param exclude_filter: optional glob pattern to exclude scripts (e.g., "*ddpg*") :return: """ # Use default tasks if none provided if tasks is None: tasks = DEFAULT_TASKS.get(benchmark_type, []) if not tasks: raise ValueError( f"No default tasks found for benchmark_type '{benchmark_type}'. Please provide tasks manually." ) # Limit number of tasks if specified if max_tasks > 0: log.info(f"Limiting to first {max_tasks}/{len(tasks)} tasks.") tasks = tasks[:max_tasks] log.info(f"Running benchmarks for {len(tasks)} task(s): {tasks}") persistence_base_dir = Path(__file__).parent / "logs" / benchmark_type / datetime_tag() # file logger for the global benchmarking logs, each individual experiment will log to its own file log_file = persistence_base_dir / "benchmarking_run.txt" log_file.parent.mkdir(parents=True, exist_ok=True) logging.add_file_logger(log_file, append=False) scripts = find_script_paths( benchmark_type, exclude_filter=exclude_filter, include_filter=include_filter ) if max_scripts > 0: log.info(f"Limiting to first {max_scripts}/{len(scripts)} scripts.") scripts = scripts[:max_scripts] if max_concurrent_sessions is None: max_concurrent_sessions = len(scripts) # Run benchmarks for each task for i_task, task in enumerate(tasks, 1): log.info( f"=== Starting benchmark batch for '{benchmark_type}' on task '{task}' ({i_task}/{len(tasks)}) " f"for {len(scripts)} scripts with {max_concurrent_sessions} concurrent jobs ===" ) for i_script, script in enumerate(scripts, start=1): # Wait for free slot has_printed_waiting_message = False while len(get_current_tmux_sessions(benchmark_type)) >= max_concurrent_sessions: if not has_printed_waiting_message: log.info( f"Max concurrent sessions reached ({max_concurrent_sessions}). " f"Current sessions:\n{get_current_tmux_sessions(benchmark_type)}\nWaiting for a free slot..." ) has_printed_waiting_message = True time.sleep(SESSION_CHECK_INTERVAL) log.info(f"Starting script {i_script}/{len(scripts)} for task '{task}'") session_started = start_tmux_session( script, benchmark_type=benchmark_type, persistence_base_dir=persistence_base_dir, num_experiments=num_experiments, task=task, max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, experiment_launcher=experiment_launcher, num_training_envs=num_training_envs, num_test_envs=num_test_envs, ) if session_started: time.sleep(TMUX_SESSION_START_DELAY) # Give tmux a moment to start the session has_printed_final_waiting_message = False # Wait for all sessions to complete before moving to next task while len(get_current_tmux_sessions(benchmark_type)) > 0: if not has_printed_final_waiting_message: log.info( f"All scripts for task '{task}' have been started, waiting for completion of remaining tmux sessions:\n" f"{get_current_tmux_sessions(benchmark_type)}" ) has_printed_final_waiting_message = True time.sleep(COMPLETION_CHECK_INTERVAL) log.info(f"All tmux sessions for task '{task}' have completed.") # Aggregate results for this specific task (scripts create task-named directory automatically) task_results_dir = persistence_base_dir / task log.info(f"Aggregating results for task '{task}' from directory: {task_results_dir}") try: aggregate_rliable_results(str(task_results_dir)) except Exception as e: log.error(f"Failed to aggregate rliable results for task '{task}': {e}\nContinuing...") log.info( f"=== Benchmark batch completed for all {len(scripts)} scripts and all {len(tasks)} task(s) ===" ) if __name__ == "__main__": logging.run_cli(main) ================================================ FILE: docs/.gitignore ================================================ /03_api/* jupyter_execute _toc.yml .jupyter_cache ================================================ FILE: docs/01_user_guide/00_training_process.md ================================================ # The Reinforcement Learning Process The following diagram illustrates the key mechanisms underlying the learning process in model-free reinforcement learning algorithms. It shows how the agent interacts with the environment, collects experiences, and periodically updates its policy based on those experiences. Accordingly, the key entities involved in the learning process are: * The **environment**: This is the system the agent interacts with. It provides the agent with observable states and rewards based on the actions taken by the agent. * The agent's **policy**: This is the strategy used by the agent to decide which action to take in a given state. The policy can be deterministic or stochastic and is typically represented by a neural network in deep reinforcement learning. * The **replay buffer**: This is a data structure used to store the agent's experiences, which consist of state transitions, actions taken, and rewards received. The agent learns from past experience by sampling mini-batches from the buffer during the policy update phase. * The **learning algorithm**: This defines how the agent updates its policy based on the experiences stored in the replay buffer. Different algorithms have different update mechanisms, which can significantly affect the learning performance. In some cases, the algorithm may also involve additional components (specifically neural networks), such as target networks or value functions. These entities have direct correspondences in Tianshou's codebase: * The environment is represented by an instance of a class that inherits from `gymnasium.Env`, which is a standard interface for reinforcement learning environments. In practice, environments are typically vectorized to enable parallel interactions, increasing efficiency. * The policy is encapsulated in the {class}`~tianshou.algorithm.algorithm_base.Policy` class, which provides methods for action selection. * The replay buffer is implemented in the {class}`~tianshou.data.buffer.buffer_base.ReplayBuffer` class. A {class}`~tianshou.data.collector.Collector` instance is used to manage the addition of new experiences to the replay buffer as the agent interacts with the environment. During the learning phase, the replay buffer may be sampled, providing an instance of {class}`~tianshou.data.batch.Batch` for the policy update. * The abstraction for learning algorithms is given by the {class}`~tianshou.algorithm.algorithm_base.Algorithm` class, which defines how to update the policy using data from the replay buffer. (structuring-the-process)= ## Structuring the Process The learning process itself is reified in Tianshou's {class}`~tianshou.trainer.trainer.Trainer` class, which orchestrates the interaction between the agent and the environment, manages the replay buffer, and coordinates the policy updates according to the specified learning algorithm. In general, the process can be described as executing a number of epochs as follows: * **epoch**: * repeat until a sufficient number of steps is reached (for online learning, typically environment step count) * **training step**: * for online learning algorithms … * **collection step**: collect state transitions in the environment by running the agent * (optionally) conduct a test step if collected data indicates promising behaviour * **update step**: apply gradient updates using the algorithm’s update logic. The update is based on … * data from the preceding collection step only (on-policy learning) * data from the collection step and previous data (off-policy learning) * data from a user-provided replay buffer (offline learning) * **test step** * collect test episodes from dedicated test environments and evaluate agent performance * (optionally) stop training early if performance is sufficiently high ```{admonition} Glossary :class: note The above introduces some of the key terms used throughout Tianshou. ``` Note that the above description encompasses several modes of model-free reinforcement learning, including: * online learning (where the agent continuously interacts with the environment in order to collect new experiences) * on-policy learning (where the policy is updated based on data collected using the current policy only) * off-policy learning (where the policy is updated based on data collected using the current and previous policies) * offline learning (where the replay buffer is pre-filled and not updated during training) In Tianshou, the {class}`~tianshou.trainer.trainer.Trainer` and {class}`~tianshou.algorithm.algorithm_base.Algorithm` classes are specialised to handle these different modes accordingly. ================================================ FILE: docs/01_user_guide/01_apis.md ================================================ # Dual APIs Tianshou provides two distinct APIs to serve different use cases and user preferences: 1. **high-level API**: a declarative, configuration-based interface designed for ease of use 2. **procedural API**: a flexible, imperative interface providing maximum control Both APIs access the same underlying algorithm implementations, allowing you to choose the level of abstraction that best fits your needs without sacrificing functionality. ## Overview ### High-Level API The high-level API is built around the **builder pattern** and **declarative semantics**. Instead of writing procedural code that sequentially constructs and connects components, you declare _what_ you want through configuration objects and let Tianshou handle _how_ to build and execute the experiment. **Key characteristics:** - centered around {class}`~tianshou.highlevel.experiment.ExperimentBuilder` classes (e.g., {class}`~tianshou.highlevel.experiment.DQNExperimentBuilder`, {class}`~tianshou.highlevel.experiment.PPOExperimentBuilder`, etc.) - uses configuration dataclasses and factories for all relevant parameters - automatically handles component creation and "wiring" - provides sensible defaults that adapt to the nature of your environment - includes built-in persistence, logging, and experiment management - full type hints (but object structure is not flat; a proper IDE is required for seamless user experience) ### Procedural API The procedural API provides explicit control over every component in the RL pipeline. You manually create environments, networks, policies, algorithms, collectors, and trainers, then wire them together. **Key characteristics:** - direct instantiation of all components - explicit control over the training loop - lower-level access to internal mechanisms - minimal abstraction (closer to the implementation) - ideal for algorithm development and research ## When to Use Which API Use the high-level API when ... - **you're applying existing algorithms** to new problems - **you want to get started quickly** with minimal boilerplate - **you need experiment management** with persistence, logging, and reproducibility - **you prefer declarative code** that focuses on configuration - **you're building applications** rather than developing new algorithms Use the procedural API when: - **you're developing new algorithms** or modifying existing ones - **you need fine-grained control** over the training process - **you want to understand** the internal workings of Tianshou - **you're implementing custom components** not supported by the high-level API - **you prefer imperative programming** where each step is explicit - **you need maximum flexibility** for experimental research ## Comparison by Example Let's compare both APIs by implementing the same DQN learning task on the CartPole environment. ### High-Level API Example ```python from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.env import EnvFactoryRegistered, VectorEnvType from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig from tianshou.highlevel.params.algorithm_params import DQNParams from tianshou.highlevel.trainer import EpochStopCallbackRewardThreshold # Build the experiment through configuration experiment = ( DQNExperimentBuilder( # Environment configuration EnvFactoryRegistered( task="CartPole-v1", venv_type=VectorEnvType.DUMMY, training_seed=0, test_seed=10, ), # Experiment settings ExperimentConfig( persistence_enabled=False, watch=True, watch_render=1 / 35, watch_num_episodes=100, ), # Training configuration OffPolicyTrainingConfig( max_epochs=10, epoch_num_steps=10000, batch_size=64, num_training_envs=10, num_test_envs=100, buffer_size=20000, collection_step_num_env_steps=10, update_step_num_gradient_steps_per_sample=1 / 10, ), ) # Algorithm-specific parameters .with_dqn_params( DQNParams( lr=1e-3, gamma=0.9, n_step_return_horizon=3, target_update_freq=320, eps_training=0.3, eps_inference=0.0, ), ) # Network architecture .with_model_factory_default(hidden_sizes=(64, 64)) # Stop condition .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) .build() ) # Run the experiment experiment.run() ``` **What's happening here:** 1. We create an {class}`~tianshou.highlevel.experiment.ExperimentBuilder` with three main configuration objects 2. We chain builder methods to specify algorithm parameters, model architecture, and callbacks 3. We call `.build()` to construct the experiment 4. We call `.run()` to execute the entire training pipeline The high-level API handles ... - creating and configuring environments - building the neural network - instantiating the policy and algorithm - setting up collectors and replay buffer - managing the training loop - watching the trained agent ### Procedural API Example ```python import gymnasium as gym import tianshou as ts from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import CollectStats from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo from torch.utils.tensorboard import SummaryWriter # Define hyperparameters task = "CartPole-v1" lr, epoch, batch_size = 1e-3, 10, 64 num_training_envs, num_test_envs = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 epoch_num_steps, collection_step_num_env_steps = 10000, 10 # Set up logging logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn")) # Create environments training_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) # Build the network env = gym.make(task, render_mode="human") space_info = SpaceInfo.from_env(env) state_shape = space_info.observation_info.obs_shape action_shape = space_info.action_info.action_shape net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) # Create policy and algorithm policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=eps_train, eps_inference=eps_test, ) algorithm = ts.algorithm.DQN( policy=policy, optim=AdamOptimizerFactory(lr=lr), gamma=gamma, n_step_return_horizon=n_step, target_update_freq=target_freq, ) # Set up collectors training_collector = ts.data.Collector[CollectStats]( algorithm, training_envs, ts.data.VectorReplayBuffer(buffer_size, num_training_envs), exploration_noise=True, ) test_collector = ts.data.Collector[CollectStats]( algorithm, test_envs, exploration_noise=True, ) # Define stop condition def stop_fn(mean_rewards: float) -> bool: if env.spec and env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold return False # Train the algorithm result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps, stop_fn=stop_fn, logger=logger, test_in_training=True, ) ) print(f"Finished training in {result.timing.total_time} seconds") # Watch the trained agent collector = ts.data.Collector[CollectStats](algorithm, env, exploration_noise=True) collector.collect(n_episode=100, render=1 / 35) ``` **What's happening here:** 1. We explicitly define all hyperparameters as variables 2. We manually create the logger 3. We construct training and test environments 4. We build the neural network by extracting space information from the environment 5. We create the policy and algorithm objects 6. We set up collectors with a replay buffer 7. We define callback functions 8. We call `algorithm.run_training()` with explicit parameters 9. We manually set up and run the evaluation collector The procedural API requires ... - explicit creation of every component - manual extraction of environment properties - direct specification of all connections ## Key Concepts in the High-Level API ### ExperimentBuilder The {class}`~tianshou.highlevel.experiment.ExperimentBuilder` is the core abstraction. Each algorithm has its own builder (e.g., {class}`~tianshou.highlevel.experiment.DQNExperimentBuilder`, {class}`~tianshou.highlevel.experiment.PPOExperimentBuilder`, {class}`~tianshou.highlevel.experiment.SACExperimentBuilder`). **Some methods you will find in experiment builders:** - `.with__params()` - Set algorithm-specific parameters - `.with_model_factory()`, `.with_model_factory_default()` - Configure network architecture - `.with_critic_factory()` - Configure critic network (for actor-critic methods) - `.with_epoch_train_callback()` - Add function to be called at the beginning of the training step in each epoch - `.with_epoch_test_callback()` - Add function to be called at the beginning of the test step in each epoch - `.with_epoch_stop_callback()` - Define stopping conditions - `.with_algorithm_wrapper_factory()` - Add algorithm wrappers (e.g., ICM) ### Configuration Objects Three main configuration objects are required when constructing an experiment builder: 1. **Environment Configuration** ({class}`~tianshou.highlevel.env.EnvFactory` subclasses) - Defines how to create and configure environments - Existing factories: - {class}`~tianshou.highlevel.env.EnvFactoryRegistered` - For the creation of environments registered in Gymnasium - {class}`~tianshou.highlevel.env.atari.atari_wrapper.AtariEnvFactory` - For Atari environments with preprocessing - Custom factories for your own environments can be created by subclassing {class}`~tianshou.highlevel.env.EnvFactory` 2. **Experiment Configuration** ({class}`~tianshou.highlevel.experiment.ExperimentConfig`): General settings for the experiment, particularly related to - logging - randomization - persistence - watching the trained agent's performance after training 3. **Training Configuration** ({class}`~tianshou.highlevel.config.OffPolicyTrainingConfig`, {class}`~tianshou.highlevel.config.OnPolicyTrainingConfig`): Defines all parameters related to the training process ### Parameter Classes Algorithm parameters are defined in dataclasses specific to each algorithm (e.g., {class}`~tianshou.highlevel.params.algorithm_params.DQNParams`, {class}`~tianshou.highlevel.params.algorithm_params.PPOParams`). The parameters are extensively documented. ```{note} Make sure to use a modern IDE to take advantage of auto-completion and inline documentation! ``` ### Factories The high-level API uses factories extensively: - **Model Factories**: Create neural networks (e.g., {class}`~tianshou.highlevel.module.intermediate.IntermediateModuleFactoryAtariDQN`) - **Environment Factories**: Create and configure environments - **Optimizer Factories**: Create optimizers with specific configurations ### Extensibility The high-level API is designed to be extensible. You can create custom factories (e.g. for your own models or your own environments) by subclassing the appropriate base classes and then use them in the experiment builder. If we have created a torch module in `CustomNetwork`, which we want to use within our policy, we simply need to define a factory for it in order to apply it in the high-level API: ```python from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory, IntermediateModule class CustomNetFactory(IntermediateModuleFactory): def __init__(self, hidden_sizes: tuple[int, ...] = (128, 128)): self.hidden_sizes = hidden_sizes def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: obs_shape = envs.get_observation_shape() action_shape = envs.get_action_shape() # Your custom network creation logic net = CustomNetwork( obs_shape=obs_shape, action_shape=action_shape, hidden_sizes=self.hidden_sizes, ).to(device) return IntermediateModule(net, net.output_dim) experiment = ( DQNExperimentBuilder(...) .with_model_factory(CustomNetFactory(hidden_sizes=(256, 256))) .build() ) ``` ## Key Concepts in the Procedural API ### Core Components You manually create and connect ... 1. **environments**: e.g. using `gym.make()` and vectorization ({class}`~tianshou.env.DummyVectorEnv`, {class}`~tianshou.env.SubprocVectorEnv`) 2. **networks**: using {class}`~tianshou.utils.net.common.Net` or other PyTorch modules 3. **policies**: using algorithm-specific policy classes (e.g., {class}`~tianshou.algorithm.modelfree.dqn.DiscreteQLearningPolicy`) 4. **algorithms**: using algorithm classes (e.g., {class}`~tianshou.algorithm.modelfree.dqn.DQN`, {class}`~tianshou.algorithm.modelfree.ppo.PPO`, {class}`~tianshou.algorithm.modelfree.sac.SAC`) 5. **collectors**: using {class}`~tianshou.data.Collector` to gather experience 6. **buffers**: using {class}`~tianshou.data.buffer.VectorReplayBuffer` or {class}`~tianshou.data.buffer.ReplayBuffer` 7. **trainers**: using the respective trainer class and corresponding parameter class (e.g., {class}`~tianshou.trainer.OffPolicyTrainer` and {class}`~tianshou.trainer.OffPolicyTrainerParams`) ### Training Loop The training is executed via `algorithm.run_training()`, which takes a trainer parameter object. You can alternatively implement custom training loops (or even your own trainer class) for maximum flexibility. ## Additional Resources - **high-Level API examples**: See `examples/` directory (scripts ending in `_hl.py`) - **procedural API examples**: See `examples/` directory (scripts without suffix) ================================================ FILE: docs/01_user_guide/02_core_abstractions.md ================================================ # Core Abstractions Tianshou's architecture is built around a number of key abstractions that work together to provide a modular and flexible reinforcement learning framework. This document describes the conceptual foundation and functionality of each abstraction, helping you understand how they interact to enable RL agent training. Knowing these abstractions is primarily relevant when using the procedural API – and particularly when implementing one's own learning algoriithms. ## Algorithm The **{class}`~tianshou.algorithm.algorithm_base.Algorithm`** is the central abstraction representing the core of a reinforcement learning method (such as DQN, PPO, or SAC). It implements the key steps within the {ref}`learning process `, containing a {ref}`policy` and defining how to update it from experience data. Since an Algorithm contains neural networks and manages their training, the class inherits from `torch.nn.Module`. ### Core Responsibilities An Algorithm implements the details of an {ref}`update step `: 1. **preprocessing**: Before the actual update begins, the algorithm prepares the training data. This includes computing derived quantities that depend on temporal sequences, such as n-step returns, GAE advantages, or terminal state handling. The {meth}`~tianshou.algorithm.algorithm_base.Algorithm._preprocess_batch` method handles this phase, often leveraging static methods like {meth}`~tianshou.algorithm.algorithm_base.Algorithm.compute_nstep_return` and {meth}`~tianshou.algorithm.algorithm_base.Algorithm.compute_episodic_return` to efficiently compute returns using the buffer's temporal structure. 2. **network update**: The algorithm performs the actual neural network updates based on its specific learning method. Each algorithm implements its own {meth}`~tianshou.algorithm.algorithm_base.Algorithm._update_with_batch` logic that defines how to update the policy networks using the preprocessed batch data. 3. **postprocessing**: After the update, the algorithm may perform cleanup operations, such as updating prioritized replay buffer weights or other algorithm-specific bookkeeping. ### Learning Orchestration The Algorithm orchestrates the {ref}`update step ` through its {meth}`~tianshou.algorithm.algorithm_base.Algorithm.update` method, which ensures these three phases execute in proper sequence. It also manages optimizer state and learning rate schedulers, making them available for state persistence through {meth}`~tianshou.algorithm.algorithm_base.Algorithm.state_dict` and {meth}`~tianshou.algorithm.algorithm_base.Algorithm.load_state_dict` methods. Each algorithm type (on-policy, off-policy, offline) creates its appropriate trainer through the {meth}`~tianshou.algorithm.algorithm_base.Algorithm.create_trainer` method, establishing the connection between the learning logic and the training loop. (policy)= ## Policy The **{class}`~tianshou.algorithm.algorithm_base.Policy`** represents the agent's decision-making component, i.e. the mapping from observations to actions. While the Algorithm defines how to learn, the Policy defines what is learned and how to act. Like Algorithm, the class inherits from `torch.nn.Module`. ### States of Operation A Policy operates in two main modes: - **training mode**: During training, the policy may employ exploration strategies, sample from action distributions, or add noise to encourage discovery. Training mode is further divided into: - *collecting state*: When gathering experience from environment interaction - *updating state*: When performing network updates during learning - **testing/inference mode**: During evaluation, the policy typically acts deterministically or uses the mode of predicted distributions to showcase learned behavior without exploration. The flag `is_within_training_step` controls the collection strategy, distinguishing between training and inference behavior. ### Key Methods The Policy provides several essential methods: - **{meth}`~tianshou.algorithm.algorithm_base.Policy.forward`**: The core computation method that processes batched observations to produce action distributions or Q-values. It takes a batch of environment data and optional hidden state (for recurrent policies), returning a batch containing at minimum the "act" key, and potentially "state" (hidden state) and "policy" (intermediate results to be stored in the buffer). - **{meth}`~tianshou.algorithm.algorithm_base.Policy.compute_action`**: A convenient method for inference that takes a single observation and returns a concrete action suitable for the environment. This method internally calls `forward` with proper batching and unbatching. - **{meth}`~tianshou.algorithm.algorithm_base.Policy.map_action`**: Transforms the raw neural network output to the environment's action space format, handling any necessary scaling or discretization. The separation between `forward` (which works with batches) and {meth}`~tianshou.algorithm.algorithm_base.Policy.compute_action` (which works with single observations) provides efficiency during training and convenience during inference. ## Collector The class **{class}`~tianshou.data.Collector`** bridges the gap between the policy and the environment(s), managing the process of gathering experience data. It enables efficient interaction with both single environments and vectorized environments (multiple parallel environments). ### Data Collection The Collector's primary method, {meth}`~tianshou.data.Collector.collect`, orchestrates the environment interaction loop. It can collect either: - a specified number of steps (`n_step`): useful for maintaining consistent training batch sizes - a specified number of episodes (`n_episode`): useful for evaluation or when episode-level statistics are important During collection, the Collector ... 1. obtains observations from the environment(s), 2. calls the policy to compute actions, 3. steps the environment(s) with these actions, 4. stores the resulting transitions (observation, action, reward, next observation, termination flags, and info) in the replay buffer, 5. manages episode boundaries and reset logic, 6. collects statistics such as episode returns, lengths, and collection speed. ### Hooks and Extensibility The Collector supports customization through hooks that can be triggered at different points in the collection process: - **step hooks**: called after each environment step - **episode done hooks**: called when episodes complete These hooks enable custom logging, curriculum learning, or other dynamic behaviors during data collection. ### Vectorized Environments The Collector seamlessly handles vectorized environments, where multiple environment instances run in parallel. This significantly speeds up data collection while maintaining correct episode boundaries and statistics for each environment instance. ## Trainer The **{class}`~tianshou.trainer.Trainer`** orchestrates the complete training loop, coordinating data collection, policy updates, and evaluation. It provides the high-level control flow that brings all components together. ### Trainer Types Tianshou provides three main trainer types, each suited to different algorithm families: - **{class}`~tianshou.trainer.OnPolicyTrainer`**: for algorithms that must learn from freshly collected data (e.g., PPO, A2C). After each collection phase, the buffer is used for updates and thereafter is cleared. - **{class}`~tianshou.trainer.OffPolicyTrainer`**: for algorithms that can learn from any past experience (e.g., DQN, SAC, DDPG). Data accumulates in the replay buffer over time, and updates sample from this growing pool of experience. - **{class}`~tianshou.trainer.OfflineTrainer`**: for algorithms that learn exclusively from a fixed dataset without any environment interaction (e.g., BCQ, CQL). ### Training Loop Structure The training process is organized into epochs, where each epoch consists of: 1. **data collection**: The trainer uses the train collector to gather experience according to its algorithm type's needs 2. **policy update**: The algorithm performs one or more update steps using the collected data 3. **evaluation**: Periodically, the trainer uses the test collector to evaluate the current policy's performance 4. **logging**: Statistics from collection, updates, and evaluation are logged 5. **checkpointing**: The best policy (according to a scoring function) is saved The trainer handles the detailed choreography of these steps, including determining when to collect more data, how many update steps to perform, when to evaluate, and when to stop training (based on maximum epochs, timesteps, or early stopping criteria). ### Configuration Trainers are configured through parameter dataclasses ({class}`~tianshou.trainer.OnPolicyTrainerParams`, {class}`~tianshou.trainer.OffPolicyTrainerParams`, {class}`~tianshou.trainer.OfflineTrainerParams`) that specify in particular: - training duration (number of epochs, steps per epoch) - collectors for training and testing - update frequency and batch size - evaluation frequency - logging and checkpointing settings - early stopping criteria ## Batch The class **{class}`~tianshou.data.Batch`** is Tianshou's flexible data structure for passing information between components. It serves as the lingua franca of the framework, carrying everything from raw environment observations to computed returns and policy outputs. ### Design Philosophy Batch is designed to be ... - **flexible**: can contain any key-value pairs, with nested structures supported, - **numpy/torch-compatible**: automatically converts lists to arrays and seamlessly works with both NumPy arrays and PyTorch tensors, - **sliceable**: supports indexing and slicing operations that work across all contained data, - **composable**: can be concatenated, stacked, and split to support batching operations. ### Type Safety with BatchProtocol While `Batch` provides a flexible, dictionary-like structure for holding arbitrary data, this flexibility can make it challenging to statically type-check which attributes are present in a batch at any given point in the code. To address this, Tianshou uses **{class}`~tianshou.data.batch.BatchProtocol`** and derived protocols to specify the expected attributes while keeping the actual runtime type as `Batch`. BatchProtocol is a Python `Protocol` (from `typing.Protocol`) that defines the interface of a Batch object, specifying which operations and attributes should be available. More importantly, Tianshou provides a rich set of derived protocols in {mod}`tianshou.data.types` that describe batches with specific sets of attributes commonly used throughout the framework: - **{class}`~tianshou.data.types.ObsBatchProtocol`**: Contains `obs` and `info` - the minimal batch for policy forward passes - **{class}`~tianshou.data.types.RolloutBatchProtocol`**: Adds `obs_next`, `act`, `rew`, `terminated`, and `truncated` - typical data from replay buffer sampling - **{class}`~tianshou.data.types.BatchWithReturnsProtocol`**: Extends RolloutBatchProtocol with `returns` computed from rewards - **{class}`~tianshou.data.types.BatchWithAdvantagesProtocol`**: Includes `adv` (advantages) and `v_s` (value estimates) for policy gradient methods - **{class}`~tianshou.data.types.ActStateBatchProtocol`**: Contains `act` and `state` for policy outputs, especially with RNN support - **{class}`~tianshou.data.types.ModelOutputBatchProtocol`**: Adds `logits` to action and state information - **{class}`~tianshou.data.types.DistBatchProtocol`**: Contains action distributions (`dist`) for stochastic policies - **{class}`~tianshou.data.types.PrioBatchProtocol`**: Includes `weight` for prioritized experience replay These protocols serve as type hints in function signatures throughout Tianshou, making it explicit what attributes are expected and available. For example, a policy's `forward` method might accept an `ObsBatchProtocol` and return an `ActStateBatchProtocol`, clearly documenting the data contract. Despite these type annotations, the actual objects remain flexible `Batch` instances at runtime, preserving Tianshou's dynamic nature while improving code clarity and IDE support. ### Common Use Cases Batches flow through the system carrying different types of information: 1. **environment data**: observations, rewards, done flags, and info from environment steps 2. **policy outputs**: actions, hidden states, and intermediate computations 3. **training data**: returns, advantages, and other computed quantities needed for learning 4. **sampling results**: batches sampled from the replay buffer for training ### Operations Key operations on batches include: - **attribute access**: dot notation (`batch.obs`) or dictionary-style access (`batch['obs']`) - **slicing**: extract subsets with standard indexing (`batch[0:10]`, `batch[[1,3,5]]`) - **stacking**: combine multiple batches along a new dimension - **type conversion**: convert between NumPy and PyTorch with `to_numpy()` and `to_torch()` - **null handling**: detect and remove null values with `hasnull()`, `isnull()`, and `dropnull()` The first dimension of all data in a Batch represents the batch size, enabling vectorized operations. ## Buffer A **buffer** (i.e. class {class}`~tianshou.data.buffer.ReplayBuffer` and its variants) manages the storage and retrieval of experience data. It acts as the memory of the learning system, preserving the temporal structure of episodes while providing efficient access patterns. ### Storage Structure Buffers store data in a circular queue fashion with a fixed maximum size. When the buffer fills, new data overwrites the oldest stored experiences. All data is stored within a single underlying Batch object, with the buffer managing: - **pointer tracking**: current insertion position - **episode boundaries**: which transitions belong to which episodes - **temporal relationships**: the sequential order of transitions ### Reserved Keys Buffers use a standard set of keys for storing transitions: - `obs`: Observation at time t - `act`: Action taken at time t - `rew`: Reward received at time t - `terminated`: True if the episode ended naturally at time t - `truncated`: True if the episode was cut off at time t (e.g., time limit) - `done`: Automatically inferred as `terminated or truncated` - `obs_next`: Observation at time t+1 - `info`: Additional information from the environment - `policy`: Intermediate policy computations to be stored ### Core Operations **adding data**: The {meth}`~tianshou.data.buffer.buffer_base.ReplayBuffer.add` method stores new transitions, automatically handling episode boundaries and computing episode statistics (return, length) when episodes complete. **sampling**: The {meth}`~tianshou.data.buffer.buffer_base.ReplayBuffer.sample` method retrieves batches of experiences for training, returning both the sampled batch and the corresponding indices. The sample size can be specified, or set to 0 to retrieve all available data. **temporal navigation**: The {meth}`~tianshou.data.buffer.buffer_base.ReplayBuffer.prev` and {meth}`~tianshou.data.buffer.ReplayBuffer.next` methods enable traversal along the temporal sequence, respecting episode boundaries. This is essential for computing n-step returns and other time-dependent quantities. **persistence**: Buffers support saving and loading via pickle or HDF5 format, enabling dataset collection and offline learning. ### Buffer Variants Tianshou provides specialized buffer types: - **{class}`~tianshou.data.buffer.buffer_base.ReplayBuffer`**: the standard buffer for single environments - **{class}`~tianshou.data.buffer.vecbuf.VectorReplayBuffer`**: manages separate sub-buffers for multiple parallel environments while maintaining chronological order - **{class}`~tianshou.data.buffer.prio.PrioritizedReplayBuffer`**: samples transitions based on their TD-error or other priority metrics, using an efficient segment tree implementation ### Advanced Features Buffers support sophisticated use cases: - **frame stacking**: automatically stacks consecutive observations (useful for RNN inputs or Atari) - **memory optimization**: option to skip storing next observations (useful for Atari where they can be inferred) - **multi-modal observations**: handle observations with multiple components (e.g., image + vector) ## Logger The **{class}`~tianshou.utils.logger.logger_base.BaseLogger`** abstraction provides a unified interface for recording and tracking training progress, metrics, and statistics. It decouples the training loop from the specifics of where and how data is logged. ### Purpose Loggers serve several essential functions: - **progress tracking**: record timesteps, episodes, and epochs as training progresses - **metric collection**: store performance indicators like rewards, losses, and success rates - **experiment organization**: manage different data scopes (training, testing, updating) - **reproducibility**: save training curves and hyperparameters for later analysis ### Logging Scopes The framework organizes logged data into distinct scopes: - **train data**: metrics from the training collector (episode returns, steps, collection speed) - **test data**: evaluation metrics from the test collector - **update data**: learning statistics from the algorithm (losses, gradients, learning rates) - **info data**: additional custom metrics or metadata Each scope has a corresponding log method (`log_train_data`, `log_test_data`, `log_update_data`, `log_info_data`) that the trainer calls at appropriate times. ### Implementations Tianshou provides several logger implementations: - **{class}`~tianshou.utils.logger.tensorboard.TensorboardLogger`**: writes to TensorBoard format for visualization with TensorBoard - **{class}`~tianshou.utils.logger.wandb.WandbLogger`**: integrates with Weights & Biases for cloud-based experiment tracking All implementations inherit from {class}`~tianshou.utils.logger.logger_base.BaseLogger` and share a common interface, making it easy to switch between logging backends or use multiple loggers simultaneously. ### Data Preparation Before writing, loggers prepare data through the `prepare_dict_for_logging` method, which can filter, transform, or aggregate metrics. The `write` method then persists the prepared data to the logging backend with an associated step count. ## How They Work Together These seven abstractions collaborate to enable reinforcement learning: 1. The **Trainer** initializes and orchestrates the training process. 2. The **Collector** uses the **Policy** to gather experience from environments. 3. Collected transitions are stored in the **Buffer** extracted as **Batches**. 4. The **Algorithm** samples from the **Buffer**, preprocesses the data, and updates the **Policy**. 5. The **Logger** records metrics throughout the process. 6. The cycle repeats until training completes. This modular design allows each component to focus on its specific responsibility while maintaining clean interfaces. You can customize individual components (e.g., implementing a new Algorithm or Buffer) without affecting the others, making Tianshou both powerful and flexible. ================================================ FILE: docs/01_user_guide/index.rst ================================================ User Guide ========== The user guide provides an introduction to core concepts, establishes the glossary of terms, introduces Tianshou's dual API architecture and provides an overview of important abstractions. ================================================ FILE: docs/02_deep_dives/0_intro.md ================================================ # Deep Dives Our deep dives are a collection of executable tutorials on some of the internal representations used by Tianshou. Provided as notebooks, you can run them directly in Colab or download them to run them locally. ================================================ FILE: docs/02_deep_dives/L1_Batch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Batch: Tianshou's Core Data Structure\n", "\n", "The `Batch` class is Tianshou's fundamental data structure for efficiently storing and manipulating heterogeneous data in reinforcement learning. This tutorial provides comprehensive guidance on understanding its conceptual foundations, operational behavior, and best practices.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "from typing import cast\n", "\n", "import numpy as np\n", "import torch\n", "from torch.distributions import Categorical, Normal\n", "\n", "from tianshou.data import Batch\n", "from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Introduction: Why Batch?\n", "\n", "### The Challenge in Reinforcement Learning\n", "\n", "Reinforcement learning algorithms face a fundamental data management challenge:\n", "\n", "1. **Diverse Data Requirements**: Different RL algorithms need different data fields:\n", " - Basic algorithms: `state`, `action`, `reward`, `done`, `next_state`\n", " - Actor-Critic: additionally `advantages`, `returns`, `values`\n", " - Policy Gradient: additionally `log_probs`, `old_log_probs`\n", " - Off-policy: additionally `priority_weights`\n", "\n", "2. **Heterogeneous Observation Spaces**: Environments return diverse observation types:\n", " - Simple: vectors (`np.array([1.0, 2.0, 3.0])`)\n", " - Complex: images (`np.array(shape=(84, 84, 3))`)\n", " - Hybrid: dictionaries combining multiple modalities\n", " ```python\n", " obs = {\n", " 'camera': np.array(shape=(64, 64, 3)),\n", " 'velocity': np.array([1.2, 0.5]),\n", " 'inventory': np.array([5, 2, 0])\n", " }\n", " ```\n", "\n", "3. **Data Flow Across Components**: Data must flow seamlessly through:\n", " - Collectors (gathering experience from environments)\n", " - Replay Buffers (storing and sampling transitions)\n", " - Policies and Algorithms (learning and inference)\n", "\n", "### Why Not Alternatives?\n", "\n", "#### Plain Dictionaries\n", "Dictionaries lack essential features\n", "```python\n", "data = {'obs': np.array([1, 2]), 'reward': np.array([1.0, 2.0])}\n", "```\n", "\n", "They would work in principle but has no shape/length semantics, no indexing, and no type safety.\n", "\n", "#### TensorDict\n", "While `TensorDict` (used in `pytorch-rl`) is a powerful alternative:\n", "- **Batch supports arbitrary objects**, not just tensors (useful for object-dtype arrays, custom types)\n", "- **Batch has better type checking** via `BatchProtocol` (enables IDE autocompletion)\n", "- **Batch preceded TensorDict** and provides a stable foundation for Tianshou\n", "- **TensorDict isn't part of core PyTorch** (external dependency)\n", "\n", "### What is Batch?\n", "\n", "**Batch = Dictionary + Array hybrid with RL-specific features**\n", "\n", "Key capabilities:\n", "- **Dict-like**: Key-value storage with attribute access (`batch.obs`, `batch.reward`)\n", "- **Array-like**: Shape, indexing, slicing (`batch[0]`, `batch[:10]`, `batch.shape`)\n", "- **Hierarchical**: Nested structures for complex data\n", "- **Type-safe**: Protocol-based typing for IDE support\n", "- **RL-aware**: Special handling for distributions, missing values, heterogeneous aggregation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Core Concepts\n", "\n", "### Hierarchical Named Tensors\n", "\n", "Batch stores **hierarchical named tensors** - collections of tensors whose identifiers form a structured hierarchy. Consider tensors `[t1, t2, t3, t4]` with names `[name1, name2, name3, name4]`, where `name1` and `name2` are under namespace `name0`. The fully qualified name of `t1` is `name0.name1`.\n", "\n", "### Tree Structure Visualization\n", "\n", "The structure can be visualized as a tree with:\n", "- **Root**: The Batch object itself\n", "- **Internal nodes**: Keys (names)\n", "- **Leaf nodes**: Values (scalars, arrays, tensors)\n", "\n", "```mermaid\n", "graph TD\n", " root[\"Batch (root)\"]\n", " root --> obs[\"obs\"]\n", " root --> act[\"act\"]\n", " root --> rew[\"rew\"]\n", " obs --> camera[\"camera\"]\n", " obs --> sensory[\"sensory\"]\n", " camera --> cam_data[\"np.array(3,3)\"]\n", " sensory --> sens_data[\"np.array(5,)\"]\n", " act --> act_data[\"np.array(2,)\"]\n", " rew --> rew_data[\"3.66\"]\n", " \n", " style root fill:#e1f5ff\n", " style obs fill:#fff4e1\n", " style act fill:#fff4e1\n", " style rew fill:#fff4e1\n", " style camera fill:#ffe1f5\n", " style sensory fill:#ffe1f5\n", " style cam_data fill:#e8f5e1\n", " style sens_data fill:#e8f5e1\n", " style act_data fill:#e8f5e1\n", " style rew_data fill:#e8f5e1\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Example: hierarchical structure\n", "data = {\n", " \"action\": np.array([1.0, 2.0, 3.0]),\n", " \"reward\": 3.66,\n", " \"obs\": {\n", " \"camera\": np.zeros((3, 3)),\n", " \"sensory\": np.ones(5),\n", " },\n", "}\n", "\n", "batch = Batch(data)\n", "print(batch)\n", "print(\"\\nAccessing nested values:\")\n", "print(f\"batch.obs.camera.shape = {batch.obs.camera.shape}\")\n", "print(f\"batch.obs.sensory = {batch.obs.sensory}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data Flow in RL Pipeline\n", "\n", "Batch facilitates data flow throughout the RL pipeline:\n", "\n", "```mermaid\n", "graph LR\n", " A[Collector] -->|ActBatchProtocol| B[Environment]\n", " B[Environment + Action] -->|RolloutBatchProtocol| C[Replay Buffer]\n", " C -->|RolloutBatchProtocol| D[Policy]\n", " D -->|ActBatchProtocol| A\n", " D -->|BatchWithAdvantages| E[Algorithm/Trainer]\n", " E --> D\n", " \n", " style A fill:#e1f5ff\n", " style B fill:#fff4e1\n", " style C fill:#ffe1f5\n", " style D fill:#e8f5e1\n", " style E fill:#f5e1e1\n", "```\n", "\n", "Each arrow represents a specific `BatchProtocol` that defines what fields are expected at that stage." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Basic Operations\n", "\n", "### 3.1 Construction\n", "\n", "Batch objects can be constructed in several ways:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# From keyword arguments\n", "batch1 = Batch(a=4, b=[5, 5], c=\"hello\")\n", "print(\"From kwargs:\", batch1)\n", "\n", "# From dictionary\n", "batch2 = Batch({\"a\": 4, \"b\": [5, 5], \"c\": \"hello\"})\n", "print(\"\\nFrom dict:\", batch2)\n", "\n", "# From list of dictionaries (automatically stacked)\n", "batch3 = Batch([{\"a\": 1, \"b\": 2}, {\"a\": 3, \"b\": 4}])\n", "print(\"\\nFrom list of dicts:\", batch3)\n", "\n", "# Nested batch\n", "batch4 = Batch(obs=Batch(x=1, y=2), act=5)\n", "print(\"\\nNested:\", batch4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.2 Content Rules\n", "\n", "Understanding what Batch can store and how it converts data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Keys must be strings\n", "batch = Batch()\n", "batch.key1 = \"value\"\n", "batch.key2 = np.array([1, 2, 3])\n", "print(\"Keys:\", list(batch.keys()))\n", "\n", "# Automatic conversions\n", "demo = Batch(\n", " scalar_int=5, # → np.array(5)\n", " scalar_float=3.14, # → np.array(3.14)\n", " list_nums=[1, 2, 3], # → np.array([1, 2, 3])\n", " list_mixed=[1, \"hello\", None], # → np.array([1, \"hello\", None], dtype=object)\n", " dict_val={\"x\": 1, \"y\": 2}, # → Batch(x=1, y=2)\n", ")\n", "\n", "print(\"\\nAutomatic conversions:\")\n", "print(f\"scalar_int type: {type(demo.scalar_int)}, value: {demo.scalar_int}\")\n", "print(f\"list_nums type: {type(demo.list_nums)}, dtype: {demo.list_nums.dtype}\")\n", "print(f\"list_mixed dtype: {demo.list_mixed.dtype}\")\n", "print(f\"dict_val type: {type(demo.dict_val)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Important conversions:**\n", "- Lists of numbers → NumPy arrays\n", "- Lists with mixed types → Object-dtype arrays\n", "- Dictionaries → Batch objects (recursively)\n", "- Scalars → NumPy scalars" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.3 Access Patterns\n", "\n", "**Important: Understanding Iteration**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch = Batch(a=[1, 2, 3], b=[4, 5, 6])\n", "\n", "# Attribute vs dictionary access (equivalent)\n", "print(\"Attribute access:\", batch.a)\n", "print(\"Dict access:\", batch[\"a\"])\n", "\n", "# Getting keys\n", "print(\"\\nKeys:\", list(batch.keys()))\n", "\n", "# Gotcha: Iteration is array like, not over keys\n", "print(\"\\nIteration behavior:\")\n", "print(\"for x in batch iterates over batch[0], batch[1], ..., NOT keys!\")\n", "for i, item in enumerate(batch):\n", " print(f\"batch[{i}] = {item}\")\n", "\n", "# This is different from dict behavior!\n", "regular_dict = {\"a\": [1, 2, 3], \"b\": [4, 5, 6]}\n", "print(\"\\nCompare with dict iteration (iterates over keys):\")\n", "for key in regular_dict:\n", " print(f\"key = {key}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "" }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.4 Indexing & Slicing\n", "\n", "Batch supports NumPy-like indexing and slicing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5.0, -5.0], [1.0, -2.0]])\n", "\n", "print(\"Original batch shape:\", batch.shape)\n", "print(\"Original batch length:\", len(batch))\n", "\n", "# Single index\n", "print(\"\\nbatch[0]:\")\n", "print(batch[0])\n", "\n", "# Slicing\n", "print(\"\\nbatch[:1]:\")\n", "print(batch[:1])\n", "\n", "# Advanced indexing\n", "print(\"\\nbatch[[0, 1]]:\")\n", "print(batch[[0, 1]])\n", "\n", "# Multi-dimensional indexing\n", "print(\"\\nbatch[:, 0] (first column of all arrays):\")\n", "print(batch[:, 0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Broadcasting and in-place operations\n", "batch[:, 1] += 10\n", "print(\"After batch[:, 1] += 10:\")\n", "print(batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.5 Stack, Concatenate, and Split\n", "\n", "Combining and splitting batches:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "# Stack: adds a new dimension\nbatch1 = Batch(a=np.array([1, 2]), b=np.array([5, 6]))\nbatch2 = Batch(a=np.array([3, 4]), b=np.array([7, 8]))\n\nstacked = Batch.stack([batch1, batch2])\nprint(\"Stacked:\")\nprint(stacked)\nprint(f\"Shape: {stacked.shape}\")\n\n# Concatenate: extends along existing dimension\nconcatenated = Batch.cat([batch1, batch2])\nprint(\"\\nConcatenated:\")\nprint(concatenated)\nprint(f\"Shape: {concatenated.shape}\")" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Split\n", "batch = Batch(a=np.arange(10), b=np.arange(10, 20))\n", "splits = list(batch.split(size=3, shuffle=False))\n", "print(f\"Split into {len(splits)} batches:\")\n", "for i, split in enumerate(splits):\n", " print(f\"Split {i}: a={split.a}, length={len(split)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.6 Data Type Conversion\n", "\n", "Converting between NumPy and PyTorch:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create batch with NumPy arrays\n", "batch = Batch(a=np.zeros((3, 4)), b=np.ones(5))\n", "print(\"Original (NumPy):\")\n", "print(f\"batch.a type: {type(batch.a)}\")\n", "\n", "# Convert to PyTorch (in-place)\n", "batch.to_torch_(dtype=torch.float32, device=\"cpu\")\n", "print(\"\\nAfter to_torch_():\")\n", "print(f\"batch.a type: {type(batch.a)}\")\n", "print(f\"batch.a dtype: {batch.a.dtype}\")\n", "\n", "# Convert back to NumPy (in-place)\n", "batch.to_numpy_()\n", "print(\"\\nAfter to_numpy_():\")\n", "print(f\"batch.a type: {type(batch.a)}\")\n", "\n", "# Non-in-place versions return a new batch\n", "batch_torch = batch.to_torch()\n", "print(\"\\nOriginal batch unchanged:\", type(batch.a))\n", "print(\"New batch:\", type(batch_torch.a))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Type Safety with Protocols\n", "\n", "### Why Protocols?\n", "\n", "Batch needs to be **flexible** (not fixed fields like dataclasses) but we still want **type safety** and **IDE autocompletion**. Protocols provide the best of both worlds:\n", "\n", "- **Runtime flexibility**: Add any fields dynamically\n", "- **Static type checking**: Type checkers (mypy, pyright) verify correct usage\n", "- **IDE support**: Autocompletion for expected fields\n", "\n", "### What is BatchProtocol?\n", "\n", "A `Protocol` defines an interface without implementation. Think of it as a contract: \"any object with these fields is valid.\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Creating a typed batch using cast\n", "# This enables IDE autocompletion and type checking\n", "\n", "# ActBatchProtocol: just needs 'act' field\n", "act_batch = cast(ActBatchProtocol, Batch(act=np.array([1, 2, 3])))\n", "print(\"ActBatchProtocol:\", act_batch.act)\n", "\n", "# ObsBatchProtocol: needs 'obs' and 'info' fields\n", "obs_batch = cast(\n", " ObsBatchProtocol,\n", " Batch(obs=np.array([[1.0, 2.0], [3.0, 4.0]]), info=np.array([{}, {}], dtype=object)),\n", ")\n", "print(\"\\nObsBatchProtocol:\", obs_batch.obs)\n", "\n", "# RolloutBatchProtocol: needs obs, obs_next, act, rew, terminated, truncated\n", "rollout_batch = cast(\n", " RolloutBatchProtocol,\n", " Batch(\n", " obs=np.array([[1.0, 2.0], [3.0, 4.0]]),\n", " obs_next=np.array([[2.0, 3.0], [4.0, 5.0]]),\n", " act=np.array([0, 1]),\n", " rew=np.array([1.0, 2.0]),\n", " terminated=np.array([False, True]),\n", " truncated=np.array([False, False]),\n", " info=np.array([{}, {}], dtype=object),\n", " ),\n", ")\n", "print(\"\\nRolloutBatchProtocol reward:\", rollout_batch.rew)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Protocol Hierarchy\n", "\n", "Tianshou defines a hierarchy of protocols for different use cases:\n", "\n", "```mermaid\n", "graph TD\n", " BP[BatchProtocol
Base protocol] --> OBP[ObsBatchProtocol
obs, info]\n", " BP --> ABP[ActBatchProtocol
act]\n", " ABP --> ASBP[ActStateBatchProtocol
act, state]\n", " OBP --> RBP[RolloutBatchProtocol
+obs_next, act, rew,
terminated, truncated]\n", " RBP --> BWRP[BatchWithReturnsProtocol
+returns]\n", " BWRP --> BWAP[BatchWithAdvantagesProtocol
+adv, v_s]\n", " ASBP --> MOBP[ModelOutputBatchProtocol
+logits]\n", " MOBP --> DBP[DistBatchProtocol
+dist]\n", " DBP --> DLPBP[DistLogProbBatchProtocol
+log_prob]\n", " BWAP --> LOPBP[LogpOldProtocol
+logp_old]\n", " \n", " style BP fill:#e1f5ff\n", " style OBP fill:#fff4e1\n", " style ABP fill:#fff4e1\n", " style RBP fill:#ffe1f5\n", " style BWRP fill:#e8f5e1\n", " style BWAP fill:#e8f5e1\n", " style DBP fill:#f5e1e1\n", " style LOPBP fill:#e1e1f5\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using Protocols in Functions\n", "\n", "Protocols enable type-safe function signatures:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def process_observations(batch: ObsBatchProtocol) -> np.ndarray:\n", " \"\"\"Function that expects observations.\n", "\n", " IDE will autocomplete batch.obs and batch.info!\n", " Type checker will verify these fields exist.\n", " \"\"\"\n", " # IDE knows batch.obs exists\n", " return batch.obs if isinstance(batch.obs, np.ndarray) else np.array(batch.obs)\n", "\n", "\n", "def compute_advantage(batch: RolloutBatchProtocol) -> np.ndarray:\n", " \"\"\"Function that expects rollout data.\n", "\n", " IDE will autocomplete batch.rew, batch.obs_next, etc.\n", " \"\"\"\n", " # Simplified advantage computation\n", " return batch.rew # IDE knows this exists\n", "\n", "\n", "# Example usage\n", "obs_data = Batch(obs=np.array([1, 2, 3]), info=np.array([{}], dtype=object))\n", "result = process_observations(obs_data)\n", "print(\"Processed obs:\", result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Key Protocol Types:**\n", "\n", "- `ActBatchProtocol`: Just actions (for simple policies)\n", "- `ObsBatchProtocol`: Observations and info\n", "- `RolloutBatchProtocol`: Complete transitions (obs, act, rew, done, obs_next)\n", "- `BatchWithReturnsProtocol`: Rollouts + computed returns\n", "- `BatchWithAdvantagesProtocol`: Returns + advantages and values\n", "- `DistBatchProtocol`: Contains distribution objects\n", "- `LogpOldProtocol`: For importance sampling (PPO, etc.)\n", "\n", "See `tianshou/data/types.py` for the complete list!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Distribution Slicing\n", "\n", "### Why Special Handling?\n", "\n", "PyTorch `Distribution` objects need special slicing because they're not simple arrays. When you slice `batch[0:2]`, Tianshou needs to slice the underlying distribution parameters correctly.\n", "\n", "### Supported Distributions\n", "\n", "Tianshou supports slicing for:\n", "- `Categorical`: Discrete distributions\n", "- `Normal`: Continuous Gaussian distributions\n", "- `Independent`: Wraps other distributions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Categorical distribution\n", "probs = torch.tensor([[0.3, 0.7], [0.4, 0.6], [0.5, 0.5]])\n", "dist = Categorical(probs=probs)\n", "batch = Batch(dist=dist, values=np.array([1, 2, 3]))\n", "\n", "print(\"Original batch length:\", len(batch))\n", "print(\"Original dist probs shape:\", batch.dist.probs.shape)\n", "\n", "# Slicing automatically handles the distribution\n", "sliced = batch[0:2]\n", "print(\"\\nSliced batch length:\", len(sliced))\n", "print(\"Sliced dist probs shape:\", sliced.dist.probs.shape)\n", "print(\"Sliced values:\", sliced.values)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Normal distribution\n", "loc = torch.tensor([0.0, 1.0, 2.0])\n", "scale = torch.tensor([1.0, 1.0, 1.0])\n", "normal_dist = Normal(loc=loc, scale=scale)\n", "batch_normal = Batch(dist=normal_dist, actions=np.array([0.5, 1.5, 2.5]))\n", "\n", "print(\"Normal distribution batch:\")\n", "print(f\"Original mean: {batch_normal.dist.mean}\")\n", "\n", "# Index a single element\n", "single = batch_normal[1]\n", "print(f\"\\nSingle element mean: {single.dist.mean}\")\n", "print(f\"Single element action: {single.actions}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Converting to At Least 2D\n", "\n", "Sometimes you need to ensure distributions have a batch dimension:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tianshou.data.batch import dist_to_atleast_2d\n", "\n", "# Scalar distribution (no batch dimension)\n", "scalar_dist = Categorical(probs=torch.tensor([0.3, 0.7]))\n", "print(\"Scalar dist batch_shape:\", scalar_dist.batch_shape)\n", "\n", "# Convert to have batch dimension\n", "batched_dist = dist_to_atleast_2d(scalar_dist)\n", "print(\"Batched dist batch_shape:\", batched_dist.batch_shape)\n", "\n", "# For entire batch\n", "scalar_batch = Batch(a=1, b=2, dist=Categorical(probs=torch.ones(3)))\n", "print(\"\\nBefore to_at_least_2d:\", scalar_batch.dist.batch_shape)\n", "\n", "batch_2d = scalar_batch.to_at_least_2d()\n", "print(\"After to_at_least_2d:\", batch_2d.dist.batch_shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Use Cases\n", "\n", "Distribution slicing is used in:\n", "- **Policy sampling**: When policies output distributions, slicing batches preserves distribution structure\n", "- **Replay buffer sampling**: Distributions are stored and retrieved correctly\n", "- **Advantage computation**: Computing log probabilities on subsets of data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Advanced Topics\n", "\n", "### 6.1 Key Reservation\n", "\n", "Sometimes you know what keys you'll need but don't have values yet. Reserve keys using empty `Batch()` objects:\n", "\n", "```mermaid\n", "graph TD\n", " root[\"Batch\"]\n", " root --> a[\"key1: np.array([1,2,3])\"]\n", " root --> b[\"key2: Batch() (reserved)\"]\n", " root --> c[\"key3\"]\n", " c --> c1[\"subkey1: Batch() (reserved)\"]\n", " c --> c2[\"subkey2: np.array([4,5])\"]\n", " \n", " style root fill:#e1f5ff\n", " style a fill:#e8f5e1\n", " style b fill:#ffcccc\n", " style c fill:#fff4e1\n", " style c1 fill:#ffcccc\n", " style c2 fill:#e8f5e1\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Reserving keys\n", "batch = Batch(\n", " known_field=np.array([1, 2]),\n", " future_field=Batch(), # Reserved for later\n", ")\n", "print(\"Batch with reserved key:\")\n", "print(batch)\n", "\n", "# Later, assign actual data\n", "batch.future_field = np.array([3, 4])\n", "print(\"\\nAfter assignment:\")\n", "print(batch)\n", "\n", "# Nested reservation\n", "batch2 = Batch(\n", " obs=Batch(\n", " camera=Batch(), # Reserved\n", " lidar=np.zeros(10),\n", " )\n", ")\n", "print(\"\\nNested reservation:\")\n", "print(batch2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.2 Length and Shape Semantics\n", "\n", "Understanding when `len()` works and what `shape` means:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Normal case: all tensors same length\n", "batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5, 6]))\n", "print(\"Normal batch:\")\n", "print(f\"len(batch1) = {len(batch1)}\")\n", "print(f\"batch1.shape = {batch1.shape}\")\n", "\n", "# Scalars have no length\n", "batch2 = Batch(a=5, b=10)\n", "print(\"\\nScalar batch:\")\n", "print(f\"batch2.shape = {batch2.shape}\")\n", "try:\n", " print(f\"len(batch2) = {len(batch2)}\")\n", "except TypeError as e:\n", " print(f\"len(batch2) raises TypeError: {e}\")\n", "\n", "# Mixed lengths: returns minimum\n", "batch3 = Batch(a=[1, 2], b=[3, 4, 5])\n", "print(\"\\nMixed length batch:\")\n", "print(f\"len(batch3) = {len(batch3)} (minimum of 2 and 3)\")\n", "\n", "# Reserved keys are ignored\n", "batch4 = Batch(a=[1, 2, 3], reserved=Batch())\n", "print(\"\\nBatch with reserved key:\")\n", "print(f\"len(batch4) = {len(batch4)} (reserved key ignored)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.3 Empty Batches\n", "\n", "Understanding different meanings of \"empty\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 1. No keys at all\n", "empty1 = Batch()\n", "print(\"No keys:\")\n", "print(f\"len(empty1.get_keys()) = {len(list(empty1.get_keys()))}\")\n", "print(f\"len(empty1) = {len(empty1)}\")\n", "\n", "# 2. Has keys but they're all reserved\n", "empty2 = Batch(a=Batch(), b=Batch())\n", "print(\"\\nReserved keys only:\")\n", "print(f\"len(empty2.get_keys()) = {len(list(empty2.get_keys()))}\")\n", "print(f\"len(empty2) = {len(empty2)}\")\n", "\n", "# 3. Has data but length is 0\n", "empty3 = Batch(a=np.array([]), b=np.array([]))\n", "print(\"\\nZero-length arrays:\")\n", "print(f\"len(empty3.get_keys()) = {len(list(empty3.get_keys()))}\")\n", "print(f\"len(empty3) = {len(empty3)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Checking emptiness:**\n", "- `len(batch.get_keys()) == 0`: No keys (completely empty)\n", "- `len(batch) == 0`: No data elements (may have reserved keys)\n", "\n", "**The `.empty()` and `.empty_()` methods:**\n", "These reset values to zeros/None, different from checking emptiness:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch = Batch(a=[1, 2, 3], b=[\"x\", \"y\", \"z\"])\n", "print(\"Original:\", batch)\n", "\n", "# Empty specific index\n", "batch[0] = Batch.empty(batch[0])\n", "print(\"\\nAfter emptying index 0:\")\n", "print(batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.4 Heterogeneous Aggregation\n", "\n", "Stacking/concatenating batches with different keys:\n", "\n", "```mermaid\n", "graph LR\n", " A[\"Batch(a=[1,2], c=5)\"] --> C[\"Batch.stack\"]\n", " B[\"Batch(b=[3,4], c=6)\"] --> C\n", " C --> D[\"Batch(a=[[1,2],[0,0]],
b=[[0,0],[3,4]],
c=[5,6])\"]\n", " \n", " style A fill:#e1f5ff\n", " style B fill:#fff4e1\n", " style C fill:#ffe1f5\n", " style D fill:#e8f5e1\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Stack with different keys (missing keys padded with zeros)\n", "batch_a = Batch(a=np.ones((2, 3)), shared=np.array([1, 2]))\n", "batch_b = Batch(b=np.zeros((2, 4)), shared=np.array([3, 4]))\n", "\n", "stacked = Batch.stack([batch_a, batch_b])\n", "print(\"Stacked batch:\")\n", "print(f\"a.shape = {stacked.a.shape} (padded with zeros for batch_b)\")\n", "print(f\"b.shape = {stacked.b.shape} (padded with zeros for batch_a)\")\n", "print(f\"shared.shape = {stacked.shared.shape} (in both batches)\")\n", "print(stacked)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.5 Missing Values\n", "\n", "Handling `None` and `NaN` values:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Batch with missing values\n", "batch = Batch(a=[1, 2, None, 4], b=[5.0, np.nan, 7.0, 8.0], c=[[1, 2], [3, 4], [5, 6], [7, 8]])\n", "\n", "# Check for nulls\n", "print(\"Has null?\", batch.hasnull())\n", "\n", "# Get null mask\n", "null_mask = batch.isnull()\n", "print(\"\\nNull mask:\")\n", "print(f\"a: {null_mask.a}\")\n", "print(f\"b: {null_mask.b}\")\n", "\n", "# Drop rows with any null\n", "clean_batch = batch.dropnull()\n", "print(\"\\nAfter dropnull() (keeps rows 0 and 3):\")\n", "print(f\"Length: {len(clean_batch)}\")\n", "print(f\"a: {clean_batch.a}\")\n", "print(f\"b: {clean_batch.b}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.6 Value Transformations\n", "\n", "Applying functions to all values recursively:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch = Batch(a=np.array([1, 2, 3]), nested=Batch(b=np.array([4.0, 5.0]), c=np.array([6, 7, 8])))\n", "\n", "# Apply transformation (returns new batch)\n", "doubled = batch.apply_values_transform(lambda x: x * 2)\n", "print(\"Original batch a:\", batch.a)\n", "print(\"Doubled batch a:\", doubled.a)\n", "print(\"Doubled nested.b:\", doubled.nested.b)\n", "\n", "# In-place transformation\n", "batch.apply_values_transform(lambda x: x + 10, inplace=True)\n", "print(\"\\nAfter in-place +10:\")\n", "print(\"a:\", batch.a)\n", "print(\"nested.b:\", batch.nested.b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Surprising Behaviors & Gotchas\n", "\n", "### Iteration Does NOT Iterate Over Keys!\n", "\n", "**This is the most common source of confusion:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch = Batch(a=[1, 2, 3], b=[4, 5, 6])\n", "\n", "print(\"WRONG: This doesn't iterate over keys!\")\n", "for item in batch:\n", " print(f\"item = {item}\") # Prints batch[0], batch[1], batch[2]\n", "\n", "print(\"\\nCORRECT: To iterate over keys:\")\n", "for key in batch.keys():\n", " print(f\"key = {key}\")\n", "\n", "print(\"\\nCORRECT: To iterate over key-value pairs:\")\n", "for key, value in batch.items():\n", " print(f\"{key} = {value}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Automatic Type Conversions\n", "\n", "Be aware of these automatic conversions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Lists become arrays\n", "batch = Batch(a=[1, 2, 3])\n", "print(\"List → array:\", type(batch.a), batch.a.dtype)\n", "\n", "# Dicts become Batch\n", "batch = Batch(a={\"x\": 1, \"y\": 2})\n", "print(\"Dict → Batch:\", type(batch.a))\n", "\n", "# Scalars become numpy scalars\n", "batch = Batch(a=5)\n", "print(\"Scalar → np.ndarray:\", type(batch.a), batch.a)\n", "\n", "# Mixed types → object dtype\n", "batch = Batch(a=[1, \"hello\", None])\n", "print(\"Mixed → object:\", batch.a.dtype, batch.a)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Length Edge Cases" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 1. Scalars have no length\n", "batch_scalar = Batch(a=5, b=10)\n", "try:\n", " len(batch_scalar)\n", "except TypeError as e:\n", " print(f\"Scalar batch: {e}\")\n", "\n", "# 2. Empty nested batches ignored in len()\n", "batch_empty_nested = Batch(a=[1, 2, 3], b=Batch())\n", "print(f\"\\nWith empty nested: len = {len(batch_empty_nested)} (ignores b)\")\n", "\n", "# 3. Different lengths: returns minimum\n", "batch_different = Batch(a=[1, 2], b=[1, 2, 3, 4])\n", "print(f\"Different lengths: len = {len(batch_different)} (minimum)\")\n", "\n", "# 4. None values don't affect length\n", "batch_none = Batch(a=[1, 2, 3], b=None)\n", "print(f\"With None: len = {len(batch_none)} (None ignored)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### String Keys Only" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Integer keys not allowed\n", "try:\n", " batch = Batch({1: \"value\", 2: \"other\"})\n", "except AssertionError as e:\n", " print(\"Integer keys not allowed:\", e)\n", "\n", "# String keys work\n", "batch = Batch({\"key1\": \"value\", \"key2\": \"other\"})\n", "print(\"\\nString keys work:\", list(batch.keys()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cat vs Stack Behavior\n", "\n", "Recent changes have made concatenation stricter about structure:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Stack pads missing keys with zeros\n", "b1 = Batch(a=[1, 2])\n", "b2 = Batch(b=[3, 4])\n", "stacked = Batch.stack([b1, b2])\n", "print(\"Stack (different keys):\")\n", "print(f\" a: {stacked.a} (b2.a padded with 0)\")\n", "print(f\" b: {stacked.b} (b1.b padded with 0)\")\n", "\n", "# Cat requires same structure now\n", "b3 = Batch(a=[1, 2], b=[3, 4])\n", "b4 = Batch(a=[5, 6], b=[7, 8])\n", "concatenated = Batch.cat([b3, b4])\n", "print(\"\\nCat (same keys):\")\n", "print(f\" a: {concatenated.a}\")\n", "print(f\" b: {concatenated.b}\")\n", "\n", "# Cat with different structures raises error\n", "try:\n", " Batch.cat([b1, b2]) # Different keys!\n", "except ValueError:\n", " print(\"\\nCat with different keys: ValueError raised\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Best Practices\n", "\n", "### When to Use Batch\n", "\n", "**Good use cases:**\n", "- Collecting environment data (transitions, episodes)\n", "- Storing replay buffer data\n", "- Passing data between components (collector → buffer → policy)\n", "- Handling heterogeneous observations (dict spaces)\n", "\n", "**Consider alternatives:**\n", "- Simple scalar tracking (use regular variables)\n", "- Pure tensor operations (use PyTorch tensors directly)\n", "- Deeply nested arbitrary structures (use dataclasses)\n", "\n", "### Structuring Your Batches\n", "\n", "**Use protocols for type safety:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Good: Use protocols for clear interfaces\n", "def train_step(batch: RolloutBatchProtocol) -> float:\n", " \"\"\"IDE knows what fields exist.\"\"\"\n", " loss = ((batch.rew - 0.5) ** 2).mean() # Type-safe\n", " return float(loss)\n", "\n", "\n", "# Create properly typed batch\n", "train_batch = cast(\n", " RolloutBatchProtocol,\n", " Batch(\n", " obs=np.random.randn(10, 4),\n", " obs_next=np.random.randn(10, 4),\n", " act=np.random.randint(0, 2, 10),\n", " rew=np.random.randn(10),\n", " terminated=np.zeros(10, dtype=bool),\n", " truncated=np.zeros(10, dtype=bool),\n", " info=np.array([{}] * 10, dtype=object),\n", " ),\n", ")\n", "\n", "loss = train_step(train_batch)\n", "print(f\"Loss: {loss:.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Consistent key naming:**\n", "- Follow Tianshou conventions: `obs`, `act`, `rew`, `terminated`, `truncated`\n", "- Use descriptive names: `camera_obs` not `co`\n", "- Avoid name collisions with Batch methods: don't use `keys`, `items`, `get`, etc.\n", "\n", "**When to nest vs flatten:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Good: Nest related data\n", "batch_nested = Batch(\n", " obs=Batch(\n", " camera=np.zeros((32, 64, 64, 3)), lidar=np.zeros((32, 100)), position=np.zeros((32, 3))\n", " ),\n", " act=np.zeros(32),\n", ")\n", "print(\"Nested structure for related obs:\")\n", "print(f\" Access: batch.obs.camera.shape = {batch_nested.obs.camera.shape}\")\n", "\n", "# Less good: Flat structure loses semantic grouping\n", "batch_flat = Batch(\n", " camera=np.zeros((32, 64, 64, 3)),\n", " lidar=np.zeros((32, 100)),\n", " position=np.zeros((32, 3)),\n", " act=np.zeros(32),\n", ")\n", "print(\"\\nFlat structure (works but less clear):\")\n", "print(f\" Access: batch.camera.shape = {batch_flat.camera.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Performance Tips\n", "\n", "**Use in-place operations:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "batch = Batch(a=np.random.randn(1000, 100))\n", "\n", "# Creates copy\n", "start = time.time()\n", "for _ in range(100):\n", " _ = batch.to_torch()\n", "time_copy = time.time() - start\n", "\n", "# In-place (faster)\n", "start = time.time()\n", "for _ in range(100):\n", " batch.to_torch_()\n", " batch.to_numpy_()\n", "time_inplace = time.time() - start\n", "\n", "print(f\"Copy: {time_copy:.4f}s\")\n", "print(f\"In-place: {time_inplace:.4f}s\")\n", "print(f\"Speedup: {time_copy / time_inplace:.1f}x\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Be mindful of copies:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "arr = np.array([1, 2, 3])\n", "\n", "# Default: creates reference (be careful!)\n", "batch1 = Batch(a=arr)\n", "batch1.a[0] = 999\n", "print(f\"Original array modified: {arr}\") # Changed!\n", "\n", "# Explicit copy when needed\n", "arr = np.array([1, 2, 3])\n", "batch2 = Batch(a=arr, copy=True)\n", "batch2.a[0] = 999\n", "print(f\"Original array preserved: {arr}\") # Unchanged" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Avoid unnecessary conversions:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Inefficient: multiple conversions\n", "batch = Batch(a=np.random.randn(100, 10))\n", "batch.to_torch_()\n", "batch.to_numpy_() # Unnecessary if we just need NumPy\n", "\n", "# Efficient: convert once, use many times\n", "batch = Batch(a=np.random.randn(100, 10))\n", "batch.to_torch_() # Convert once\n", "# ... do torch operations ...\n", "# Keep as torch if that's what you need!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Common Patterns\n", "\n", "**Pattern 1: Building batches incrementally**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Collect data from multiple steps\n", "step_data = []\n", "for i in range(5):\n", " step_data.append({\"obs\": np.random.randn(4), \"act\": i, \"rew\": np.random.randn()})\n", "\n", "# Convert to batch (automatically stacks)\n", "episode_batch = Batch(step_data)\n", "print(\"Episode batch shape:\", episode_batch.shape)\n", "print(\"obs shape:\", episode_batch.obs.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Pattern 2: Slicing for mini-batches**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Large batch\n", "large_batch = Batch(obs=np.random.randn(100, 4), act=np.random.randint(0, 2, 100))\n", "\n", "# Split into mini-batches\n", "batch_size = 32\n", "for mini_batch in large_batch.split(batch_size, shuffle=True):\n", " print(f\"Mini-batch size: {len(mini_batch)}\")\n", " # Train on mini_batch...\n", " break # Just show one iteration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Pattern 3: Extending batches**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Start with some data\n", "batch = Batch(obs=np.array([[1, 2], [3, 4]]), act=np.array([0, 1]))\n", "print(\"Initial:\", len(batch))\n", "\n", "# Add more data\n", "new_data = Batch(obs=np.array([[5, 6]]), act=np.array([1]))\n", "batch.cat_(new_data)\n", "print(\"After cat_:\", len(batch))\n", "print(\"obs:\", batch.obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9. Summary\n", "\n", "### Key Takeaways\n", "\n", "1. **Batch = Dict + Array**: Combines key-value storage with array operations\n", "2. **Hierarchical Structure**: Perfect for complex RL data (nested observations, etc.)\n", "3. **Type Safety via Protocols**: Use `BatchProtocol` subclasses for IDE support and type checking\n", "4. **Special RL Features**: Distribution slicing, heterogeneous aggregation, missing value handling\n", "5. **Remember**: Iteration is over indices, NOT keys!\n", "\n", "### Quick Reference\n", "\n", "| Operation | Code | Notes |\n", "|-----------|------|-------|\n", "| Create | `Batch(a=1, b=[2, 3])` | Auto-converts types |\n", "| Access | `batch.a` or `batch[\"a\"]` | Equivalent |\n", "| Index | `batch[0]`, `batch[:10]` | Returns sliced Batch |\n", "| Iterate indices | `for item in batch:` | Yields batch[0], batch[1], ... |\n", "| Iterate keys | `for k in batch.keys():` | Like dict |\n", "| Stack | `Batch.stack([b1, b2])` | Adds dimension |\n", "| Concatenate | `Batch.cat([b1, b2])` | Extends dimension |\n", "| Split | `batch.split(size=10)` | Returns iterator |\n", "| To PyTorch | `batch.to_torch_()` | In-place |\n", "| To NumPy | `batch.to_numpy_()` | In-place |\n", "| Transform | `batch.apply_values_transform(fn)` | Recursive |\n", "\n", "### Next Steps\n", "\n", "- **Collector Deep Dive**: See how Batch flows through data collection\n", "- **Buffer Deep Dive**: Understand how Batch is stored and sampled\n", "- **Policy Guide**: Learn how policies work with BatchProtocol\n", "- **API Reference**: Full details at [Batch API documentation](https://tianshou.org/en/stable/api/tianshou.data.html#tianshou.data.Batch)\n", "\n", "### Questions?\n", "\n", "- Check the [Tianshou GitHub discussions](https://github.com/thu-ml/tianshou/discussions)\n", "- Review [issue tracker](https://github.com/thu-ml/tianshou/issues) for known gotchas\n", "- Read the [source code](https://github.com/thu-ml/tianshou/blob/master/tianshou/data/batch.py) - it's well-documented!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Appendix: Serialization & Advanced Topics\n", "\n", "### Pickle Support" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Batch objects are picklable\n", "original = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))\n", "\n", "# Serialize and deserialize\n", "serialized = pickle.dumps(original)\n", "restored = pickle.loads(serialized)\n", "\n", "print(\"Original obs.a:\", original.obs.a)\n", "print(\"Restored obs.a:\", restored.obs.a)\n", "print(\"Equal:\", original == restored)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Advanced Indexing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Multi-dimensional data\n", "batch = Batch(a=np.random.randn(5, 3, 2))\n", "print(\"Original shape:\", batch.a.shape)\n", "\n", "# Various indexing operations\n", "print(\"batch[0].a.shape:\", batch[0].a.shape)\n", "print(\"batch[:, 0].a.shape:\", batch[:, 0].a.shape)\n", "print(\"batch[[0, 2, 4]].a.shape:\", batch[[0, 2, 4]].a.shape)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: docs/02_deep_dives/L2_Buffer.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Buffer: Experience Replay in Tianshou\n", "\n", "The replay buffer is a fundamental component in reinforcement learning, particularly for off-policy algorithms. Tianshou's buffer implementation extends beyond simple data storage to provide sophisticated trajectory tracking, efficient sampling, and seamless integration with the RL training pipeline.\n", "\n", "This tutorial provides comprehensive coverage of Tianshou's buffer system, from basic concepts to advanced features and integration patterns." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "import tempfile\n", "\n", "import numpy as np\n", "\n", "from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer, VectorReplayBuffer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Introduction: Why Buffers in Reinforcement Learning?\n", "\n", "### The Role of Experience Replay\n", "\n", "Experience replay is a critical technique in modern reinforcement learning that addresses three fundamental challenges:\n", "\n", "1. **Breaking Temporal Correlation**: Sequential experiences from an agent are highly correlated. Training directly on these sequences can lead to unstable learning. By storing experiences and sampling randomly, we break these correlations.\n", "\n", "2. **Sample Efficiency**: In RL, collecting data through environment interaction is often expensive. Experience replay allows us to reuse each experience multiple times for training, dramatically improving sample efficiency.\n", "\n", "3. **Mini-batch Training**: Modern deep learning requires mini-batch gradient descent. Buffers enable efficient batching of experiences for neural network training.\n", "\n", "### Why Not Alternatives?\n", "\n", "**Plain Python Lists**\n", "- No efficient random sampling\n", "- No automatic circular queue behavior\n", "- No trajectory boundary tracking\n", "- Poor memory management for large datasets\n", "\n", "**Simple Batch Storage**\n", "- No automatic overwriting when full\n", "- No episode metadata (returns, lengths)\n", "- No methods for boundary navigation (prev/next)\n", "- No specialized sampling strategies\n", "\n", "### Buffer = Batch + Trajectory Management + Sampling\n", "\n", "Tianshou's buffers build on the `Batch` class to provide:\n", "- **Circular queue storage**: Automatic overwriting of oldest data\n", "- **Trajectory tracking**: Episode boundaries, returns, and lengths\n", "- **Efficient sampling**: Random access with various strategies\n", "- **Integration utilities**: Seamless connection to Collector and Policy\n", "\n", "### Use Cases\n", "\n", "- **Off-policy algorithms**: DQN, SAC, TD3, DDPG require experience replay\n", "- **On-policy with replay**: Some PPO implementations reuse buffer data\n", "- **Offline RL**: Loading and using pre-collected datasets\n", "- **Multi-environment training**: VectorReplayBuffer for parallel collection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Buffer Types and Hierarchy\n", "\n", "Tianshou provides several buffer implementations, each designed for specific use cases. Understanding this hierarchy is crucial for choosing the right buffer.\n", "\n", "### Buffer Hierarchy\n", "\n", "```mermaid\n", "graph TD\n", " RB[ReplayBuffer
Single environment
Circular queue] --> RBM[ReplayBufferManager
Manages multiple buffers
Contiguous memory]\n", " RBM --> VRB[VectorReplayBuffer
Parallel environments
Maintains temporal order]\n", " \n", " RB --> PRB[PrioritizedReplayBuffer
TD-error based sampling
Importance weights]\n", " PRB --> PVRB[PrioritizedVectorReplayBuffer
Prioritized + Parallel]\n", " \n", " RB --> CRB[CachedReplayBuffer
Primary + auxiliary caches
Imitation learning]\n", " \n", " RB --> HERB[HERReplayBuffer
Hindsight Experience Replay
Goal-conditioned RL]\n", " HERB --> HVRB[HERVectorReplayBuffer
HER + Parallel]\n", " \n", " style RB fill:#e1f5ff\n", " style RBM fill:#fff4e1\n", " style VRB fill:#ffe1f5\n", " style PRB fill:#e8f5e1\n", " style CRB fill:#f5e1e1\n", " style HERB fill:#e1e1f5\n", "```\n", "\n", "### When to Use Which Buffer\n", "\n", "**ReplayBuffer**: Single environment scenarios\n", "- Simple setup and testing\n", "- Debugging algorithms\n", "- Low-parallelism training\n", "\n", "**VectorReplayBuffer**: Multiple parallel environments (most common)\n", "- Standard production use case\n", "- Efficient parallel data collection\n", "- Maintains per-environment episode boundaries\n", "\n", "**PrioritizedReplayBuffer**: DQN variants with prioritization\n", "- Rainbow DQN\n", "- Algorithms requiring importance sampling\n", "- When some transitions are more valuable than others\n", "\n", "**CachedReplayBuffer**: Separate primary and auxiliary caches\n", "- Imitation learning (expert + agent data)\n", "- GAIL and similar algorithms\n", "- When you need different sampling strategies for different data sources\n", "\n", "**HERReplayBuffer**: Goal-conditioned reinforcement learning\n", "- Sparse reward environments\n", "- Robotics tasks with explicit goals\n", "- Relabeling failed experiences with achieved goals" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Basic Operations\n", "\n", "### 3.1 Construction and Configuration\n", "\n", "The ReplayBuffer constructor accepts several important parameters that control its behavior:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a buffer with all configuration options\n", "buf = ReplayBuffer(\n", " size=20, # Maximum capacity (transitions)\n", " stack_num=1, # Frame stacking for RNNs (default: 1, no stacking)\n", " ignore_obs_next=False, # Save memory by not storing obs_next\n", " save_only_last_obs=False, # For temporal stacking (Atari-style)\n", " sample_avail=False, # Sample only valid indices for frame stacking\n", " random_seed=42, # Reproducible sampling\n", ")\n", "\n", "print(f\"Buffer created: {buf}\")\n", "print(f\"Max size: {buf.maxsize}\")\n", "print(f\"Current length: {len(buf)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Parameter Explanations**:\n", "\n", "- `size`: Maximum number of transitions the buffer can hold. When full, oldest data is overwritten.\n", "- `stack_num`: Number of consecutive frames to stack. Used for RNN inputs or frame-based policies (Atari).\n", "- `ignore_obs_next`: If True, obs_next is not stored, saving memory. The buffer reconstructs it from the next obs when needed.\n", "- `save_only_last_obs`: For temporal stacking. Only saves the last observation in a stack.\n", "- `sample_avail`: When True with stack_num > 1, only samples indices where a complete stack is available.\n", "- `random_seed`: Seeds the random number generator for reproducible sampling." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.2 Reserved Keys and the Done Flag System\n", "\n", "ReplayBuffer uses nine reserved keys that integrate with Gymnasium conventions. Understanding the done flag system is critical." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# The nine reserved keys\n", "print(\"Reserved keys:\")\n", "print(ReplayBuffer._reserved_keys)\n", "print(\"\\nKeys required for add():\")\n", "print(ReplayBuffer._required_keys_for_add)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Important: Understanding done, terminated, and truncated**\n", "\n", "Gymnasium (the successor to OpenAI Gym) introduced a crucial distinction:\n", "\n", "- `terminated`: Episode ended naturally (agent reached goal or failed)\n", " - Examples: CartPole fell over, agent reached goal state\n", " - Should be used for bootstrapping calculations\n", "\n", "- `truncated`: Episode was cut off artificially (time limit, external interruption)\n", " - Examples: Maximum episode length reached, environment reset externally \n", " - Should NOT be used for bootstrapping (the episode could have continued)\n", "\n", "- `done`: Computed automatically as `terminated OR truncated`\n", " - Used internally for episode boundary tracking\n", " - You should NEVER manually set this field\n", "\n", "**Best Practice**: Always use the `info` dictionary for custom metadata rather than adding top-level keys:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# GOOD: Custom metadata in info dictionary\n", "good_batch = Batch(\n", " obs=np.array([1.0, 2.0]),\n", " act=0,\n", " rew=1.0,\n", " terminated=False,\n", " truncated=False,\n", " obs_next=np.array([1.5, 2.5]),\n", " info={\"custom_metric\": 0.95, \"step_count\": 10}, # Custom data here\n", ")\n", "\n", "# BAD: Don't add custom top-level keys (may conflict with future buffer features)\n", "# bad_batch = Batch(..., custom_metric=0.95) # Don't do this!\n", "\n", "print(\"Good batch structure:\")\n", "print(good_batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.3 Circular Queue Storage\n", "\n", "The buffer implements a circular queue: when it reaches maximum capacity, new data overwrites the oldest entries." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a small buffer to demonstrate circular behavior\n", "demo_buf = ReplayBuffer(size=5)\n", "\n", "print(\"Adding 3 transitions:\")\n", "for i in range(3):\n", " demo_buf.add(\n", " Batch(\n", " obs=i,\n", " act=i,\n", " rew=float(i),\n", " terminated=False,\n", " truncated=False,\n", " obs_next=i + 1,\n", " info={},\n", " )\n", " )\n", "print(f\"Length: {len(demo_buf)}, Max: {demo_buf.maxsize}\")\n", "print(f\"Observations: {demo_buf.obs[: len(demo_buf)]}\")\n", "\n", "print(\"\\nAdding 5 more transitions (total 8, exceeds capacity 5):\")\n", "for i in range(3, 8):\n", " demo_buf.add(\n", " Batch(\n", " obs=i,\n", " act=i,\n", " rew=float(i),\n", " terminated=False,\n", " truncated=False,\n", " obs_next=i + 1,\n", " info={},\n", " )\n", " )\n", "print(f\"Length: {len(demo_buf)}, Max: {demo_buf.maxsize}\")\n", "print(f\"Observations: {demo_buf.obs[: len(demo_buf)]}\")\n", "print(\"\\nNotice: First 3 transitions (0,1,2) were overwritten by (3,4,5)\")\n", "print(\"Buffer now contains: [3, 4, 5, 6, 7]\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.4 Batch-Compatible Operations\n", "\n", "Since ReplayBuffer extends Batch functionality, it supports standard indexing and slicing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Indexing and slicing\n", "print(\"Last transition:\")\n", "print(demo_buf[-1])\n", "\n", "print(\"\\nLast 3 transitions:\")\n", "print(demo_buf[-3:])\n", "\n", "print(\"\\nSpecific indices [0, 2, 4]:\")\n", "print(demo_buf[np.array([0, 2, 4])])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Trajectory Management\n", "\n", "A key distinguishing feature of ReplayBuffer is its automatic tracking of episode boundaries and metadata.\n", "\n", "### 4.1 Episode Tracking and Metadata\n", "\n", "The `add()` method returns four values that provide episode information:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a fresh buffer for trajectory demonstration\n", "traj_buf = ReplayBuffer(size=20)\n", "\n", "print(\"Episode 1: 4 steps, terminates naturally\")\n", "for i in range(4):\n", " idx, ep_rew, ep_len, ep_start = traj_buf.add(\n", " Batch(\n", " obs=i,\n", " act=i,\n", " rew=float(i + 1), # Rewards: 1, 2, 3, 4\n", " terminated=i == 3, # Last step terminates\n", " truncated=False,\n", " obs_next=i + 1,\n", " info={},\n", " )\n", " )\n", " print(f\" Step {i}: idx={idx}, ep_rew={ep_rew}, ep_len={ep_len}, ep_start={ep_start}\")\n", "\n", "print(\"\\nNotice: Episode return (10.0) and length (4) only appear at the end!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Return Values Explained**:\n", "\n", "1. `idx`: Index where the transition was inserted (np.ndarray of shape (1,))\n", "2. `ep_rew`: Episode return, only non-zero when `done=True` (np.ndarray of shape (1,))\n", "3. `ep_len`: Episode length, only non-zero when `done=True` (np.ndarray of shape (1,))\n", "4. `ep_start`: Index where the episode started (np.ndarray of shape (1,))\n", "\n", "This automatic computation eliminates manual episode tracking during data collection." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Continue with Episode 2: 5 steps\n", "print(\"Episode 2: 5 steps, truncated (time limit)\")\n", "for i in range(4, 9):\n", " idx, ep_rew, ep_len, ep_start = traj_buf.add(\n", " Batch(\n", " obs=i,\n", " act=i,\n", " rew=float(i + 1),\n", " terminated=False,\n", " truncated=i == 8, # Last step truncated\n", " obs_next=i + 1,\n", " info={},\n", " )\n", " )\n", " if i == 8:\n", " print(\n", " f\" Final step: idx={idx}, ep_rew={ep_rew[0]:.1f}, ep_len={ep_len[0]}, ep_start={ep_start}\"\n", " )\n", "\n", "# Episode 3: Ongoing (not finished)\n", "print(\"\\nEpisode 3: 3 steps, ongoing (not done)\")\n", "for i in range(9, 12):\n", " idx, ep_rew, ep_len, ep_start = traj_buf.add(\n", " Batch(\n", " obs=i,\n", " act=i,\n", " rew=float(i + 1),\n", " terminated=False,\n", " truncated=False, # Episode continues\n", " obs_next=i + 1,\n", " info={},\n", " )\n", " )\n", " if i == 11:\n", " print(\n", " f\" Latest step: idx={idx}, ep_rew={ep_rew}, ep_len={ep_len} (zeros because not done)\"\n", " )\n", "\n", "print(f\"\\nBuffer state: {len(traj_buf)} transitions across 2 complete + 1 ongoing episode\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.2 Boundary Navigation: prev() and next()\n", "\n", "The buffer provides methods to navigate within episodes while respecting episode boundaries:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Examine the buffer structure\n", "print(\"Buffer contents:\")\n", "print(f\"Indices: {np.arange(len(traj_buf))}\")\n", "print(f\"Obs: {traj_buf.obs[: len(traj_buf)]}\")\n", "print(f\"Terminated: {traj_buf.terminated[: len(traj_buf)]}\")\n", "print(f\"Truncated: {traj_buf.truncated[: len(traj_buf)]}\")\n", "print(f\"Done: {traj_buf.done[: len(traj_buf)]}\")\n", "print(\"\\nEpisode boundaries: indices 3 (terminated) and 8 (truncated)\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# prev() returns the previous index within the same episode\n", "# It STOPS at episode boundaries\n", "test_indices = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])\n", "prev_indices = traj_buf.prev(test_indices)\n", "\n", "print(\"prev() behavior:\")\n", "print(f\"Index: {test_indices}\")\n", "print(f\"Prev: {prev_indices}\")\n", "print(\"\\nObservations:\")\n", "print(\"- Index 0 stays at 0 (start of episode 1)\")\n", "print(\"- Index 4 stays at 4 (start of episode 2, can't go back to episode 1)\")\n", "print(\"- Index 9 stays at 9 (start of episode 3, can't go back to episode 2)\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# next() returns the next index within the same episode\n", "# It STOPS at episode boundaries\n", "next_indices = traj_buf.next(test_indices)\n", "\n", "print(\"next() behavior:\")\n", "print(f\"Index: {test_indices}\")\n", "print(f\"Next: {next_indices}\")\n", "print(\"\\nObservations:\")\n", "print(\"- Index 3 stays at 3 (end of episode 1, terminated)\")\n", "print(\"- Index 8 stays at 8 (end of episode 2, truncated)\")\n", "print(\"- Indices 9-11 advance normally (episode 3 ongoing)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Use Cases for prev() and next()**:\n", "\n", "These methods are essential for computing algorithmic quantities:\n", "- **N-step returns**: Use prev() to look back N steps within an episode\n", "- **GAE (Generalized Advantage Estimation)**: Navigate backwards through episodes\n", "- **Episode extraction**: Find episode start/end indices\n", "- **Temporal difference targets**: Ensure you don't bootstrap across episode boundaries" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.3 Identifying Unfinished Episodes\n", "\n", "The `unfinished_index()` method returns indices of ongoing episodes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "unfinished = traj_buf.unfinished_index()\n", "print(f\"Unfinished episode indices: {unfinished}\")\n", "print(f\"Latest step of ongoing episode: obs={traj_buf.obs[unfinished[0]]}\")\n", "\n", "# After finishing episode 3\n", "traj_buf.add(\n", " Batch(\n", " obs=12,\n", " act=12,\n", " rew=13.0,\n", " terminated=True,\n", " truncated=False,\n", " obs_next=13,\n", " info={},\n", " )\n", ")\n", "\n", "unfinished_after = traj_buf.unfinished_index()\n", "print(\"\\nAfter finishing episode 3:\")\n", "print(f\"Unfinished episodes: {unfinished_after} (empty array)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Sampling Strategies\n", "\n", "Efficient sampling is critical for RL training. The buffer provides several sampling methods and strategies.\n", "\n", "### 5.1 Basic Sampling" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a buffer with some data\n", "sample_buf = ReplayBuffer(size=100)\n", "for i in range(50):\n", " sample_buf.add(\n", " Batch(\n", " obs=i,\n", " act=i % 4,\n", " rew=np.random.random(),\n", " terminated=(i + 1) % 10 == 0,\n", " truncated=False,\n", " obs_next=i + 1,\n", " info={},\n", " )\n", " )\n", "\n", "# Sample with batch_size\n", "batch, indices = sample_buf.sample(batch_size=8)\n", "print(f\"Sampled batch size: {len(batch)}\")\n", "print(f\"Sampled indices: {indices}\")\n", "print(f\"Sampled observations: {batch.obs}\")\n", "\n", "# batch_size=None: return all data in random order\n", "all_data, all_indices = sample_buf.sample(batch_size=None)\n", "print(f\"\\nSample all (batch_size=None): {len(all_data)} transitions\")\n", "\n", "# batch_size=0: return all data in buffer order\n", "ordered_data, ordered_indices = sample_buf.sample(batch_size=0)\n", "print(f\"Get all in order (batch_size=0): {len(ordered_data)} transitions\")\n", "print(f\"Indices in order: {ordered_indices[:10]}...\") # Show first 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Sampling Behavior Summary**:\n", "\n", "- `batch_size > 0`: Random sample of specified size\n", "- `batch_size = None`: All data in random order \n", "- `batch_size = 0`: All data in insertion order\n", "- `batch_size < 0`: Empty array (edge case handling)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.2 Frame Stacking\n", "\n", "The `stack_num` parameter enables automatic frame stacking, useful for RNN inputs or Atari-style environments where temporal context matters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create buffer with frame stacking\n", "stack_buf = ReplayBuffer(size=20, stack_num=4)\n", "\n", "# Add observations: 0, 1, 2, ..., 9\n", "for i in range(10):\n", " stack_buf.add(\n", " Batch(\n", " obs=np.array([i]), # Single frame\n", " act=0,\n", " rew=1.0,\n", " terminated=i == 9,\n", " truncated=False,\n", " obs_next=np.array([i + 1]),\n", " info={},\n", " )\n", " )\n", "\n", "# Get stacked frames for index 6\n", "# Should return [3, 4, 5, 6] (4 consecutive frames ending at 6)\n", "stacked = stack_buf.get(index=6, key=\"obs\")\n", "print(\"Frame stacking demo:\")\n", "print(\"Requested index: 6\")\n", "print(f\"Stacked frames shape: {stacked.shape}\")\n", "print(f\"Stacked frames: {stacked.flatten()}\")\n", "print(\"\\nExplanation: stack_num=4, so index 6 returns [obs[3], obs[4], obs[5], obs[6]]\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Demonstrate episode boundary handling with frame stacking\n", "boundary_buf = ReplayBuffer(size=20, stack_num=4)\n", "\n", "# Episode 1: indices 0-4\n", "for i in range(5):\n", " boundary_buf.add(\n", " Batch(\n", " obs=np.array([i]),\n", " act=0,\n", " rew=1.0,\n", " terminated=i == 4,\n", " truncated=False,\n", " obs_next=np.array([i + 1]),\n", " info={},\n", " )\n", " )\n", "\n", "# Episode 2: indices 5-9\n", "for i in range(5, 10):\n", " boundary_buf.add(\n", " Batch(\n", " obs=np.array([i]),\n", " act=0,\n", " rew=1.0,\n", " terminated=i == 9,\n", " truncated=False,\n", " obs_next=np.array([i + 1]),\n", " info={},\n", " )\n", " )\n", "\n", "# Try to get stacked frames at episode boundary\n", "boundary_stack = boundary_buf.get(index=6, key=\"obs\") # Early in episode 2\n", "print(\"\\nFrame stacking at episode boundary:\")\n", "print(f\"Index 6 stacked frames: {boundary_stack.flatten()}\")\n", "print(\"Notice: Frames don't cross episode boundary (5,5,5,6 not 3,4,5,6)\")\n", "print(\"The buffer uses prev() internally, which respects episode boundaries\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Frame Stacking Use Cases**:\n", "\n", "- **RNN/LSTM inputs**: Provide temporal context to recurrent networks\n", "- **Atari games**: Stack 4 frames to capture motion (as in DQN paper)\n", "- **Velocity estimation**: Multiple frames allow computing derivatives\n", "- **Partially observable environments**: Build up state estimates\n", "\n", "**Important Notes**:\n", "- Frame stacking respects episode boundaries (won't stack across episodes)\n", "- Set `sample_avail=True` to only sample indices where full stacks are available\n", "- `save_only_last_obs=True` saves memory in Atari-style setups" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. VectorReplayBuffer: Parallel Environment Support\n", "\n", "VectorReplayBuffer is essential for modern RL training with parallel environments. It maintains separate subbuffers for each environment while providing a unified interface.\n", "\n", "### 6.1 Motivation and Architecture\n", "\n", "When training with multiple parallel environments (e.g., 8 environments running simultaneously), we need:\n", "- **Per-environment episode tracking**: Each environment has its own episode boundaries\n", "- **Temporal ordering**: Preserve the sequence of events within each environment\n", "- **Unified sampling**: Sample uniformly across all environments for training\n", "\n", "```mermaid\n", "graph LR\n", " E1[Env 1] --> B1[Subbuffer 1
2500 capacity]\n", " E2[Env 2] --> B2[Subbuffer 2
2500 capacity]\n", " E3[Env 3] --> B3[Subbuffer 3
2500 capacity]\n", " E4[Env 4] --> B4[Subbuffer 4
2500 capacity]\n", " \n", " B1 --> VRB[VectorReplayBuffer
Total: 10000
Unified Sampling]\n", " B2 --> VRB\n", " B3 --> VRB\n", " B4 --> VRB\n", " \n", " VRB --> Policy[Policy Training]\n", " \n", " style E1 fill:#e1f5ff\n", " style E2 fill:#e1f5ff\n", " style E3 fill:#e1f5ff\n", " style E4 fill:#e1f5ff\n", " style B1 fill:#fff4e1\n", " style B2 fill:#fff4e1\n", " style B3 fill:#fff4e1\n", " style B4 fill:#fff4e1\n", " style VRB fill:#ffe1f5\n", " style Policy fill:#e8f5e1\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create VectorReplayBuffer for 4 parallel environments\n", "vec_buf = VectorReplayBuffer(\n", " total_size=100, # Total capacity across all subbuffers\n", " buffer_num=4, # Number of parallel environments\n", ")\n", "\n", "print(\"VectorReplayBuffer created:\")\n", "print(f\"Total size: {vec_buf.maxsize}\")\n", "print(f\"Number of subbuffers: {vec_buf.buffer_num}\")\n", "print(f\"Size per subbuffer: {vec_buf.maxsize // vec_buf.buffer_num}\")\n", "print(f\"Subbuffer edges: {vec_buf.subbuffer_edges}\")\n", "print(\"\\nSubbuffer edges define the boundary indices: [0, 25, 50, 75, 100]\")\n", "print(\"Subbuffer 0: indices 0-24, Subbuffer 1: indices 25-49, etc.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.2 The buffer_ids Parameter\n", "\n", "This is one of the most confusing aspects for new users. The `buffer_ids` parameter specifies which subbuffer each transition belongs to." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Simulate data from 4 parallel environments\n", "# Each environment produces one transition\n", "parallel_batch = Batch(\n", " obs=np.array([[0.1, 0.2], [1.1, 1.2], [2.1, 2.2], [3.1, 3.2]]), # 4 observations\n", " act=np.array([0, 1, 0, 1]), # 4 actions\n", " rew=np.array([1.0, 2.0, 3.0, 4.0]), # 4 rewards\n", " terminated=np.array([False, False, False, False]),\n", " truncated=np.array([False, False, False, False]),\n", " obs_next=np.array([[0.2, 0.3], [1.2, 1.3], [2.2, 2.3], [3.2, 3.3]]),\n", " info=np.array([{}, {}, {}, {}], dtype=object),\n", ")\n", "\n", "print(\"Parallel batch shape:\", parallel_batch.obs.shape)\n", "print(\"This represents 4 transitions, one from each environment\")\n", "\n", "# Add with buffer_ids specifying which subbuffer each transition goes to\n", "indices, ep_rews, ep_lens, ep_starts = vec_buf.add(\n", " parallel_batch,\n", " buffer_ids=[0, 1, 2, 3], # Transition 0→Subbuf 0, 1→Subbuf 1, etc.\n", ")\n", "\n", "print(f\"\\nAdded to indices: {indices}\")\n", "print(\"Notice: Indices are in different subbuffers:\")\n", "print(f\" Index {indices[0]} in subbuffer 0 (range 0-24)\")\n", "print(f\" Index {indices[1]} in subbuffer 1 (range 25-49)\")\n", "print(f\" Index {indices[2]} in subbuffer 2 (range 50-74)\")\n", "print(f\" Index {indices[3]} in subbuffer 3 (range 75-99)\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Add more data to demonstrate buffer_ids\n", "# Environments don't always produce data in order 0,1,2,3\n", "# For example, if only environments 1 and 3 are ready:\n", "partial_batch = Batch(\n", " obs=np.array([[1.2, 1.3], [3.2, 3.3]]), # Only 2 observations\n", " act=np.array([0, 1]),\n", " rew=np.array([2.5, 4.5]),\n", " terminated=np.array([False, False]),\n", " truncated=np.array([False, False]),\n", " obs_next=np.array([[1.3, 1.4], [3.3, 3.4]]),\n", " info=np.array([{}, {}], dtype=object),\n", ")\n", "\n", "# Only environments 1 and 3 produced data\n", "indices2, _, _, _ = vec_buf.add(\n", " partial_batch,\n", " buffer_ids=[1, 3], # Only these two subbuffers receive data\n", ")\n", "\n", "print(\"Added partial batch (only envs 1 and 3):\")\n", "print(f\"Indices: {indices2}\")\n", "print(f\"Subbuffer 1 received data at index {indices2[0]}\")\n", "print(f\"Subbuffer 3 received data at index {indices2[1]}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Important: buffer_ids Requirements**:\n", "\n", "For `VectorReplayBuffer`:\n", "- `buffer_ids` length must match batch size\n", "- Values must be in range [0, buffer_num)\n", "- Can be partial (not all environments at once)\n", "\n", "For regular `ReplayBuffer`:\n", "- If `buffer_ids` is not None, it must be [0]\n", "- Batch must have shape (1, data_length)\n", "- This is for API compatibility with VectorReplayBuffer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.3 Subbuffer Edges and Episode Handling\n", "\n", "Subbuffer edges prevent episodes from spanning across subbuffers, ensuring data from different environments doesn't get mixed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# The subbuffer_edges property defines boundaries\n", "print(f\"Subbuffer edges: {vec_buf.subbuffer_edges}\")\n", "print(\"\\nThis creates 4 subbuffers:\")\n", "for i in range(vec_buf.buffer_num):\n", " start = vec_buf.subbuffer_edges[i]\n", " end = vec_buf.subbuffer_edges[i + 1]\n", " print(f\"Subbuffer {i}: indices [{start}, {end})\")\n", "\n", "# Episodes cannot cross these boundaries\n", "# prev() and next() respect subbuffer edges just like episode boundaries\n", "test_idx = np.array([24, 25, 49, 50]) # At subbuffer edges\n", "prev_result = vec_buf.prev(test_idx)\n", "next_result = vec_buf.next(test_idx)\n", "\n", "print(\"\\nBoundary navigation test:\")\n", "print(f\"Indices: {test_idx}\")\n", "print(f\"prev(): {prev_result}\")\n", "print(f\"next(): {next_result}\")\n", "print(\"\\nNotice: prev/next don't cross subbuffer boundaries\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.4 Sampling from VectorReplayBuffer\n", "\n", "Sampling is uniform across all subbuffers (proportional to their current fill level):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Add more data to have enough for sampling\n", "for _step in range(10):\n", " batch = Batch(\n", " obs=np.random.randn(4, 2),\n", " act=np.random.randint(0, 2, size=4),\n", " rew=np.random.random(4),\n", " terminated=np.zeros(4, dtype=bool),\n", " truncated=np.zeros(4, dtype=bool),\n", " obs_next=np.random.randn(4, 2),\n", " info=np.array([{}] * 4, dtype=object),\n", " )\n", " vec_buf.add(batch, buffer_ids=[0, 1, 2, 3])\n", "\n", "# Sample batch\n", "sampled, indices = vec_buf.sample(batch_size=16)\n", "print(f\"Sampled {len(sampled)} transitions\")\n", "print(f\"Sample indices (from different subbuffers): {indices}\")\n", "print(\"\\nNotice indices span across all subbuffer ranges\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Specialized Buffer Variants\n", "\n", "### 7.1 PrioritizedReplayBuffer\n", "\n", "Implements prioritized experience replay where transitions are sampled based on their TD-error magnitudes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "# Create prioritized buffer\nprio_buf = PrioritizedReplayBuffer(\n size=100,\n alpha=0.6, # Prioritization exponent (0=uniform, 1=fully prioritized)\n beta=0.4, # Importance sampling correction (annealed to 1)\n)\n\n# Add some transitions\nfor i in range(20):\n prio_buf.add(\n Batch(\n obs=np.array([i]),\n act=i % 4,\n rew=np.random.random(),\n terminated=False,\n truncated=False,\n obs_next=np.array([i + 1]),\n info={},\n )\n )\n\n# Sample returns batch and indices\n# Importance weights are INSIDE the batch as batch.weight\nbatch, indices = prio_buf.sample(batch_size=8)\nprint(f\"Sampled batch size: {len(batch)}\")\nprint(f\"Indices: {indices}\")\nprint(f\"Importance weights (batch.weight): {batch.weight}\")\nprint(\"\\nWeights are stored in batch.weight and compensate for biased sampling\")" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# After computing TD-errors from the sampled batch, update priorities\n", "# In practice, these would be actual TD-errors: |Q(s,a) - (r + γ*max Q(s',a'))|\n", "fake_td_errors = np.random.random(len(indices)) * 10 # Simulated TD-errors\n", "\n", "# Update priorities (higher TD-error = higher priority)\n", "prio_buf.update_weight(indices, fake_td_errors)\n", "\n", "print(\"Updated priorities based on TD-errors\")\n", "print(\"Transitions with higher TD-errors will be sampled more frequently\")\n", "\n", "# Demonstrate beta annealing\n", "prio_buf.set_beta(0.6) # Increase beta over training\n", "print(f\"\\nAnnealed beta to: {prio_buf.options['beta']}\")\n", "print(\"Beta typically starts at 0.4 and anneals to 1.0 over training\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**PrioritizedReplayBuffer Use Cases**:\n", "- Rainbow DQN and variants\n", "- Any algorithm where some transitions are more \"surprising\" and valuable\n", "- Environments with rare but important events\n", "\n", "**Key Parameters**:\n", "- `alpha`: Controls how much prioritization affects sampling (0=uniform, 1=fully proportional to priority)\n", "- `beta`: Importance sampling correction to remain unbiased (anneal from ~0.4 to 1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7.2 Other Specialized Buffers\n", "\n", "**CachedReplayBuffer**: Maintains a primary buffer plus auxiliary caches\n", "- Use case: Imitation learning where you want separate expert and agent buffers\n", "- Example: GAIL (Generative Adversarial Imitation Learning)\n", "- Allows different sampling ratios from different sources\n", "\n", "**HERReplayBuffer**: Hindsight Experience Replay for goal-conditioned tasks\n", "- Use case: Sparse reward robotics tasks\n", "- Relabels failed episodes with achieved goals as if they were intended\n", "- Dramatically improves learning in goal-reaching tasks\n", "- See the HER documentation for detailed examples\n", "\n", "For detailed usage of these specialized buffers, refer to the Tianshou API documentation and algorithm-specific tutorials." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Serialization and Persistence\n", "\n", "Buffers support multiple serialization formats for saving and loading data.\n", "\n", "### 8.1 Pickle Serialization\n", "\n", "The simplest method, preserving all buffer state including trajectory metadata:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create and populate a buffer\n", "save_buf = ReplayBuffer(size=50)\n", "for i in range(30):\n", " save_buf.add(\n", " Batch(\n", " obs=np.array([i, i + 1]),\n", " act=i % 4,\n", " rew=float(i),\n", " terminated=(i + 1) % 10 == 0,\n", " truncated=False,\n", " obs_next=np.array([i + 1, i + 2]),\n", " info={\"step\": i},\n", " )\n", " )\n", "\n", "print(f\"Original buffer: {len(save_buf)} transitions\")\n", "\n", "# Serialize with pickle\n", "pickled_data = pickle.dumps(save_buf)\n", "print(f\"Serialized size: {len(pickled_data)} bytes\")\n", "\n", "# Deserialize\n", "loaded_buf = pickle.loads(pickled_data)\n", "print(f\"Loaded buffer: {len(loaded_buf)} transitions\")\n", "print(f\"Data preserved: obs[0] = {loaded_buf.obs[0]}\")\n", "print(f\"Metadata preserved: info[0] = {loaded_buf.info[0]}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 8.2 HDF5 Serialization\n", "\n", "HDF5 is recommended for large datasets and cross-platform compatibility:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Save to HDF5\n", "with tempfile.NamedTemporaryFile(suffix=\".hdf5\", delete=False) as tmp:\n", " hdf5_path = tmp.name\n", "\n", "save_buf.save_hdf5(hdf5_path, compression=\"gzip\")\n", "print(f\"Saved to HDF5: {hdf5_path}\")\n", "\n", "# Load from HDF5\n", "loaded_hdf5_buf = ReplayBuffer.load_hdf5(hdf5_path)\n", "print(f\"Loaded from HDF5: {len(loaded_hdf5_buf)} transitions\")\n", "print(f\"Data matches: {np.array_equal(save_buf.obs, loaded_hdf5_buf.obs)}\")\n", "\n", "# Clean up\n", "import os\n", "\n", "os.unlink(hdf5_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**When to Use HDF5**:\n", "- Large datasets (> 1GB)\n", "- Offline RL with pre-collected data\n", "- Sharing data across platforms\n", "- Need for compression\n", "- Integration with external tools (many scientific tools read HDF5)\n", "\n", "**When to Use Pickle**:\n", "- Quick saves during development\n", "- Small buffers\n", "- Python-only workflow\n", "- Simpler serialization needs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 8.3 Loading from Raw Data with from_data()\n", "\n", "For offline RL, you can create a buffer from raw arrays:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Simulate pre-collected offline dataset\n", "import h5py\n", "\n", "# Create temporary HDF5 file with raw data\n", "with tempfile.NamedTemporaryFile(suffix=\".hdf5\", delete=False) as tmp:\n", " offline_path = tmp.name\n", "\n", "with h5py.File(offline_path, \"w\") as f:\n", " # Create datasets\n", " n = 100\n", " f.create_dataset(\"obs\", data=np.random.randn(n, 4))\n", " f.create_dataset(\"act\", data=np.random.randint(0, 2, n))\n", " f.create_dataset(\"rew\", data=np.random.randn(n))\n", " f.create_dataset(\"terminated\", data=np.random.random(n) < 0.1)\n", " f.create_dataset(\"truncated\", data=np.zeros(n, dtype=bool))\n", " f.create_dataset(\"done\", data=np.random.random(n) < 0.1)\n", " f.create_dataset(\"obs_next\", data=np.random.randn(n, 4))\n", "\n", "# Load into buffer\n", "with h5py.File(offline_path, \"r\") as f:\n", " offline_buf = ReplayBuffer.from_data(\n", " obs=f[\"obs\"],\n", " act=f[\"act\"],\n", " rew=f[\"rew\"],\n", " terminated=f[\"terminated\"],\n", " truncated=f[\"truncated\"],\n", " done=f[\"done\"],\n", " obs_next=f[\"obs_next\"],\n", " )\n", "\n", "print(f\"Loaded offline dataset: {len(offline_buf)} transitions\")\n", "print(f\"Observation shape: {offline_buf.obs.shape}\")\n", "\n", "# Clean up\n", "os.unlink(offline_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the standard approach for offline RL where you have pre-collected datasets from other sources." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9. Integration with the RL Pipeline\n", "\n", "Understanding how buffers integrate with other Tianshou components is essential for effective usage.\n", "\n", "### 9.1 Data Flow in RL Training\n", "\n", "```mermaid\n", "graph LR\n", " ENV[Vectorized
Environments] -->|observations| COL[Collector]\n", " POL[Policy] -->|actions| COL\n", " COL -->|transitions| BUF[Buffer]\n", " BUF -->|sampled batches| POL\n", " POL -->|forward pass| ALG[Algorithm]\n", " ALG -->|loss & gradients| POL\n", " \n", " style ENV fill:#e1f5ff\n", " style COL fill:#fff4e1\n", " style BUF fill:#ffe1f5\n", " style POL fill:#e8f5e1\n", " style ALG fill:#f5e1e1\n", "```\n", "\n", "### 9.2 Typical Training Loop Pattern\n", "\n", "Here's how buffers are typically used in a training loop:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Pseudocode for typical RL training loop\n", "# (This is illustrative; actual implementation would use Trainer)\n", "\n", "\n", "def training_loop_pseudocode():\n", " \"\"\"\n", " Illustrative training loop showing buffer integration.\n", "\n", " In practice, use Tianshou's Trainer class which handles this.\n", " \"\"\"\n", " # Setup (illustration only)\n", " # env = make_vectorized_env(num_envs=8)\n", " # policy = make_policy()\n", " # buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)\n", " # collector = Collector(policy, env, buffer)\n", "\n", " # Training loop\n", " # for epoch in range(num_epochs):\n", " # # 1. Collect data from environments\n", " # collect_result = collector.collect(n_step=1000)\n", " # # Collector automatically adds transitions to buffer with correct buffer_ids\n", " #\n", " # # 2. Train on multiple batches\n", " # for _ in range(update_per_collect):\n", " # # Sample batch from buffer\n", " # batch, indices = buffer.sample(batch_size=256)\n", " #\n", " # # Compute loss and update policy\n", " # loss = policy.learn(batch)\n", " #\n", " # # For prioritized buffers, update priorities\n", " # # if isinstance(buffer, PrioritizedReplayBuffer):\n", " # # buffer.update_weight(indices, td_errors)\n", "\n", " print(\"This pseudocode illustrates the buffer's role:\")\n", " print(\"1. Collector fills buffer from environment interaction\")\n", " print(\"2. Buffer provides random samples for training\")\n", " print(\"3. Policy learns from sampled batches\")\n", " print(\"\\nIn practice, use Tianshou's Trainer for this workflow\")\n", "\n", "\n", "training_loop_pseudocode()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 9.3 Collector Integration\n", "\n", "The Collector class handles the complexity of:\n", "- Calling policy to get actions\n", "- Stepping environments\n", "- Adding transitions to buffer with correct buffer_ids\n", "- Tracking episode statistics\n", "\n", "When you create a Collector, you pass it a buffer, and it automatically:\n", "- Uses VectorReplayBuffer for vectorized environments\n", "- Sets buffer_ids based on which environments are ready\n", "- Handles episode resets and boundary tracking\n", "\n", "See the Collector tutorial for detailed examples of this integration." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10. Advanced Topics and Edge Cases\n", "\n", "### 10.1 Buffer Overflow and Episode Boundaries\n", "\n", "What happens when the buffer fills up mid-episode?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Small buffer to demonstrate overflow\n", "overflow_buf = ReplayBuffer(size=8)\n", "\n", "# Add a long episode (12 steps, buffer size is only 8)\n", "print(\"Adding 12-step episode to buffer with size 8:\")\n", "for i in range(12):\n", " idx, ep_rew, ep_len, ep_start = overflow_buf.add(\n", " Batch(\n", " obs=i,\n", " act=0,\n", " rew=1.0,\n", " terminated=i == 11,\n", " truncated=False,\n", " obs_next=i + 1,\n", " info={},\n", " )\n", " )\n", " if i in [7, 11]:\n", " print(f\" Step {i}: idx={idx}, buffer_len={len(overflow_buf)}\")\n", "\n", "print(\"\\nFinal buffer contents (most recent 8 steps):\")\n", "print(f\"Observations: {overflow_buf.obs[: len(overflow_buf)]}\")\n", "print(f\"Episode return: {ep_rew[0]} (sum of all 12 steps, tracked correctly!)\")\n", "print(\"\\nNote: Buffer overwrote old data but episode statistics are still correct\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Important**: Episode returns and lengths are tracked internally and remain correct even when the episode spans buffer overflows. The buffer maintains `_ep_return`, `_ep_len`, and `_ep_start_idx` to track ongoing episodes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10.2 Episode Spanning Subbuffer Edges\n", "\n", "In VectorReplayBuffer, episodes can wrap around within their subbuffer:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create small VectorReplayBuffer to demonstrate edge crossing\n", "edge_buf = VectorReplayBuffer(total_size=20, buffer_num=2) # 10 per subbuffer\n", "\n", "print(f\"Subbuffer edges: {edge_buf.subbuffer_edges}\")\n", "print(\"Subbuffer 0: indices 0-9, Subbuffer 1: indices 10-19\\n\")\n", "\n", "# Fill subbuffer 0 with 12 steps (wraps around since capacity is 10)\n", "for i in range(12):\n", " batch = Batch(\n", " obs=np.array([[i]]),\n", " act=np.array([0]),\n", " rew=np.array([1.0]),\n", " terminated=np.array([i == 11]),\n", " truncated=np.array([False]),\n", " obs_next=np.array([[i + 1]]),\n", " info=np.array([{}], dtype=object),\n", " )\n", " idx, _, _, _ = edge_buf.add(batch, buffer_ids=[0])\n", " if i >= 10:\n", " print(f\"Step {i} added at index {idx[0]} (wrapped around in subbuffer 0)\")\n", "\n", "# get_buffer_indices handles this correctly\n", "episode_indices = edge_buf.get_buffer_indices(start=8, stop=2) # Crosses edge\n", "print(f\"\\nEpisode spanning edge (from 8 to 1): {episode_indices}\")\n", "print(\"Correctly retrieves [8, 9, 0, 1] within subbuffer 0\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10.3 ignore_obs_next Memory Optimization\n", "\n", "For memory-constrained scenarios, you can avoid storing obs_next:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Buffer that doesn't store obs_next\n", "memory_buf = ReplayBuffer(size=10, ignore_obs_next=True)\n", "\n", "# Add transitions (obs_next is ignored)\n", "for i in range(5):\n", " memory_buf.add(\n", " Batch(\n", " obs=np.array([i, i + 1]),\n", " act=i,\n", " rew=1.0,\n", " terminated=False,\n", " truncated=False,\n", " obs_next=np.array([i + 1, i + 2]), # Provided but not stored\n", " info={},\n", " )\n", " )\n", "\n", "# When sampling, obs_next is reconstructed from next obs\n", "sample, _ = memory_buf.sample(batch_size=1)\n", "print(f\"Sampled obs: {sample.obs}\")\n", "print(f\"Sampled obs_next: {sample.obs_next}\")\n", "print(\"\\nobs_next was reconstructed, not stored directly\")\n", "print(\"This saves memory at the cost of slightly more complex retrieval\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is particularly useful for Atari environments with large observation spaces (84x84x4 frames)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 11. Surprising Behaviors and Gotchas\n", "\n", "### 11.1 Most Common Mistake: buffer_ids Confusion\n", "\n", "The buffer_ids parameter is the most common source of errors:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# COMMON ERROR 1: Forgetting buffer_ids with VectorReplayBuffer\n", "vec_demo = VectorReplayBuffer(total_size=100, buffer_num=4)\n", "\n", "parallel_data = Batch(\n", " obs=np.random.randn(4, 2),\n", " act=np.array([0, 1, 0, 1]),\n", " rew=np.array([1.0, 2.0, 3.0, 4.0]),\n", " terminated=np.array([False, False, False, False]),\n", " truncated=np.array([False, False, False, False]),\n", " obs_next=np.random.randn(4, 2),\n", " info=np.array([{}, {}, {}, {}], dtype=object),\n", ")\n", "\n", "# WRONG: Omitting buffer_ids (defaults to [0,1,2,3] which is OK here)\n", "# But if you have partial data, this will fail\n", "vec_demo.add(parallel_data) # Works by default\n", "\n", "# CORRECT: Always explicit\n", "vec_demo.add(parallel_data, buffer_ids=[0, 1, 2, 3])\n", "print(\"Always specify buffer_ids explicitly for clarity\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# COMMON ERROR 2: Shape mismatch with buffer_ids\n", "try:\n", " # Trying to add 2 transitions but specifying 4 buffer_ids\n", " wrong_batch = Batch(\n", " obs=np.random.randn(2, 2), # Only 2 transitions!\n", " act=np.array([0, 1]),\n", " rew=np.array([1.0, 2.0]),\n", " terminated=np.array([False, False]),\n", " truncated=np.array([False, False]),\n", " obs_next=np.random.randn(2, 2),\n", " info=np.array([{}, {}], dtype=object),\n", " )\n", " vec_demo.add(wrong_batch, buffer_ids=[0, 1, 2, 3]) # MISMATCH!\n", "except (IndexError, ValueError) as e:\n", " print(f\"Error caught: {type(e).__name__}\")\n", " print(\"Lesson: buffer_ids length must match batch size\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 11.2 Done Flag Confusion\n", "\n", "Never manually set the `done` flag:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# WRONG: Manually setting done\n", "wrong_batch = Batch(\n", " obs=1,\n", " act=0,\n", " rew=1.0,\n", " terminated=True,\n", " truncated=False,\n", " # done=True, # DON'T DO THIS! It will be overwritten anyway\n", " obs_next=2,\n", " info={},\n", ")\n", "\n", "# CORRECT: Only set terminated and truncated\n", "# done is automatically computed as (terminated OR truncated)\n", "correct_batch = Batch(\n", " obs=1,\n", " act=0,\n", " rew=1.0,\n", " terminated=True, # Episode ended naturally\n", " truncated=False, # Not cut off\n", " obs_next=2,\n", " info={},\n", ")\n", "\n", "demo = ReplayBuffer(size=10)\n", "demo.add(correct_batch)\n", "print(f\"Terminated: {demo.terminated[0]}\")\n", "print(f\"Truncated: {demo.truncated[0]}\")\n", "print(f\"Done (auto-computed): {demo.done[0]}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 11.3 Sampling from Empty or Near-Empty Buffers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Edge case: Sampling more than available\n", "small_buf = ReplayBuffer(size=100)\n", "for i in range(5): # Only 5 transitions\n", " small_buf.add(\n", " Batch(obs=i, act=0, rew=1.0, terminated=False, truncated=False, obs_next=i + 1, info={})\n", " )\n", "\n", "# Request 20 but only 5 available - samples with replacement\n", "batch, indices = small_buf.sample(batch_size=20)\n", "print(f\"Requested 20, buffer has {len(small_buf)}, got {len(batch)}\")\n", "print(f\"Indices: {indices}\")\n", "print(\"Notice: Some indices repeat (sampling with replacement)\")\n", "\n", "# Defensive pattern: Check buffer size\n", "if len(small_buf) >= 128:\n", " batch, _ = small_buf.sample(128)\n", "else:\n", " print(f\"Buffer has {len(small_buf)} < 128, waiting for more data\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 11.4 Frame Stacking Valid Indices\n", "\n", "With stack_num > 1, not all indices are valid for sampling:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# With frame stacking, early indices can't form complete stacks\n", "stack_demo = ReplayBuffer(size=20, stack_num=4, sample_avail=True)\n", "\n", "for i in range(10):\n", " stack_demo.add(\n", " Batch(\n", " obs=np.array([i]),\n", " act=0,\n", " rew=1.0,\n", " terminated=i == 9,\n", " truncated=False,\n", " obs_next=np.array([i + 1]),\n", " info={},\n", " )\n", " )\n", "\n", "# With sample_avail=True, only valid indices are sampled\n", "sampled, indices = stack_demo.sample(batch_size=5)\n", "print(f\"Sampled indices with stack_num=4, sample_avail=True: {indices}\")\n", "print(\"All indices >= 3 (can form complete 4-frame stacks)\")\n", "\n", "# Without sample_avail, any index can be sampled (may have incomplete stacks)\n", "stack_demo2 = ReplayBuffer(size=20, stack_num=4, sample_avail=False)\n", "for i in range(10):\n", " stack_demo2.add(\n", " Batch(\n", " obs=np.array([i]),\n", " act=0,\n", " rew=1.0,\n", " terminated=False,\n", " truncated=False,\n", " obs_next=np.array([i + 1]),\n", " info={},\n", " )\n", " )\n", "\n", "sampled2, indices2 = stack_demo2.sample(batch_size=5)\n", "print(f\"\\nSampled indices with sample_avail=False: {indices2}\")\n", "print(\"May include indices < 3 (incomplete stacks repeated from boundary)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 12. Best Practices\n", "\n", "### 12.1 Choosing the Right Buffer\n", "\n", "**Decision Tree**:\n", "\n", "1. Are you using parallel environments?\n", " - Yes → Use `VectorReplayBuffer`\n", " - No → Continue to 2\n", "\n", "2. Do you need prioritized experience replay?\n", " - Yes → Use `PrioritizedReplayBuffer` or `PrioritizedVectorReplayBuffer`\n", " - No → Continue to 3\n", "\n", "3. Is it goal-conditioned RL with sparse rewards?\n", " - Yes → Use `HERReplayBuffer` or `HERVectorReplayBuffer`\n", " - No → Continue to 4\n", "\n", "4. Do you need separate expert and agent buffers?\n", " - Yes → Use `CachedReplayBuffer`\n", " - No → Use `ReplayBuffer` (single env) or `VectorReplayBuffer` (standard choice)\n", "\n", "**Most Common Setup**: `VectorReplayBuffer` for production training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 12.2 Buffer Sizing Guidelines\n", "\n", "**Rule of Thumb by Domain**:\n", "\n", "- **Atari games**: 1,000,000 transitions (1e6)\n", "- **Continuous control (MuJoCo)**: 100,000-1,000,000 (1e5-1e6)\n", "- **Robotics**: 100,000-500,000 (1e5-5e5)\n", "- **Simple environments (CartPole)**: 10,000-50,000 (1e4-5e4)\n", "\n", "**Factors to Consider**:\n", "- Available RAM (each transition ~observation_size * 2 + metadata)\n", "- Training time vs sample efficiency tradeoff\n", "- Algorithm requirements (some need larger buffers)\n", "\n", "**Memory Estimation**:\n", "```python\n", "# For environments with observation shape (84, 84, 4) (Atari):\n", "# Each transition: 2 * 84 * 84 * 4 bytes (obs + obs_next) + ~100 bytes overhead\n", "# = ~56KB per transition\n", "# 1M transitions = ~56GB (use ignore_obs_next to halve this!)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 12.3 Configuration Best Practices\n", "\n", "**When to use stack_num > 1**:\n", "- RNN/LSTM policies need temporal context\n", "- Frame-based policies (Atari with 4-frame stacking)\n", "- Velocity estimation from positions\n", "\n", "**When to use ignore_obs_next=True**:\n", "- Memory-constrained environments\n", "- Atari (large observation spaces)\n", "- When obs_next can be reconstructed from next obs\n", "\n", "**When to use save_only_last_obs=True**:\n", "- Atari with temporal stacking in environment wrapper\n", "- When observations already contain frame history\n", "\n", "**When to use sample_avail=True**:\n", "- Always use with stack_num > 1 for correctness\n", "- Ensures samples have complete frame stacks\n", "- Small performance cost but worth it for data quality" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 12.4 Integration Patterns\n", "\n", "**Pattern 1: Standard Off-Policy Setup**\n", "```python\n", "# env = make_vectorized_env(num_envs=8)\n", "# buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)\n", "# policy = SACPolicy(...)\n", "# collector = Collector(policy, env, buffer)\n", "# \n", "# # Collect and train\n", "# collector.collect(n_step=1000)\n", "# for _ in range(10):\n", "# batch, indices = buffer.sample(256)\n", "# policy.learn(batch)\n", "```\n", "\n", "**Pattern 2: Pre-fill Buffer Before Training**\n", "```python\n", "# # Collect random exploration data\n", "# collector.collect(n_step=10000) # Fill buffer\n", "# \n", "# # Then start training\n", "# while not converged:\n", "# collector.collect(n_step=100)\n", "# for _ in range(10):\n", "# batch = buffer.sample(256)\n", "# policy.learn(batch)\n", "```\n", "\n", "**Pattern 3: Offline RL**\n", "```python\n", "# # Load pre-collected dataset\n", "# buffer = ReplayBuffer.load_hdf5(\"expert_data.hdf5\")\n", "# \n", "# # Train without further collection\n", "# for epoch in range(num_epochs):\n", "# for _ in range(updates_per_epoch):\n", "# batch = buffer.sample(256)\n", "# policy.learn(batch)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 12.5 Performance Tips\n", "\n", "**Tip 1: Pre-allocate buffer size appropriately**\n", "- Don't make buffer too large (wastes memory)\n", "- Don't make it too small (loses important old experiences)\n", "- Start with domain defaults and adjust based on performance\n", "\n", "**Tip 2: Use HDF5 for large offline datasets**\n", "- Compression saves disk space\n", "- Faster loading than pickle for large files\n", "- Better for sharing across systems\n", "\n", "**Tip 3: Batch sampling efficiently**\n", "- Sample once and use multiple times if possible\n", "- Don't sample more than you need\n", "- For multi-GPU training, sample once and split\n", "\n", "**Tip 4: Monitor buffer usage**\n", "```python\n", "# print(f\"Buffer usage: {len(buffer)}/{buffer.maxsize}\")\n", "# if len(buffer) < batch_size:\n", "# print(\"Warning: Sampling with replacement!\")\n", "```\n", "\n", "**Tip 5: Consider ignore_obs_next for large observation spaces**\n", "- Can halve memory usage\n", "- Small computational overhead on sampling\n", "- Especially valuable for image-based RL" ] }, { "cell_type": "markdown", "metadata": {}, "source": "## 13. Quick Reference\n\n### Method Summary\n\n| Method | Purpose | Returns | Notes |\n|--------|---------|---------|-------|\n| `add(batch, buffer_ids)` | Add transition(s) | `(idx, ep_rew, ep_len, ep_start)` | ep_rew/ep_len only non-zero when done=True |\n| `sample(size)` | Random sample | `(batch, indices)` | size=None for all (random), 0 for all (ordered) |\n| `prev(idx)` | Previous in episode | `indices` | Stops at episode boundaries |\n| `next(idx)` | Next in episode | `indices` | Stops at episode boundaries |\n| `get(idx, key, stack_num)` | Get with stacking | `data` | Returns stacked frames if stack_num > 1 |\n| `get_buffer_indices(start, stop)` | Episode range | `indices` | Handles edge-crossing episodes |\n| `unfinished_index()` | Ongoing episodes | `indices` | Returns last step of unfinished episodes |\n| `save_hdf5(path)` | Save to HDF5 | - | Recommended for large datasets |\n| `load_hdf5(path)` | Load from HDF5 | `buffer` | Class method |\n| `from_data(...)` | Create from arrays | `buffer` | For offline RL datasets |\n| `reset()` | Clear buffer | - | Optionally keep episode statistics |\n| `sample_indices(size)` | Get indices only | `indices` | For custom sampling logic |\n\n### Common Patterns Cheatsheet\n\n**Single Environment**:\n```python\nbuffer = ReplayBuffer(size=10000)\nbuffer.add(Batch(obs=..., act=..., rew=..., terminated=..., truncated=..., obs_next=..., info={}))\nbatch, indices = buffer.sample(batch_size=256)\n```\n\n**Parallel Environments**:\n```python\nbuffer = VectorReplayBuffer(total_size=100000, buffer_num=8)\nbuffer.add(parallel_batch, buffer_ids=[0,1,2,3,4,5,6,7])\nbatch, indices = buffer.sample(batch_size=256)\n```\n\n**Frame Stacking**:\n```python\nbuffer = ReplayBuffer(size=100000, stack_num=4, sample_avail=True)\nstacked_obs = buffer.get(index=50, key=\"obs\") # Returns 4 stacked frames\n```\n\n**Prioritized Replay**:\n```python\nbuffer = PrioritizedReplayBuffer(size=100000, alpha=0.6, beta=0.4)\nbatch, indices = buffer.sample(batch_size=256)\nweights = batch.weight # Importance weights are inside the batch\n# ... compute TD errors ...\nbuffer.update_weight(indices, td_errors)\n```\n\n**Offline RL**:\n```python\nbuffer = ReplayBuffer.load_hdf5(\"dataset.hdf5\")\n# Or:\nwith h5py.File(\"dataset.hdf5\", \"r\") as f:\n buffer = ReplayBuffer.from_data(obs=f[\"obs\"], act=f[\"act\"], ...)\n```\n\n**Episode Retrieval**:\n```python\n# Find episode boundaries, then:\nepisode_indices = buffer.get_buffer_indices(start=ep_start_idx, stop=ep_end_idx+1)\nepisode = buffer[episode_indices]\n```" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary and Next Steps\n", "\n", "This tutorial covered Tianshou's buffer system comprehensively:\n", "\n", "1. **Buffer fundamentals**: Why buffers are essential for RL\n", "2. **Buffer hierarchy**: Understanding different buffer types\n", "3. **Basic operations**: Construction, configuration, and data management\n", "4. **Trajectory management**: Episode tracking and boundary navigation\n", "5. **Sampling strategies**: Basic sampling and frame stacking\n", "6. **VectorReplayBuffer**: Critical for parallel environments\n", "7. **Specialized buffers**: Prioritized, cached, and HER variants\n", "8. **Serialization**: Pickle and HDF5 persistence\n", "9. **Integration**: How buffers fit in the RL pipeline\n", "10. **Advanced topics**: Edge cases and overflow handling\n", "11. **Gotchas**: Common mistakes and how to avoid them\n", "12. **Best practices**: Configuration, sizing, and performance\n", "13. **Quick reference**: Method summary and common patterns\n", "\n", "### Next Steps\n", "\n", "- **Collector Deep Dive**: Learn how Collector fills buffers from environments\n", "- **Policy Tutorial**: Understand how policies sample from buffers for training\n", "- **Algorithm Examples**: See buffer usage in specific algorithms (DQN, SAC, PPO)\n", "- **API Reference**: Full details at [Buffer API documentation](https://tianshou.org/en/stable/api/tianshou.data.html)\n", "\n", "### Further Resources\n", "\n", "- [Tianshou GitHub](https://github.com/thu-ml/tianshou) for source code and examples\n", "- [Gymnasium Documentation](https://gymnasium.farama.org/) for environment conventions\n", "- Research papers on experience replay and prioritized sampling" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: docs/02_deep_dives/L3_Environments.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Environments\n", "\n", "In reinforcement learning, agents interact with environments to improve their performance through trial and error. This tutorial explores how Tianshou handles environments, from basic single-environment setups to advanced vectorized and parallel configurations.\n", "\n", "
\n", "
\n", "The agent-environment interaction loop\n", "
\n", "\n", "Tianshou maintains full compatibility with the [Gymnasium](https://gymnasium.farama.org/) API (formerly OpenAI Gym), making it easy to use any Gymnasium-compatible environment." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Bottleneck Problem\n", "\n", "In a standard Gymnasium environment, each interaction follows a sequential pattern:\n", "\n", "1. Agent selects an action\n", "2. Environment processes the action and returns observation and reward\n", "3. Repeat\n", "\n", "This sequential process can become a significant bottleneck in deep reinforcement learning experiments, especially when:\n", "- The environment simulation is computationally intensive\n", "- Network training is fast but data collection is slow\n", "- You have multiple CPU cores available but aren't using them\n", "\n", "Tianshou addresses this bottleneck through **vectorized environments**, which allow parallel sampling across multiple CPU cores." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vectorized Environments\n", "\n", "Vectorized environments enable you to run multiple environment instances in parallel, dramatically accelerating data collection. Let's see this in action." ] }, { "cell_type": "code", "metadata": {}, "source": [ "import time\n", "\n", "import gymnasium as gym\n", "import numpy as np\n", "\n", "from tianshou.env import DummyVectorEnv, SubprocVectorEnv" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Performance Comparison\n", "\n", "Let's compare the sampling speed with different numbers of parallel environments:" ] }, { "cell_type": "code", "metadata": {}, "source": [ "num_cpus = [1, 2, 5]\n", "\n", "for num_cpu in num_cpus:\n", " # Create vectorized environment with multiple processes\n", " env = SubprocVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(num_cpu)])\n", " env.reset()\n", "\n", " sampled_steps = 0\n", " time_start = time.time()\n", "\n", " # Sample 1000 steps\n", " while sampled_steps < 1000:\n", " act = np.random.choice(2, size=num_cpu)\n", " obs, rew, terminated, truncated, info = env.step(act)\n", "\n", " # Reset terminated environments\n", " if np.sum(terminated):\n", " env.reset(np.where(terminated)[0])\n", "\n", " sampled_steps += num_cpu\n", "\n", " time_used = time.time() - time_start\n", " print(f\"Sampled 1000 steps in {time_used:.3f}s using {num_cpu} CPU(s)\")\n", " print(f\" → Speed: {1000 / time_used:.1f} steps/second\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Understanding the Results\n", "\n", "You might notice that the speedup isn't perfectly linear with the number of CPUs. Several factors contribute to this:\n", "\n", "1. **Straggler Effect**: In synchronous mode, all environments must complete before the next batch begins. Slower environments hold back faster ones.\n", "2. **Communication Overhead**: Inter-process communication has costs, especially for fast environments.\n", "3. **Environment Complexity**: For simple environments like CartPole, the overhead may outweigh the benefits.\n", "\n", "> **Important**: `SubprocVectorEnv` should only be used when environment execution is slow. For simple, fast environments like CartPole, `DummyVectorEnv` (or even raw Gymnasium environments) can be more efficient because they avoid both the straggler effect and inter-process communication overhead." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Types of Vectorized Environments\n", "\n", "Tianshou provides several vectorized environment implementations, each optimized for different scenarios:\n", "\n", "### 1. DummyVectorEnv\n", "**Pseudo-parallel simulation using a for-loop**\n", "- Best for: Simple/fast environments, debugging\n", "- Pros: No overhead, deterministic execution\n", "- Cons: No actual parallelization\n", "\n", "### 2. SubprocVectorEnv\n", "**Multiple processes for true parallel simulation**\n", "- Best for: Most parallel simulation scenarios\n", "- Pros: True parallelization, good balance\n", "- Cons: Inter-process communication overhead\n", "\n", "### 3. ShmemVectorEnv\n", "**Shared memory optimization of SubprocVectorEnv**\n", "- Best for: Environments with large observations (e.g., images)\n", "- Pros: Reduced memory footprint, faster for large states\n", "- Cons: More complex implementation\n", "\n", "### 4. RayVectorEnv\n", "**Ray-based distributed simulation**\n", "- Best for: Cluster computing with multiple machines\n", "- Pros: Scales to multiple machines\n", "- Cons: Requires Ray installation and setup\n", "\n", "All these classes share the same API through their base class `BaseVectorEnv`, making it easy to switch between them." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic Usage\n", "\n", "### Creating a Vectorized Environment" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Standard Gymnasium environment\n", "gym_env = gym.make(\"CartPole-v1\")\n", "\n", "\n", "# Tianshou vectorized environment\n", "def create_cartpole_env() -> gym.Env:\n", " return gym.make(\"CartPole-v1\")\n", "\n", "\n", "# Create 5 parallel environments\n", "vector_env = DummyVectorEnv([create_cartpole_env for _ in range(5)])\n", "\n", "print(f\"Created vectorized environment with {vector_env.env_num} environments\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Environment Interaction\n", "\n", "The key difference from standard Gymnasium is that actions, observations, and rewards are all vectorized:" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Standard Gymnasium: reset() returns a single observation\n", "print(\"Standard Gymnasium reset:\")\n", "single_obs, info = gym_env.reset()\n", "print(f\" Shape: {single_obs.shape}\")\n", "print(f\" Value: {single_obs}\")\n", "\n", "print(\"\\n\" + \"=\" * 50 + \"\\n\")\n", "\n", "# Vectorized environment: reset() returns stacked observations\n", "print(\"Vectorized environment reset:\")\n", "vector_obs, info = vector_env.reset()\n", "print(f\" Shape: {vector_obs.shape}\")\n", "print(f\" Value:\\n{vector_obs}\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Taking Vectorized Steps" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Take random actions in all environments\n", "actions = np.random.choice(2, size=vector_env.env_num)\n", "obs, rew, terminated, truncated, info = vector_env.step(actions)\n", "\n", "print(f\"Actions taken: {actions}\")\n", "print(f\"Rewards received: {rew}\")\n", "print(f\"Terminated flags: {terminated}\")\n", "print(\"Info\", info)" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Selective Environment Execution\n", "\n", "You can interact with specific environments using the `id` parameter:" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Execute only environments 0, 1, and 3\n", "selected_actions = np.random.choice(2, size=3)\n", "obs, rew, terminated, truncated, info = vector_env.step(selected_actions, id=[0, 1, 3])\n", "\n", "print(\"Executed actions in environments [0, 1, 3]\")\n", "print(f\"Received {len(rew)} results\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parallel Sampling: Synchronous vs Asynchronous\n", "\n", "### Synchronous Mode (Default)\n", "\n", "By default, vectorized environments operate synchronously: a step completes only after **all** environments finish their step. This works well when all environments take roughly the same time per step.\n", "\n", "### Asynchronous Mode\n", "\n", "When environment step times vary significantly (e.g., 90% of steps take 1s, but 10% take 10s), asynchronous mode can help. It allows faster environments to continue without waiting for slower ones.\n", "\n", "
\n", "
\n", "Comparison of synchronous and asynchronous vectorized environments
\n", "(Steps with the same color are processed together)\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Enabling Asynchronous Mode\n", "\n", "Use the `wait_num` or `timeout` parameters (or both):" ] }, { "cell_type": "code", "metadata": {}, "source": [ "from functools import partial\n", "\n", "\n", "# Create environments with varying step times\n", "class SlowEnv(gym.Env):\n", " \"\"\"Environment with variable step duration.\"\"\"\n", "\n", " def __init__(self, sleep_time):\n", " self.sleep_time = sleep_time\n", " self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4,))\n", " self.action_space = gym.spaces.Discrete(2)\n", " super().__init__()\n", "\n", " def reset(self, seed=None, options=None):\n", " super().reset(seed=seed)\n", " return np.random.rand(4), {}\n", "\n", " def step(self, action):\n", " time.sleep(self.sleep_time) # Simulate slow computation\n", " return np.random.rand(4), 0.0, False, False, {}\n", "\n", "\n", "# Create async vectorized environment\n", "env_fns = [partial(SlowEnv, sleep_time=0.01 * i) for i in [1, 2, 3, 4]]\n", "async_env = SubprocVectorEnv(env_fns, wait_num=3, timeout=0.1)\n", "\n", "print(\"Asynchronous environment created\")\n", "print(\" wait_num=3: Returns after 3 environments complete\")\n", "print(\" timeout=0.1: Or after 0.1 seconds, whichever comes first\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How Async Parameters Work\n", "\n", "- **`wait_num`**: Minimum number of environments to wait for (e.g., `wait_num=3` means each step returns results from at least 3 environments)\n", "- **`timeout`**: Maximum time to wait in seconds (acts as a dynamic `wait_num`—returns whatever is ready after timeout)\n", "- If no environment finishes within the timeout, the system waits until at least one completes\n", "\n", "> **Warning**: Asynchronous collectors can cause exceptions when used as `test_collector` in trainers. Always use synchronous mode for test collectors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## EnvPool Integration\n", "\n", "[EnvPool](https://github.com/sail-sg/envpool/) is a C++-based vectorized environment library that provides significant performance improvements over Python-based solutions for many of the standard environments. Tianshou fully supports EnvPool with minimal code changes.\n", "\n", "### Why EnvPool?\n", "\n", "- **Performance**: 10x-100x faster than standard vectorized environments for supported environments\n", "- **Memory Efficient**: Optimized memory usage through shared buffers\n", "- **Drop-in Replacement**: Nearly identical API to Tianshou's vectorized environments\n", "\n", "### Supported Environments\n", "\n", "EnvPool currently supports:\n", "- Atari games\n", "- MuJoCo physics simulations\n", "- VizDoom 3D environments\n", "- Classic control environments\n", "- Toy text environments" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using EnvPool\n", "\n", "First, install EnvPool:\n", "\n", "```bash\n", "pip install envpool\n", "```\n", "\n", "Then use it directly with Tianshou:\n", "\n", "```python\n", "import envpool\n", "\n", "# Create EnvPool vectorized environment\n", "envs = envpool.make_gymnasium(\"CartPole-v1\", num_envs=10)\n", "\n", "print(f\"Created EnvPool environment with {envs.spec.config.num_envs} environments\")\n", "print(\"Ready to use with Tianshou collectors!\")\n", "\n", "# Use directly with Tianshou\n", "collector = Collector(algorithm, envs, buffer)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### EnvPool Examples\n", "\n", "For complete examples of using EnvPool with Tianshou:\n", "- [Atari with EnvPool](https://github.com/thu-ml/tianshou/tree/master/examples/atari#envpool)\n", "- [MuJoCo with EnvPool](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco#envpool)\n", "- [VizDoom with EnvPool](https://github.com/thu-ml/tianshou/tree/master/examples/vizdoom#envpool)\n", "- [More EnvPool Examples](https://github.com/sail-sg/envpool/tree/master/examples/tianshou_examples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom Environments and State Representations\n", "\n", "Tianshou works seamlessly with custom environments as long as they follow the Gymnasium API. Let's explore how to handle different state representations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Required Gymnasium API\n", "\n", "Your custom environment must implement:\n", "\n", "```python\n", "class MyEnv(gym.Env):\n", " def reset(self, seed=None, options=None) -> Tuple[observation, info]:\n", " \"\"\"Reset environment to initial state.\"\"\"\n", " pass\n", " \n", " def step(self, action) -> Tuple[observation, reward, terminated, truncated, info]:\n", " \"\"\"Execute one step in the environment.\"\"\"\n", " pass\n", " \n", " def seed(self, seed: int) -> List[int]:\n", " \"\"\"Set random seed.\"\"\"\n", " pass\n", " \n", " def render(self, mode='human') -> Any:\n", " \"\"\"Render the environment.\"\"\"\n", " pass\n", " \n", " def close(self) -> None:\n", " \"\"\"Clean up resources.\"\"\"\n", " pass\n", " \n", " # Required spaces\n", " observation_space: gym.Space\n", " action_space: gym.Space\n", "```\n", "\n", "> **Important**: Make sure your `seed()` method is implemented correctly:\n", "> ```python\n", "> def seed(self, seed):\n", "> np.random.seed(seed)\n", "> # Also seed other random generators used in your environment\n", "> ```\n", "> Without proper seeding, parallel environments may produce identical outputs!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dictionary Observations\n", "\n", "Many environments return observations as dictionaries rather than simple arrays. Tianshou's `Batch` class handles this elegantly.\n", "\n", "Example with the FetchReach environment:" ] }, { "cell_type": "code", "metadata": {}, "source": [ "from tianshou.data import Batch, ReplayBuffer\n", "\n", "# Example: Creating a mock observation similar to FetchReach\n", "observation = {\n", " \"observation\": np.array([1.34, 0.75, 0.53, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),\n", " \"achieved_goal\": np.array([1.34, 0.75, 0.53]),\n", " \"desired_goal\": np.array([1.24, 0.78, 0.63]),\n", "}\n", "\n", "# Store in replay buffer\n", "buffer = ReplayBuffer(size=10)\n", "buffer.add(Batch(obs=observation, act=0, rew=0.0, terminated=False, truncated=False))\n", "\n", "print(\"Stored observation structure:\")\n", "print(buffer.obs)" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Accessing Dictionary Observations\n", "\n", "When sampling from the buffer, you can access nested dictionary values in multiple ways:" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Sample a batch\n", "batch, indices = buffer.sample(batch_size=1)\n", "\n", "print(\"Batch keys:\", list(batch.keys()))\n", "print(\"\\nAccessing nested observation:\")\n", "\n", "# Recommended way: access through batch first\n", "print(\"batch.obs.desired_goal[0]:\", batch.obs.desired_goal[0])\n", "\n", "# Alternative ways (not recommended)\n", "print(\"batch.obs[0].desired_goal:\", batch.obs[0].desired_goal)\n", "print(\"batch[0].obs.desired_goal:\", batch[0].obs.desired_goal)" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using Dictionary Observations in Networks\n", "\n", "When designing networks for environments with dictionary observations:" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "\n", "class CustomNetwork(nn.Module):\n", " \"\"\"Network that processes dictionary observations.\"\"\"\n", "\n", " def __init__(self, obs_dim, goal_dim, hidden_dim, action_dim):\n", " super().__init__()\n", "\n", " # Separate processing for different observation components\n", " self.obs_encoder = nn.Linear(obs_dim, hidden_dim)\n", " self.goal_encoder = nn.Linear(goal_dim * 2, hidden_dim) # achieved + desired\n", "\n", " # Combined processing\n", " self.fc = nn.Sequential(\n", " nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, action_dim)\n", " )\n", "\n", " def forward(self, obs_batch, **kwargs):\n", " # Extract components from the batch\n", " observation = obs_batch.observation\n", " achieved_goal = obs_batch.achieved_goal\n", " desired_goal = obs_batch.desired_goal\n", "\n", " # Process each component\n", " obs_feat = self.obs_encoder(observation)\n", " goal_feat = self.goal_encoder(torch.cat([achieved_goal, desired_goal], dim=-1))\n", "\n", " # Combine and output\n", " combined = torch.cat([obs_feat, goal_feat], dim=-1)\n", " return self.fc(combined)\n", "\n", "\n", "# Example usage\n", "net = CustomNetwork(obs_dim=10, goal_dim=3, hidden_dim=64, action_dim=4)\n", "print(\"Network created for dictionary observations\")\n", "print(\" Input: observation (10D) + achieved_goal (3D) + desired_goal (3D)\")\n", "print(\" Output: actions (4D)\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Custom Object States\n", "\n", "For more complex state representations (e.g., graphs, custom objects), Tianshou stores references in numpy arrays. However, you must ensure deep copies to avoid state aliasing:" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import copy\n", "\n", "import networkx as nx\n", "\n", "\n", "class GraphEnv(gym.Env):\n", " \"\"\"Example environment with graph-based states.\"\"\"\n", "\n", " def __init__(self):\n", " super().__init__()\n", " self.graph = nx.Graph()\n", " self.action_space = gym.spaces.Discrete(5)\n", " self.observation_space = gym.spaces.Box(low=0, high=1, shape=(10,)) # for compatibility\n", "\n", " def reset(self, seed=None, options=None):\n", " super().reset(seed=seed)\n", " self.graph = nx.erdos_renyi_graph(10, 0.3)\n", " # IMPORTANT: Return deep copy to avoid reference issues\n", " return copy.deepcopy(self.graph), {}\n", "\n", " def step(self, action):\n", " # Modify graph based on action\n", " if action < 4 and len(self.graph.nodes) > 0:\n", " nodes = list(self.graph.nodes)\n", " if len(nodes) >= 2:\n", " self.graph.add_edge(nodes[0], nodes[1])\n", "\n", " # IMPORTANT: Return deep copy\n", " return copy.deepcopy(self.graph), 0.0, False, False, {}\n", "\n", "\n", "# Test storing graph objects\n", "graph_buffer = ReplayBuffer(size=5)\n", "env = GraphEnv()\n", "obs, _ = env.reset()\n", "graph_buffer.add(Batch(obs=obs, act=0, rew=0.0, terminated=False, truncated=False))\n", "\n", "print(\"Graph objects stored in buffer:\")\n", "print(graph_buffer.obs)" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "> **Important**: When using custom objects as states:\n", "> 1. Always return `copy.deepcopy(state)` in both `reset()` and `step()`\n", "> 2. Ensure the object is numpy-compatible: `np.array([your_object])` should not result in an empty array\n", "> 3. The object may be stored as a shallow copy in the buffer—deep copying prevents state aliasing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Best Practices Summary\n", "\n", "### Choosing the Right Environment Wrapper\n", "\n", "| Scenario | Recommended Wrapper | Why |\n", "|----------|-------------------|-----|\n", "| Simple/fast environments | `DummyVectorEnv` or raw Gym | Minimal overhead |\n", "| Most parallel scenarios | `SubprocVectorEnv` | Good balance of speed and simplicity |\n", "| Large observations (images) | `ShmemVectorEnv` | Optimized memory usage |\n", "| Multi-machine clusters | `RayVectorEnv` | Distributed computing support |\n", "| Maximum performance | EnvPool | C++-based, 10x-100x speedup |\n", "\n", "### Performance Tips\n", "\n", "1. **Profile First**: Measure whether environment or training is your bottleneck before optimizing\n", "2. **Start Simple**: Begin with `DummyVectorEnv` for debugging, then upgrade to parallel versions\n", "3. **Use EnvPool**: If your environment is supported, EnvPool offers the best performance\n", "4. **Async for Variable Times**: Use asynchronous mode only when environment step times vary significantly\n", "5. **Proper Seeding**: Always implement the `seed()` method correctly in custom environments\n", "\n", "### Common Pitfalls\n", "\n", "- ❌ Using `SubprocVectorEnv` for fast environments → Use `DummyVectorEnv` instead\n", "- ❌ Forgetting to deep-copy custom states → States will be aliased in the buffer\n", "- ❌ Not implementing `seed()` properly → Parallel environments produce identical results\n", "- ❌ Using async collectors for testing → Causes exceptions in trainers\n", "- ❌ Assuming linear speedup → Account for communication overhead and straggler effects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Further Reading\n", "\n", "- **Tianshou Documentation**: [Environment API Reference](https://tianshou.org/en/master/03_api/env/venvs.html)\n", "- **EnvPool**: [Official Documentation](https://envpool.readthedocs.io/)\n", "- **Gymnasium**: [Environment Creation Tutorial](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/)\n", "- **Ray**: [Distributed RL with Ray](https://docs.ray.io/en/latest/rllib/index.html)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: docs/02_deep_dives/L4_GAE.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "QJ5krjrcbuiA" }, "source": [ "# Generalized Advantage Estimation\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "UPVl5LBEWJ0t" }, "source": [ "## How to compute GAE on your own?\n", "(Note that for this reading you need to understand the calculation of [GAE](https://arxiv.org/abs/1506.02438) advantage first)\n", "\n", "In terms of code implementation, perhaps the most difficult and annoying part is computing GAE advantage. Just now, we use the `self.compute_episodic_return()` method inherited from `BasePolicy` to save us from all those troubles. However, it is still important that we know the details behind this.\n", "\n", "To compute GAE advantage, the usage of `self.compute_episodic_return()` may go like:" ] }, { "cell_type": "markdown", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "D34GlVvPNz08", "outputId": "43a4e5df-59b5-4e4a-c61c-e69090810215" }, "source": [ "```python\n", "batch, indices = dummy_buffer.sample(0) # 0 means sampling all the data from the buffer\n", "returns, advantage = Algorithm.compute_episodic_return(\n", " batch=batch,\n", " buffer=dummy_buffer,\n", " indices=indices,\n", " v_s_=np.zeros(10),\n", " v_s=np.zeros(10),\n", " gamma=1.0,\n", " gae_lambda=1.0,\n", ")\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the code above, we sample all the 10 data in the buffer and try to compute the GAE advantage. However, the way the returns are computed here might be a bit misleading. In fact, the last episode is unfinished, but its last step saved in the batch is treated as a terminal state, since it assumes that there are no future rewards. The episode is not terminated yet, it is truncated, so the agent could still get rewards in the future. Terminated and truncated episodes should indeed be treated differently.\n", "The return of a step is the (discounted) sum of the future rewards from that step until the end of the episode. \n", "\\begin{equation}\n", "R_{t}=\\sum_{t}^{T} \\gamma^{t} r_{t}\n", "\\end{equation}\n", "Thus, at the last step of a terminated episode the return is equal to the reward at that state, since there are no future states.\n", "\\begin{equation}\n", "R_{T,terminated}=r_{T}\n", "\\end{equation}\n", "\n", "However, if the episode was truncated the return at the last step is usually better represented by the estimated value of that state, which is the expected return from that state onwards.\n", "\\begin{align*}\n", "R_{T,truncated}=V^{\\pi}\\left(s_{T}\\right) \\quad & \\text{or} \\quad R_{T,truncated}=Q^{\\pi}(s_{T},a_{T})\n", "\\end{align*}\n", "Moreover, if the next state was also observed (but not its reward), then an even better estimate would be the reward of the last step plus the discounted value of the next state.\n", "\\begin{align*}\n", "R_{T,truncated}=r_T+\\gamma V^{\\pi}\\left(s_{T+1}\\right)\n", "\\end{align*}" ] }, { "cell_type": "markdown", "metadata": { "id": "h_5Dt6XwQLXV" }, "source": [ "\n", "As we know, we need to estimate the value function of every observation to compute GAE advantage. So in `v_s` is the value of `batch.obs`, and in `v_s_` is the value of `batch.obs_next`. This is usually computed by:\n", "\n", "`v_s = critic(batch.obs)`,\n", "\n", "`v_s_ = critic(batch.obs_next)`,\n", "\n", "where both `v_s` and `v_s_` are 10 dimensional arrays and `critic` is usually a neural network.\n", "\n", "After we've got all those values, GAE can be computed following the equation below." ] }, { "cell_type": "markdown", "metadata": { "id": "ooHNIICGUO19" }, "source": [ "\\begin{aligned}\n", "\\hat{A}_{t}^{\\mathrm{GAE}(\\gamma, \\lambda)}: =& \\sum_{l=0}^{\\infty}(\\gamma \\lambda)^{l} \\delta_{t+l}^{V}\n", "\\end{aligned}\n", "\n", "where\n", "\n", "\\begin{equation}\n", "\\delta_{t}^{V} \\quad=-V\\left(s_{t}\\right)+r_{t}+\\gamma V\\left(s_{t+1}\\right)\n", "\\end{equation}\n" ] }, { "cell_type": "markdown", "metadata": { "id": "eV6XZaouU7EV" }, "source": [ "Unfortunately, if you follow this equation, which is taken from the paper, you probably will get a slightly lower performance than you expected. There are at least 3 \"bugs\" in this equation." ] }, { "cell_type": "markdown", "metadata": { "id": "FCxD9gNNVYbd" }, "source": [ "**First** is that Gym always returns you a `obs_next` even if this is already the last step. The value of this timestep is exactly 0 and you should not let the neural network estimate it." ] }, { "cell_type": "markdown", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rNZNUNgQVvRJ", "outputId": "44354595-c25a-4da8-b4d8-cffa31ac4b7d" }, "source": [ "```python\n", "# Assume v_s_ is got by calling critic(batch.obs_next)\n", "v_s_ = np.ones(10)\n", "v_s_ *= ~batch.done\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "2EtMi18QWXTN" }, "source": [ "After the fix above, we will perhaps get a more accurate estimate.\n", "\n", "**Secondly**, you must know when to stop bootstrapping. Usually we stop bootstrapping when we meet a `done` flag. However, in the buffer above, the last (10th) step is not marked by done=True, because the collecting has not finished. We must know all those unfinished steps so that we know when to stop bootstrapping.\n", "\n", "Luckily, this can be done under the assistance of buffer because buffers in Tianshou not only store data, but also help you manage data trajectories." ] }, { "cell_type": "markdown", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "saluvX4JU6bC", "outputId": "2994d178-2f33-40a0-a6e4-067916b0b5c5" }, "source": [ "```python\n", "unfinished_indexes = dummy_buffer.unfinished_index()\n", "done_indexes = np.where(batch.done)[0]\n", "stop_bootstrap_ids = np.concatenate([unfinished_indexes, done_indexes])\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "qp6vVE4dYWv1" }, "source": [ "**Thirdly**, there are some special indexes which are marked by done flag, however its value for obs_next should not be zero. It is again because done does not differentiate between terminated and truncated. These steps are usually those at the last step of an episode, but this episode stops not because the agent can no longer get any rewards (value=0), but because the episode is too long so we have to truncate it. These kind of steps are always marked with `info['TimeLimit.truncated']=True` in Gym." ] }, { "cell_type": "markdown", "metadata": { "id": "tWkqXRJfZTvV" }, "source": [ "As a result, we need to rewrite the equation above\n", "\n", "`v_s_ *= ~batch.done`" ] }, { "cell_type": "markdown", "metadata": { "id": "kms-QtxKZe-M" }, "source": [ "to\n", "\n", "```\n", "mask = batch.info['TimeLimit.truncated'] | (~batch.done)\n", "v_s_ *= mask\n", "\n", "```\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "u_aPPoKraBu6" }, "source": [ "## Summary\n", "If you already felt bored by now, simply remember that Tianshou can help handle all these little details so that you can focus on the algorithm itself. Just call `Algorithm.compute_episodic_return()`.\n", "\n", "If you still feel interested, we would recommend you check Appendix C in this [paper](https://arxiv.org/abs/2107.14171v2) and implementation of `Algorithm.value_mask()` and `Algorithm.compute_episodic_return()` for details." ] }, { "cell_type": "markdown", "metadata": { "id": "2cPnUXRBWKD9" }, "source": [ "
\n", "\n", "
\n", "
\n", "\n", "
" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: docs/02_deep_dives/L5_Collector.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "M98bqxdMsTXK" }, "source": [ "# Collector\n", "\n", "The Collector serves as the orchestration layer between the policy (agent) and the environment in Tianshou's architecture. It manages the interaction loop, persists collected experiences to a replay buffer, and computes episode-level statistics. This module is fundamental to both training data collection and policy evaluation workflows." ] }, { "cell_type": "markdown", "metadata": { "id": "OX5cayLv4Ziu" }, "source": [ "## Core Applications\n", "\n", "The Collector supports two primary use cases in reinforcement learning experiments:\n", "1. **Training**: Collecting interaction data for policy optimization\n", "2. **Evaluation**: Assessing policy performance without learning" ] }, { "cell_type": "markdown", "metadata": { "id": "Z6XKbj28u8Ze" }, "source": [ "### Policy Evaluation\n", "\n", "Periodic policy evaluation is essential in deep reinforcement learning (DRL) experiments to monitor training progress and assess generalization. The Collector provides a standardized interface for this purpose.\n", "\n", "**Setup**: A Collector requires two components:\n", "1. An environment (or vectorized environment for parallelization)\n", "2. A policy instance to evaluate" ] }, { "cell_type": "code", "metadata": { "editable": true, "id": "w8t9ubO7u69J", "slideshow": { "slide_type": "" }, "tags": [ "hide-cell", "remove-output" ], "ExecuteTime": { "end_time": "2025-10-26T21:59:25.914405Z", "start_time": "2025-10-26T21:59:22.196044Z" } }, "source": [ "import gymnasium as gym\n", "import torch\n", "\n", "from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy" ], "outputs": [], "execution_count": 1 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-10-26T21:59:30.621207Z", "start_time": "2025-10-26T21:59:25.922401Z" } }, "source": [ "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import DiscreteActor\n", "\n", "# Initialize single environment for configuration\n", "env = gym.make(\"CartPole-v1\")\n", "\n", "# Create vectorized test environments (2 parallel environments)\n", "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n", "\n", "# Configure neural network architecture\n", "assert env.observation_space.shape is not None # for mypy\n", "preprocess_net = Net(\n", " state_shape=env.observation_space.shape,\n", " hidden_sizes=[\n", " 16,\n", " ],\n", ")\n", "\n", "# Initialize discrete action actor network\n", "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", "actor = DiscreteActor(preprocess_net=preprocess_net, action_shape=env.action_space.n)\n", "\n", "# Create policy with categorical action distribution\n", "policy = ProbabilisticActorPolicy(\n", " actor=actor,\n", " dist_fn=torch.distributions.Categorical,\n", " action_space=env.action_space,\n", " action_scaling=False,\n", ")\n", "\n", "# Initialize collector for evaluation\n", "test_collector = Collector[CollectStats](policy, test_envs)" ], "outputs": [], "execution_count": 2 }, { "cell_type": "markdown", "metadata": { "id": "wmt8vuwpzQdR" }, "source": [ "### Evaluating Untrained Policy Performance\n", "\n", "We now evaluate the randomly initialized policy across 9 episodes to establish a baseline performance metric:" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9SuT6MClyjyH", "outputId": "1e48f13b-c1fe-4fc2-ca1b-669485efdcae", "ExecuteTime": { "end_time": "2025-10-26T21:59:31.362074Z", "start_time": "2025-10-26T21:59:30.752198Z" } }, "source": [ "# Collect 9 complete episodes with environment reset\n", "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n", "\n", "collect_result.pprint_asdict()" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CollectStats\n", "----------------------------------------\n", "{ 'collect_speed': 288.36823267420584,\n", " 'collect_time': 0.5860562324523926,\n", " 'lens': array([15, 22, 29, 22, 8, 16, 28, 10, 19]),\n", " 'lens_stat': { 'max': 29.0,\n", " 'mean': 18.77777777777778,\n", " 'min': 8.0,\n", " 'std': 6.876332643007022},\n", " 'n_collected_episodes': 9,\n", " 'n_collected_steps': 169,\n", " 'pred_dist_std_array': array([[0.49482444],\n", " [0.49513358],\n", " [0.491721 ],\n", " [0.49804375],\n", " [0.48936436],\n", " [0.49519676],\n", " [0.49186328],\n", " [0.4981152 ],\n", " [0.49512368],\n", " [0.49527684],\n", " [0.49800068],\n", " [0.4982014 ],\n", " [0.49516457],\n", " [0.4953748 ],\n", " [0.49805665],\n", " [0.49269873],\n", " [0.49946678],\n", " [0.49553478],\n", " [0.49997097],\n", " [0.49289048],\n", " [0.4998387 ],\n", " [0.4957225 ],\n", " [0.49908724],\n", " [0.4930759 ],\n", " [0.49776313],\n", " [0.49590224],\n", " [0.49913287],\n", " [0.4986142 ],\n", " [0.4998805 ],\n", " [0.49605882],\n", " [0.49581137],\n", " [0.49861884],\n", " [0.4922438 ],\n", " [0.4962572 ],\n", " [0.49569792],\n", " [0.49863097],\n", " [0.4982123 ],\n", " [0.49961406],\n", " [0.49553847],\n", " [0.4999985 ],\n", " [0.49808985],\n", " [0.4997094 ],\n", " [0.49964666],\n", " [0.4987858 ],\n", " [0.49796286],\n", " [0.4948797 ],\n", " [0.49960598],\n", " [0.4916098 ],\n", " [0.4999896 ],\n", " [0.49003887],\n", " [0.4997966 ],\n", " [0.48927104],\n", " [0.4999768 ],\n", " [0.4899478 ],\n", " [0.49948972],\n", " [0.49140957],\n", " [0.4978501 ],\n", " [0.49466696],\n", " [0.49509352],\n", " [0.49118617],\n", " [0.49797186],\n", " [0.49447665],\n", " [0.49950802],\n", " [0.49740306],\n", " [0.498081 ],\n", " [0.49935713],\n", " [0.49534237],\n", " [0.49994358],\n", " [0.49823537],\n", " [0.499905 ],\n", " [0.4955164 ],\n", " [0.49991024],\n", " [0.49839276],\n", " [0.4999328 ],\n", " [0.49570385],\n", " [0.4993451 ],\n", " [0.49855497],\n", " [0.49995714],\n", " [0.4995841 ],\n", " [0.49939576],\n", " [0.49999622],\n", " [0.49824795],\n", " [0.49972966],\n", " [0.49653304],\n", " [0.49880832],\n", " [0.49425703],\n", " [0.49974373],\n", " [0.49662435],\n", " [0.49473634],\n", " [0.49472553],\n", " [0.49773476],\n", " [0.49140546],\n", " [0.49935693],\n", " [0.48954245],\n", " [0.4999408 ],\n", " [0.491403 ],\n", " [0.49988943],\n", " [0.49483237],\n", " [0.49920663],\n", " [0.49134007],\n", " [0.49793968],\n", " [0.4894678 ],\n", " [0.49924025],\n", " [0.4912263 ],\n", " [0.4945435 ],\n", " [0.49469063],\n", " [0.49759832],\n", " [0.49754107],\n", " [0.49466005],\n", " [0.49943802],\n", " [0.4977191 ],\n", " [0.49995443],\n", " [0.49479046],\n", " [0.49937534],\n", " [0.49785116],\n", " [0.49731255],\n", " [0.49934685],\n", " [0.4993554 ],\n", " [0.49798217],\n", " [0.4999266 ],\n", " [0.4993439 ],\n", " [0.49931702],\n", " [0.49815634],\n", " [0.49991363],\n", " [0.4993506 ],\n", " [0.49928144],\n", " [0.49821213],\n", " [0.4973895 ],\n", " [0.49938264],\n", " [0.4992856 ],\n", " [0.4999623 ],\n", " [0.49991205],\n", " [0.49940434],\n", " [0.49991933],\n", " [0.49825713],\n", " [0.49990463],\n", " [0.49554875],\n", " [0.49924377],\n", " [0.49196848],\n", " [0.49991465],\n", " [0.48965713],\n", " [0.49991086],\n", " [0.4888782 ],\n", " [0.49921995],\n", " [0.48808664],\n", " [0.49516302],\n", " [0.48725367],\n", " [0.49179506],\n", " [0.4879356 ],\n", " [0.4952572 ],\n", " [0.48861024],\n", " [0.49187768],\n", " [0.48927858],\n", " [0.4953288 ],\n", " [0.48839873],\n", " [0.49193203],\n", " [0.49538046],\n", " [0.49808696],\n", " [0.49537748],\n", " [0.49810043],\n", " [0.4953903 ],\n", " [0.4981276 ],\n", " [0.49956635],\n", " [0.49998853],\n", " [0.49978945],\n", " [0.49897715],\n", " [0.4975953 ],\n", " [0.49903452],\n", " [0.49765074]], dtype=float32),\n", " 'pred_dist_std_array_stat': { 0: { 'max': 0.4999985098838806,\n", " 'mean': 0.4965951144695282,\n", " 'min': 0.48725366592407227,\n", " 'std': 0.003376598935574293}},\n", " 'returns': array([15., 22., 29., 22., 8., 16., 28., 10., 19.]),\n", " 'returns_stat': { 'max': 29.0,\n", " 'mean': 18.77777777777778,\n", " 'min': 8.0,\n", " 'std': 6.876332643007022}}\n" ] } ], "execution_count": 3 }, { "cell_type": "markdown", "metadata": { "id": "zX9AQY0M0R3C" }, "source": [ "### Baseline Comparison: Random Policy\n", "\n", "To contextualize the initialized policy's performance, we establish a random action baseline:" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UEcs8P8P0RLt", "outputId": "85f02f9d-b79b-48b2-99c6-36a1602f0884", "ExecuteTime": { "end_time": "2025-10-26T21:59:31.431099Z", "start_time": "2025-10-26T21:59:31.371074Z" } }, "source": [ "# Evaluate random policy by sampling actions uniformly from action space\n", "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n", "\n", "collect_result.pprint_asdict()" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CollectStats\n", "----------------------------------------\n", "{ 'collect_speed': 4407.5322624798,\n", " 'collect_time': 0.053998470306396484,\n", " 'lens': array([11, 13, 15, 29, 15, 12, 15, 30, 98]),\n", " 'lens_stat': { 'max': 98.0,\n", " 'mean': 26.444444444444443,\n", " 'min': 11.0,\n", " 'std': 26.16236105175657},\n", " 'n_collected_episodes': 9,\n", " 'n_collected_steps': 238,\n", " 'pred_dist_std_array': None,\n", " 'pred_dist_std_array_stat': None,\n", " 'returns': array([11., 13., 15., 29., 15., 12., 15., 30., 98.]),\n", " 'returns_stat': { 'max': 98.0,\n", " 'mean': 26.444444444444443,\n", " 'min': 11.0,\n", " 'std': 26.16236105175657}}\n" ] } ], "execution_count": 4 }, { "cell_type": "markdown", "metadata": { "id": "sKQRTiG10ljU" }, "source": [ "**Observation**: The randomly initialized policy performs comparably to (or worse than) uniform random actions prior to training. This is expected behavior, as the network weights lack task-specific optimization." ] }, { "cell_type": "markdown", "metadata": { "id": "8RKmHIoG1A1k" }, "source": [ "### Training Data Collection\n", "\n", "During the training phase, the Collector manages experience gathering and automatic storage in a replay buffer. This enables the experience replay mechanism fundamental to off-policy algorithms." ] }, { "cell_type": "code", "metadata": { "editable": true, "id": "CB9XB9bF1YPC", "slideshow": { "slide_type": "" }, "tags": [], "ExecuteTime": { "end_time": "2025-10-26T21:59:31.452144Z", "start_time": "2025-10-26T21:59:31.444096Z" } }, "source": [ "# Configuration for parallel training data collection\n", "train_env_num = 4\n", "buffer_size = 100\n", "\n", "# Initialize vectorized training environments\n", "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n", "\n", "# Create replay buffer compatible with vectorized environments\n", "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", "\n", "# Initialize training collector with buffer integration\n", "training_collector = Collector[CollectStats](policy, train_envs, replayBuffer)" ], "outputs": [], "execution_count": 5 }, { "cell_type": "markdown", "metadata": { "id": "rWKDazA42IUQ" }, "source": [ "### Step-Based Collection\n", "\n", "The Collector supports both step-based and episode-based collection modes. Here we demonstrate step-based collection, which is commonly used in training loops with fixed update frequencies.\n", "\n", "**Note**: When using vectorized environments, the actual number of collected steps may exceed the requested amount to maintain synchronization across parallel environments." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-fUtQOnM2Yi1", "outputId": "dceee987-433e-4b75-ed9e-823c20a9e1c2", "ExecuteTime": { "end_time": "2025-10-26T21:59:31.501487Z", "start_time": "2025-10-26T21:59:31.459140Z" } }, "source": [ "# Reset collector and buffer to clean state\n", "training_collector.reset()\n", "replayBuffer.reset()\n", "\n", "print(f\"Replay buffer before collecting is empty, and has length={len(replayBuffer)} \\n\")\n", "\n", "# Collect 50 environment steps\n", "n_step = 50\n", "collect_result = training_collector.collect(n_step=n_step)\n", "\n", "print(\n", " f\"Replay buffer after collecting {n_step} steps has length={len(replayBuffer)}.\\n\"\n", " f\"The actual count may exceed n_step when it is not a multiple of train_env_num \\n\"\n", " f\"due to vectorization synchronization requirements.\\n\",\n", ")\n", "\n", "collect_result.pprint_asdict()" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Replay buffer before collecting is empty, and has length=0 \n", "\n", "Replay buffer after collecting 50 steps has length=52.\n", "The actual count may exceed n_step when it is not a multiple of train_env_num \n", "due to vectorization synchronization requirements.\n", "\n", "CollectStats\n", "----------------------------------------\n", "{ 'collect_speed': 1529.5011711244197,\n", " 'collect_time': 0.03399801254272461,\n", " 'lens': array([], dtype=int32),\n", " 'lens_stat': None,\n", " 'n_collected_episodes': 0,\n", " 'n_collected_steps': 52,\n", " 'pred_dist_std_array': array([[0.4944575 ],\n", " [0.49571753],\n", " [0.49482644],\n", " [0.49571693],\n", " [0.49746 ],\n", " [0.49228 ],\n", " [0.491648 ],\n", " [0.49237084],\n", " [0.49931562],\n", " [0.48953396],\n", " [0.4949102 ],\n", " [0.49022076],\n", " [0.49992043],\n", " [0.4921799 ],\n", " [0.49171764],\n", " [0.4894729 ],\n", " [0.4992769 ],\n", " [0.48948848],\n", " [0.49497682],\n", " [0.48870105],\n", " [0.49763048],\n", " [0.49201292],\n", " [0.49787888],\n", " [0.4893877 ],\n", " [0.49927947],\n", " [0.4955971 ],\n", " [0.49943653],\n", " [0.49005648],\n", " [0.49780723],\n", " [0.49179533],\n", " [0.49995926],\n", " [0.49153325],\n", " [0.49928913],\n", " [0.48941523],\n", " [0.49986592],\n", " [0.49499276],\n", " [0.4999287 ],\n", " [0.49152908],\n", " [0.4991583 ],\n", " [0.49093276],\n", " [0.4998997 ],\n", " [0.48936346],\n", " [0.4978821 ],\n", " [0.49442956],\n", " [0.49992698],\n", " [0.49117777],\n", " [0.49921465],\n", " [0.49751103],\n", " [0.4992887 ],\n", " [0.4893143 ],\n", " [0.49991187],\n", " [0.4992216 ]], dtype=float32),\n", " 'pred_dist_std_array_stat': { 0: { 'max': 0.499959260225296,\n", " 'mean': 0.49497732520103455,\n", " 'min': 0.4887010455131531,\n", " 'std': 0.003929081838577986}},\n", " 'returns': array([], dtype=float64),\n", " 'returns_stat': None}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "F:\\Users\\Dominik Jain\\Dev\\AI\\tianshou\\tianshou\\data\\collector.py:537: UserWarning: n_step=50 is not a multiple of (self.env_num=4), which may cause extra transitions being collected into the buffer.\n", " warnings.warn(\n" ] } ], "execution_count": 6 }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Buffer Sampling Verification\n", "\n", "Verify that collected experiences are properly stored and can be sampled for training:" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-10-26T21:59:31.517583Z", "start_time": "2025-10-26T21:59:31.509483Z" } }, "source": [ "# Sample mini-batch of 10 transitions from buffer\n", "replayBuffer.sample(10)" ], "outputs": [ { "data": { "text/plain": [ "(Batch(\n", " obs: array([[-7.59119692e-04, -3.54404569e-01, 8.15278068e-02,\n", " 6.34967446e-01],\n", " [ 2.03953441e-02, -5.46947002e-01, 4.59121428e-02,\n", " 8.69558692e-01],\n", " [-5.53812869e-02, -3.63834441e-01, 1.84285983e-01,\n", " 8.54350269e-01],\n", " [ 5.94463721e-02, -3.39802876e-02, -5.61027192e-02,\n", " -2.05838066e-02],\n", " [ 1.70439295e-02, -3.58715117e-01, 2.22064722e-02,\n", " 6.39448643e-01],\n", " [ 1.51256351e-02, 2.27344140e-01, 1.95531528e-02,\n", " -2.54039675e-01],\n", " [-7.69001395e-02, -7.54580617e-01, 1.79230303e-01,\n", " 1.36748278e+00],\n", " [-3.51171643e-02, -1.14145672e+00, 1.09657384e-01,\n", " 1.86768615e+00],\n", " [ 2.10114848e-02, 3.47817928e-01, -1.05900057e-01,\n", " -6.93330288e-01],\n", " [-1.53460149e-02, 5.40259123e-01, -4.36654910e-02,\n", " -9.24050748e-01]], dtype=float32),\n", " act: array([1, 1, 0, 1, 0, 0, 1, 1, 0, 1], dtype=int64),\n", " rew: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n", " terminated: array([False, False, False, False, False, False, False, False, False,\n", " False]),\n", " truncated: array([False, False, False, False, False, False, False, False, False,\n", " False]),\n", " done: array([False, False, False, False, False, False, False, False, False,\n", " False]),\n", " obs_next: array([[-0.00784721, -0.16050874, 0.09422715, 0.36903235],\n", " [ 0.0094564 , -0.3524787 , 0.06330331, 0.59165704],\n", " [-0.06265797, -0.5609251 , 0.20137298, 1.1988543 ],\n", " [ 0.05876676, 0.16189948, -0.05651439, -0.33042672],\n", " [ 0.00986963, -0.5541395 , 0.03499544, 0.93904114],\n", " [ 0.01967252, 0.03194853, 0.01447236, 0.04474595],\n", " [-0.09199175, -0.5620968 , 0.20657995, 1.135794 ],\n", " [-0.05794629, -0.94769216, 0.1470111 , 1.6109598 ],\n", " [ 0.02796784, 0.15431203, -0.11976666, -0.435774 ],\n", " [-0.00454083, 0.73594284, -0.06214651, -1.2301302 ]],\n", " dtype=float32),\n", " info: Batch(\n", " env_id: array([0, 0, 0, 1, 2, 2, 2, 2, 3, 3]),\n", " ),\n", " policy: Batch(),\n", " ),\n", " array([ 6, 3, 12, 31, 56, 53, 62, 60, 81, 78]))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 7 }, { "cell_type": "markdown", "metadata": { "id": "8NP7lOBU3-VS" }, "source": [ "## Advanced Topics\n", "\n", "### Asynchronous Collection\n", "\n", "The standard `Collector` implementation may collect more steps than requested when using vectorized environments. In the example above, requesting 50 steps resulted in 52 steps (the smallest multiple of 4 that is ≥50).\n", "\n", "For scenarios requiring precise step counts, Tianshou provides the `AsyncCollector`, which enables exact step collection at the cost of additional implementation complexity. This is particularly relevant for:\n", "- Strict reproducibility requirements\n", "- Algorithms sensitive to exact batch sizes\n", "- Fine-grained control over data collection\n", "\n", "Consult the [AsyncCollector documentation](https://tianshou.org/en/master/03_api/data/collector.html#tianshou.data.collector.AsyncCollector) for implementation details and usage patterns." ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: docs/02_deep_dives/L6_MARL.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multi-Agent Reinforcement Learning (MARL)\n", "\n", "This tutorial demonstrates how to use Tianshou for multi-agent reinforcement learning scenarios. We'll explore different MARL paradigms and implement a practical example using the Tic-Tac-Toe game." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MARL Paradigms\n", "\n", "Tianshou supports three fundamental types of multi-agent reinforcement learning paradigms:\n", "\n", "1. **Simultaneous move**: All agents take their actions at each timestep simultaneously (e.g., MOBA games)\n", "2. **Cyclic move**: Agents take actions sequentially in turns (e.g., Go)\n", "3. **Conditional move**: The environment conditionally selects which agent acts at each timestep (e.g., [Pig Game](https://en.wikipedia.org/wiki/Pig_(dice_game)))\n", "\n", "Our approach addresses these multi-agent RL problems by converting them into traditional single-agent RL formulations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Converting MARL to Single-Agent RL\n", "\n", "### Simultaneous Move\n", "\n", "For simultaneous-move scenarios, the solution is straightforward: we add an extra `num_agents` dimension to the state, action, and reward tensors. No other modifications are necessary.\n", "\n", "### Cyclic and Conditional Move\n", "\n", "Both cyclic and conditional move scenarios can be unified into a single framework. At each timestep, the environment selects an agent identified by `agent_id` to act. Since multiple agents are typically wrapped into a single object (the \"abstract agent\"), we pass the `agent_id` to this abstract agent, which then delegates the action to the appropriate specific agent.\n", "\n", "Additionally, in multi-agent RL, the set of legal actions often varies across timesteps (as in Go). Therefore, the environment must also provide a legal action mask to the abstract agent. This mask is a boolean array where `True` indicates available actions and `False` indicates illegal actions at the current timestep.\n", "\n", "
\n", "
\n", "The abstract agent framework for multi-agent RL\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Unified Formulation\n", "\n", "This architecture leads to the following formulation of multi-agent RL:\n", "\n", "```python\n", "act = policy(state, agent_id, mask)\n", "(next_state, next_agent_id, next_mask), reward = env.step(act)\n", "```\n", "\n", "By constructing an augmented state `state_ = (state, agent_id, mask)`, we can reduce this to the standard single-agent RL formulation:\n", "\n", "```python\n", "act = policy(state_)\n", "next_state_, reward = env.step(act)\n", "```\n", "\n", "Following this principle, we'll implement a Q-learning algorithm to play [Tic-Tac-Toe](https://en.wikipedia.org/wiki/Tic-tac-toe) against a random opponent." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PettingZoo Integration\n", "\n", "Tianshou is fully compatible with [PettingZoo](https://pettingzoo.farama.org/) environments for multi-agent RL. While Tianshou doesn't directly provide specialized MARL facilities, it offers a flexible framework that can be adapted to various MARL scenarios.\n", "\n", "For comprehensive tutorials on using Tianshou with PettingZoo, refer to:\n", "\n", "* [Beginner Tutorial](https://pettingzoo.farama.org/tutorials/tianshou/beginner/)\n", "* [Intermediate Tutorial](https://pettingzoo.farama.org/tutorials/tianshou/intermediate/)\n", "* [Advanced Tutorial](https://pettingzoo.farama.org/tutorials/tianshou/advanced/)\n", "\n", "In this tutorial, we'll demonstrate how to use Tianshou in a multi-agent setting where only one agent is trained while the other uses a fixed random policy. You can then use this as a blueprint to replace the random policy with another trainable agent.\n", "\n", "Specifically, we'll train an agent to play Tic-Tac-Toe against a random opponent:\n", "\n", "
\n", "
\n", "Tic-Tac-Toe game board\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exploring the Tic-Tac-Toe Environment\n", "\n", "The complete scripts are located in `test/pettingzoo/`. Tianshou provides the `PettingZooEnv` wrapper class that can wrap any PettingZoo environment. Let's explore the 3×3 Tic-Tac-Toe environment provided by PettingZoo." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pettingzoo.classic import tictactoe_v3 # the Tic-Tac-Toe environment\n", "\n", "from tianshou.env import PettingZooEnv # wrapper for PettingZoo environments\n", "\n", "# Initialize the environment\n", "# The board has 3 rows and 3 columns (9 positions total)\n", "# Players place 'X' and 'O' alternately on the board\n", "# The first player to get 3 consecutive marks wins\n", "env = PettingZooEnv(tictactoe_v3.env(render_mode=\"human\"))\n", "obs = env.reset()\n", "env.render() # render the empty board" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output shows an empty 3×3 board:\n", "\n", "```\n", "board (step 0):\n", " | |\n", " - | - | -\n", "_____|_____|_____\n", " | |\n", " - | - | -\n", "_____|_____|_____\n", " | |\n", " - | - | -\n", " | |\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Examine the observation structure\n", "print(obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Understanding the Observation Space\n", "\n", "The observation returned by the environment is a dictionary with three keys:\n", "\n", "- **`agent_id`**: The identifier of the currently acting agent (e.g., `'player_1'` or `'player_2'`)\n", "\n", "- **`obs`**: The actual environment observation. For Tic-Tac-Toe, this is a numpy array with shape `(3, 3, 2)`:\n", " - For `player_1`: The first 3×3 plane represents X placements, the second plane represents O placements\n", " - For `player_2`: The planes are swapped (O in first plane, X in second)\n", " - Each cell contains either 0 (empty/not placed) or 1 (mark placed)\n", "\n", "- **`mask`**: A boolean array indicating legal actions at the current timestep. For Tic-Tac-Toe, index `i` corresponds to position `(i // 3, i % 3)` on the board. If `mask[i] == True`, the player can place their mark at that position. Initially, all positions are available, so all mask values are `True`.\n", "\n", "> **Note**: The mask representation is flexible and works for both discrete and continuous action spaces. While we use a boolean array here, you could also use action spaces like `gymnasium.spaces.Discrete` or `gymnasium.spaces.Box` to represent available actions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Playing a Few Steps\n", "\n", "Let's play a couple of moves to understand the environment dynamics better." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "# Take an action (place mark at position 0 - top-left corner)\n", "action = 0 # action can be an integer or a numpy array with one element\n", "obs, reward, done, truncated, info = env.step(action) # follows the Gymnasium API\n", "\n", "print(\"Observation after first move:\")\n", "print(obs)\n", "\n", "# Examine the reward structure\n", "# Reward has two items (one for each player): 1 for win, -1 for loss, 0 otherwise\n", "print(f\"\\nReward: {reward}\")\n", "\n", "# Check if the game is over\n", "print(f\"Done: {done}\")\n", "\n", "# Info is typically an empty dict in Tic-Tac-Toe but may contain useful information in other environments\n", "print(f\"Info: {info}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that after the first move:\n", "- The `agent_id` switches to `'player_2'`\n", "- The observation array shows the X placement in the first position\n", "- The mask now has `False` at index 0 (that position is occupied)\n", "- The reward is `[0, 0]` (no winner yet)\n", "- The game continues (`done = False`)" ] }, { "cell_type": "markdown", "metadata": {}, "source": "Note: If we continue playing, the game terminates when only one empty position remains, rather than when the board is completely full. This is because a player with only one available position has no meaningful choice." }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Random Agents\n", "\n", "Now that we understand the environment, let's start by watching two random agents play against each other.\n", "\n", "Tianshou provides built-in classes for multi-agent learning. The key components are:\n", "\n", "- **`RandomPolicy`**: A policy that randomly selects actions\n", "- **`MultiAgentPolicyManager`**: Manages multiple agent policies and delegates actions to the appropriate agent based on `agent_id`\n", "\n", "
\n", "
\n", "The relationship between MultiAgentPolicyManager and individual agent policies\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tianshou.algorithm.multiagent.marl import MultiAgentOffPolicyAlgorithm\n", "from tianshou.algorithm.random import MARLRandomDiscreteMaskedOffPolicyAlgorithm\n", "from tianshou.data import Collector\n", "from tianshou.env import DummyVectorEnv\n", "\n", "# Create a multi-agent algorithm with two random agents\n", "policy = MultiAgentOffPolicyAlgorithm(\n", " algorithms=[\n", " MARLRandomDiscreteMaskedOffPolicyAlgorithm(action_space=env.action_space),\n", " MARLRandomDiscreteMaskedOffPolicyAlgorithm(action_space=env.action_space),\n", " ],\n", " env=env,\n", ")\n", "\n", "# Vectorize the environment for the collector\n", "env = DummyVectorEnv([lambda: env])\n", "\n", "# Create a collector to gather trajectories\n", "collector = Collector(policy, env)\n", "\n", "# Collect and visualize one episode\n", "result = collector.collect(n_episode=1, render=0.1, reset_before_collect=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You'll see the game progress step by step. Here's an example of the final moves:\n", "\n", "```\n", " | |\n", " X | X | -\n", "_____|_____|_____\n", " | |\n", " X | O | -\n", "_____|_____|_____\n", " | |\n", " O | - | -\n", " | |\n", "```\n", "\n", "```\n", " | |\n", " X | X | -\n", "_____|_____|_____\n", " | |\n", " X | O | -\n", "_____|_____|_____\n", " | |\n", " O | - | O\n", " | |\n", "```\n", "\n", "```\n", " | |\n", " X | X | X\n", "_____|_____|_____\n", " | |\n", " X | O | -\n", "_____|_____|_____\n", " | |\n", " O | - | O\n", " | |\n", "```\n", "\n", "Random agents perform poorly. In the game above, although agent 2 eventually wins, a smart agent 1 would have won immediately by placing an X at position (1, 1) (center of middle row)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training an Agent Against a Random Opponent\n", "\n", "Now let's train an intelligent agent! We'll use Deep Q-Network (DQN) to learn optimal play against a random opponent." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Imports and Setup\n", "\n", "First, let's import all necessary modules:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "from copy import deepcopy\n", "from functools import partial\n", "\n", "import gymnasium\n", "import torch\n", "from pettingzoo.classic import tictactoe_v3\n", "from torch.utils.tensorboard import SummaryWriter\n", "\n", "from tianshou.algorithm import (\n", " DQN,\n", " Algorithm,\n", " MARLRandomDiscreteMaskedOffPolicyAlgorithm,\n", " MultiAgentOffPolicyAlgorithm,\n", ")\n", "from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm\n", "from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy\n", "from tianshou.algorithm.optim import AdamOptimizerFactory, OptimizerFactory\n", "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.data.stats import InfoStats\n", "from tianshou.env import DummyVectorEnv\n", "from tianshou.env.pettingzoo_env import PettingZooEnv\n", "from tianshou.trainer import OffPolicyTrainerParams\n", "from tianshou.utils import TensorboardLogger\n", "from tianshou.utils.net.common import Net" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Hyperparameters\n", "\n", "Let's define the hyperparameters for our training experiment directly (no argparse needed in notebooks!):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define hyperparameters\n", "class Args:\n", " seed = 1626\n", " eps_test = 0.05\n", " eps_train = 0.1\n", " buffer_size = 20000\n", " lr = 1e-4\n", " gamma = 0.9 # A smaller gamma favors earlier wins\n", " n_step = 3\n", " target_update_freq = 320\n", " epoch = 50\n", " epoch_num_steps = 1000\n", " collection_step_num_env_steps = 10\n", " update_per_step = 0.1\n", " batch_size = 64\n", " hidden_sizes = [128, 128, 128, 128] # noqa: RUF012\n", " num_train_envs = 10\n", " num_test_envs = 10\n", " logdir = \"log\"\n", " render = 0.1\n", " win_rate = 0.6 # Target winning rate (optimal policy can get ~0.7)\n", " watch = False # Set to True to skip training and watch pre-trained models\n", " agent_id = 2 # The learned agent plays as player 2\n", " resume_path = \"\" # Path to pre-trained agent .pth file\n", " opponent_path = \"\" # Path to pre-trained opponent .pth file\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " model_save_path = None # Will be set in save_best_fn\n", "\n", "\n", "args = Args()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Agent Setup\n", "\n", "The `get_agents` function creates and configures our agents:\n", "\n", "- **Neural Network**: We use `Net`, a multi-layer perceptron with ReLU activations\n", "- **Learning Algorithm**: A `DiscreteQLearningPolicy` combined with `DQN` for Q-learning updates\n", "- **Opponent**: Either a `MARLRandomDiscreteMaskedOffPolicyAlgorithm` that randomly chooses legal actions, or a pre-trained agent for self-play\n", "\n", "Both agents are managed by `MultiAgentOffPolicyAlgorithm`, which:\n", "- Calls the correct agent based on `agent_id` in the observation\n", "- Dispatches data to each agent according to their `agent_id`\n", "- Makes each agent perceive the environment as a single-agent problem\n", "\n", "
\n", "
\n", "How MultiAgentOffPolicyAlgorithm coordinates agent algorithms\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_env(render_mode: str | None = None) -> PettingZooEnv:\n", " return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))\n", "\n", "\n", "def get_agents(\n", " args,\n", " agent_learn: OffPolicyAlgorithm | None = None,\n", " agent_opponent: OffPolicyAlgorithm | None = None,\n", " optim: OptimizerFactory | None = None,\n", ") -> tuple[MultiAgentOffPolicyAlgorithm, torch.optim.Optimizer | None, list]:\n", " \"\"\"Create or load agents for training.\"\"\"\n", " env = get_env()\n", " observation_space = (\n", " env.observation_space.spaces[\"observation\"]\n", " if isinstance(env.observation_space, gymnasium.spaces.Dict)\n", " else env.observation_space\n", " )\n", " args.state_shape = observation_space.shape or int(observation_space.n)\n", " args.action_shape = env.action_space.shape or int(env.action_space.n)\n", "\n", " if agent_learn is None:\n", " # Create the neural network model\n", " net = Net(\n", " state_shape=args.state_shape,\n", " action_shape=args.action_shape,\n", " hidden_sizes=args.hidden_sizes,\n", " ).to(args.device)\n", "\n", " if optim is None:\n", " optim = AdamOptimizerFactory(lr=args.lr)\n", "\n", " # Create Q-learning policy for the learning agent\n", " algorithm = DiscreteQLearningPolicy(\n", " model=net,\n", " action_space=env.action_space,\n", " eps_training=args.eps_train,\n", " eps_inference=args.eps_test,\n", " )\n", "\n", " # Wrap in DQN algorithm\n", " agent_learn = DQN(\n", " policy=algorithm,\n", " optim=optim,\n", " n_step_return_horizon=args.n_step,\n", " gamma=args.gamma,\n", " target_update_freq=args.target_update_freq,\n", " )\n", "\n", " if args.resume_path:\n", " agent_learn.load_state_dict(torch.load(args.resume_path))\n", "\n", " if agent_opponent is None:\n", " if args.opponent_path:\n", " # Load a pre-trained opponent for self-play\n", " agent_opponent = deepcopy(agent_learn)\n", " agent_opponent.load_state_dict(torch.load(args.opponent_path))\n", " else:\n", " # Use a random opponent\n", " agent_opponent = MARLRandomDiscreteMaskedOffPolicyAlgorithm(\n", " action_space=env.action_space\n", " )\n", "\n", " # Arrange agents based on which player position the learning agent takes\n", " if args.agent_id == 1:\n", " agents = [agent_learn, agent_opponent]\n", " else:\n", " agents = [agent_opponent, agent_learn]\n", "\n", " ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env)\n", " return ma_algorithm, optim, env.agents" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training Loop\n", "\n", "The training procedure follows the standard Tianshou workflow, similar to single-agent DQN training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train_agent(\n", " args,\n", " agent_learn: OffPolicyAlgorithm | None = None,\n", " agent_opponent: OffPolicyAlgorithm | None = None,\n", " optim: OptimizerFactory | None = None,\n", ") -> tuple[InfoStats, OffPolicyAlgorithm]:\n", " \"\"\"Train the agent using DQN.\"\"\"\n", " # ======== Environment Setup =========\n", " train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)])\n", " test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)])\n", "\n", " # Set random seeds for reproducibility\n", " np.random.seed(args.seed)\n", " torch.manual_seed(args.seed)\n", " train_envs.seed(args.seed)\n", " test_envs.seed(args.seed)\n", "\n", " # ======== Agent Setup =========\n", " marl_algorithm, optim, agents = get_agents(\n", " args,\n", " agent_learn=agent_learn,\n", " agent_opponent=agent_opponent,\n", " optim=optim,\n", " )\n", "\n", " # ======== Collector Setup =========\n", " training_collector = Collector[CollectStats](\n", " marl_algorithm,\n", " train_envs,\n", " VectorReplayBuffer(args.buffer_size, len(train_envs)),\n", " exploration_noise=True,\n", " )\n", " test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True)\n", "\n", " # Collect initial random samples\n", " training_collector.reset()\n", " training_collector.collect(n_step=args.batch_size * args.num_train_envs)\n", "\n", " # ======== Logging Setup =========\n", " log_path = os.path.join(args.logdir, \"tic_tac_toe\", \"dqn\")\n", " writer = SummaryWriter(log_path)\n", " writer.add_text(\"args\", str(args))\n", " logger = TensorboardLogger(writer)\n", "\n", " player_agent_id = agents[args.agent_id - 1]\n", "\n", " # ======== Callback Functions =========\n", " def save_best_fn(policy: Algorithm) -> None:\n", " \"\"\"Save the best performing policy.\"\"\"\n", " if hasattr(args, \"model_save_path\") and args.model_save_path:\n", " model_save_path = args.model_save_path\n", " else:\n", " model_save_path = os.path.join(args.logdir, \"tic_tac_toe\", \"dqn\", \"policy.pth\")\n", " torch.save(policy.get_algorithm(player_agent_id).state_dict(), model_save_path)\n", "\n", " def stop_fn(mean_rewards: float) -> bool:\n", " \"\"\"Stop training when target win rate is achieved.\"\"\"\n", " return mean_rewards >= args.win_rate\n", "\n", " def reward_metric(rews: np.ndarray) -> np.ndarray:\n", " \"\"\"Extract the reward for our learning agent.\"\"\"\n", " return rews[:, args.agent_id - 1]\n", "\n", " # ======== Trainer =========\n", " result = marl_algorithm.run_training(\n", " OffPolicyTrainerParams(\n", " training_collector=training_collector,\n", " test_collector=test_collector,\n", " max_epochs=args.epoch,\n", " epoch_num_steps=args.epoch_num_steps,\n", " collection_step_num_env_steps=args.collection_step_num_env_steps,\n", " test_step_num_episodes=args.num_test_envs,\n", " batch_size=args.batch_size,\n", " stop_fn=stop_fn,\n", " save_best_fn=save_best_fn,\n", " update_step_num_gradient_steps_per_sample=args.update_per_step,\n", " logger=logger,\n", " test_in_training=False,\n", " multi_agent_return_reduction=reward_metric,\n", " show_progress=False,\n", " )\n", " )\n", "\n", " return result, marl_algorithm.get_algorithm(player_agent_id)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluation Function\n", "\n", "This function allows us to watch a trained agent play:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def watch(\n", " args,\n", " agent_learn: OffPolicyAlgorithm | None = None,\n", " agent_opponent: OffPolicyAlgorithm | None = None,\n", ") -> None:\n", " \"\"\"Watch a pre-trained agent play.\"\"\"\n", " env = DummyVectorEnv([partial(get_env, render_mode=\"human\")])\n", " policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)\n", " collector = Collector[CollectStats](policy, env, exploration_noise=True)\n", " result = collector.collect(n_episode=1, render=args.render, reset_before_collect=True)\n", " result.pprint_asdict()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Running the Training\n", "\n", "Now let's train the agent and watch it play!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Train the agent\n", "result, agent = train_agent(args)\n", "\n", "# Watch the trained agent play\n", "watch(args, agent)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training Results\n", "\n", "After training for less than a minute, you'll see the agent play against the random opponent. Here's an example game:\n", "\n", "
\n", "Example: Trained Agent vs Random Opponent\n", "\n", "```\n", " | |\n", " - | - | -\n", "_____|_____|_____\n", " | |\n", " - | - | X\n", "_____|_____|_____\n", " | |\n", " - | - | -\n", " | |\n", "```\n", "\n", "```\n", " | |\n", " - | - | -\n", "_____|_____|_____\n", " | |\n", " - | O | X\n", "_____|_____|_____\n", " | |\n", " - | - | -\n", " | |\n", "```\n", "\n", "```\n", " | |\n", " - | - | -\n", "_____|_____|_____\n", " | |\n", " X | O | X\n", "_____|_____|_____\n", " | |\n", " - | - | -\n", " | |\n", "```\n", "\n", "```\n", " | |\n", " - | O | -\n", "_____|_____|_____\n", " | |\n", " X | O | X\n", "_____|_____|_____\n", " | |\n", " - | - | -\n", " | |\n", "```\n", "\n", "```\n", " | |\n", " - | O | -\n", "_____|_____|_____\n", " | |\n", " X | O | X\n", "_____|_____|_____\n", " | |\n", " - | X | -\n", " | |\n", "```\n", "\n", "```\n", " | |\n", " O | O | -\n", "_____|_____|_____\n", " | |\n", " X | O | X\n", "_____|_____|_____\n", " | |\n", " - | X | -\n", " | |\n", "```\n", "\n", "```\n", " | |\n", " O | O | X\n", "_____|_____|_____\n", " | |\n", " X | O | X\n", "_____|_____|_____\n", " | |\n", " - | X | -\n", " | |\n", "```\n", "\n", "```\n", " | |\n", " O | O | X\n", "_____|_____|_____\n", " | |\n", " X | O | X\n", "_____|_____|_____\n", " | |\n", " - | X | O\n", " | |\n", "```\n", "\n", "Final reward: 1.0, length: 8.0\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that our trained agent plays as player 2 (O) and wins! The agent has learned the game rules through trial and error, understanding that three consecutive O marks lead to victory." ] }, { "cell_type": "markdown", "metadata": {}, "source": "It is easily possible to make the trained agent play against itself. Try this as an exercise!" }, { "cell_type": "markdown", "metadata": {}, "source": [ "While the trained agent plays well against a random opponent, it's still far from perfect play. The next step would be to implement self-play training, similar to AlphaZero, where the agent continuously improves by playing against increasingly stronger versions of itself." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "In this tutorial, we demonstrated how to use Tianshou for training a single agent in a multi-agent reinforcement learning setting. Key takeaways:\n", "\n", "1. **MARL Paradigms**: Tianshou supports simultaneous, cyclic, and conditional move scenarios\n", "2. **Abstraction**: Multi-agent problems can be converted to single-agent RL through clever state augmentation\n", "3. **PettingZoo Integration**: Seamless compatibility with PettingZoo environments via `PettingZooEnv`\n", "4. **Algorithm Management**: `MultiAgentOffPolicyAlgorithm` handles agent coordination and data distribution\n", "5. **Flexible Framework**: Easy to extend from single-agent training to more complex multi-agent scenarios\n", "\n", "Tianshou provides a flexible and intuitive framework for reinforcement learning. Experiment with different architectures, training regimes, and opponent strategies to build even more capable agents!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: docs/04_benchmarks/benchmarks.rst ================================================ Benchmarks ========== Tianshou's algorithm implementations lead to state-of-the-art results on standard benchmarks. An efficient parallel implementation for evaluating algorithms on mujoco or atari is in `benchmark/run_benchmark.py`. It can easily be adapted for custom benchmarks as well. The reported results are thus completely reproducible. The evaluation code uses Tianshou's integration with the `rliable `_ framework, which supports best practices for trustworthy RL evaluation. Each experiment is conducted under 5 random seeds, we report the interquartile mean (IQM) and 95% confidence intervals over these seeds. .. raw:: html


================================================ FILE: docs/05_developer_guide/developer_guide.md ================================================ # Developer Guide The section addresses developers of Tianshou, providing information for both casual contributors and maintainers alike. ## Python Virtual Environment Tianshou is built and managed by [poetry](https://python-poetry.org/). The development environment uses Python 3.11. To install all relevant requirements (as well as Tianshou itself in editable mode) you can simply call poetry install --with dev ```{important} Depending on your setup, you may need to create and activate an empty virtual environment using the right Python version beforehand. For instance, to do this with conda, use: conda create -n tianshou python=3.11 conda activate tianshou ``` ## Code Style and Auto-Formatting When editing code in Tianshou, strive for **local consistency**, i.e. adhere to the style already present in the codebase. Tianshou uses an auto-formatting for consistency. To apply it, call poe format To check whether your formatting is compliant without applying the auto-formatter, call poe lint ## Type Checking We use [mypy](https://github.com/python/mypy/) to perform static type analysis. To check typing, run poe type-check ## Tests ### Running the Test Suite Locally Tianshou uses pytest. Tests are located in `./test`. To run the full set of tests locally, run poe test ### Determinism Tests We implemented **determinism tests** for Tianshou's algorithms, which allow us to determine whether algorithms still compute exactly the same results even after large refactorings. These tests are applied by 1. creating a behavior snapshot in the old code branch before the changes and then 2. running the respective determinism test in the new branch to ensure that the behavior is the same. Unfortunately, full determinism is difficult to achieve across different platforms and even different machines using the same platform an Python environment. Therefore, these tests are not carried out in the CI pipeline. Instead, it is up to the developer to run them locally and check the results whenever a change is made to the codebase that could affect algorithm behavior. Technically, the two steps are handled by setting static flags in class `AlgorithmDeterminismTest` and then running either the full test suite or a specific determinism test (`test_*_determinism`, e.g. `test_ddpg_determinism`) in the two branches to be compared. 1. On the old branch: (Temporarily) set `ENABLED=True` and `FORCE_SNAPSHOT_UPDATE=True` and run the test(s). 2. On the new branch: (Temporarily) set `ENABLED=True` and `FORCE_SNAPSHOT_UPDATE=False` and run the test(s). 3. Inspect the test results; find a summary in `determinism_tests.log` ### Tests in CI (GitHub Actions) CI tests will extensively test Tianshou's functionality in multiple environments. In particular, we test * on Ubuntu (full functionality tested) * **py_pinned**: using the pinned development environment (Python 3.11, known versions of all dependencies) * **py_latest**: using a more recent Python version with the newest set of compatible dependencies (automatically resolved) * on Windows and macOS (core functionality tested) #### Principle of Maximum Compatibility The idea behind testing with dynamically resolved dependencies is that we want to maximize the applicability of Tianshou: For important dependencies that could conflict with environments used by our users, **we do not restrict the version of a dependency unless there is a known incompatibility.** If incompatibilities should arise (e.g. by the "py_latest" test failing), we either * resolve them by making the code compatible with both old and new versions OR * add an upper bound to our dependency declarations (excluding the incompatible versions) and release a new version of Tianshou to make these exclusions explicit. ## High-Level API The high-level API provides a declarative, user-friendly interface for setting up reinforcement learning experiments. From a library developer's perspective, it is important that this API be clearly structured and maintainable. This section explains the architectural principles and how to extend the API to support new algorithms. ### Core Abstractions The high-level API is built around a clear separation of concerns: **Parameter Classes** are dataclasses (inheriting from `Params`) that represent algorithm-specific configuration. They capture hyperparameters in a high-level, user-friendly form. Because the high-level interface must abstract away from low-level details, parameters may need transformation before being passed to policy classes. This is handled via `ParamTransformer` instances, which successively transform the parameter dictionary representation. To maintain clarity and reduce coupling, parameter transformers are co-located with the parameters they affect. The system uses inheritance and mixins extensively to reduce duplication while maintaining flexibility. **Factories** embody the principle of declarative configuration. Because object creation may depend on other objects that don't yet exist at configuration time (e.g., neural networks depend on environment properties), the API transitions from objects to factories. Key factory types include: - `EnvFactory` for creating training, test, and watch environments - `AgentFactory` as the central factory that creates policies, trainers, and collectors - Various specialized factories for optimizers, actors, critics, noise, distributions, learning rate schedulers, and policy wrappers **Algorithm Factories** (subclasses of `AlgorithmFactory`) are the core components responsible for orchestrating the creation of all algorithm-specific objects. They handle the creation of neural network architectures, apply parameter transformations, instantiate policies, and create trainers with appropriate collectors. To support a new algorithm, this is the primary extension point. **Experiment Builders** (subclasses of `ExperimentBuilder`) provide the user-facing interface following the builder pattern. They contain sensible defaults while allowing customization through fluent `with_*` methods. Builder mixins provide composable functionality for common patterns (e.g., actor/critic configuration), avoiding code duplication across algorithm-specific implementations. ### Supporting a New Algorithm Extending the high-level API to support a new algorithm involves creating three main components: **Parameter Class**: Define a dataclass in `tianshou/highlevel/params/algorithm_params.py` that inherits from appropriate base classes and mixins. The choice of base class depends on the algorithm's architecture (actor-critic, single network, etc.) and learning paradigm (on-policy, off-policy). Override `_get_param_transformers()` to specify how high-level parameters should be transformed for the low-level policy API. Common transformers handle optimizer creation, noise instantiation, and environment-dependent parameter resolution. **Algorithm Factory**: Implement a subclass of `AlgorithmFactory` in `tianshou/highlevel/algorithm.py`. In most cases, inherit from existing base factories like `ActorCriticOnPolicyAlgorithmFactory`, `ActorCriticOffPolicyAlgorithmFactory`, or `DiscreteCriticOnlyOffPolicyAlgorithmFactory`, which handle common creation patterns. The primary requirement is implementing `_get_algorithm_class()` to return the appropriate algorithm class. For algorithms with non-standard requirements, override `_create_algorithm()`, `_create_kwargs()`, etc. to customize the instantiation logic. **Experiment Builder**: Add a builder class in `tianshou/highlevel/experiment.py` that inherits from `OnPolicyExperimentBuilder` or `OffPolicyExperimentBuilder` along with appropriate mixins. The mixins provide standard functionality for configuring actors and critics (single critic, dual critics, critic ensembles, parameter sharing patterns, etc.). The main responsibility is implementing `_create_algorithm_factory()` to instantiate the algorithm factory with appropriate parameters and network factories. Optionally provide `with_*` methods for algorithm-specific configuration. Export the new classes in `tianshou/highlevel/__init__.py` to make them available to users. ### Design Principles The architecture follows several key principles: **Separation of Concerns**: Configuration is cleanly separated from implementation. The transformation system bridges these layers while maintaining independence. **Declarative Configuration**: Factories enable a declarative style where experiments are defined by what should be created rather than imperative steps. This makes experiments easily serializable and reproducible. **Composition and Inheritance**: Mixins and inheritance reduce code duplication. Common functionality is factored into reusable components while maintaining flexibility for algorithm-specific requirements. **Progressive Disclosure**: The API provides sensible defaults for simple use cases while allowing deep customization when needed. Users can progress from simple configurations to advanced setups without fighting the abstractions. **Co-location**: Related code is kept together. Parameter transformers are defined near the parameters they transform, maintaining clarity about dependencies and making the codebase easier to navigate. **Type Safety**: Extensive use of generics and type hints ensures that type checkers can catch configuration errors at development time rather than runtime. ## Documentation Documentation is in the `docs/` directory, using Markdown (`.md`), ReStructuredText (`.rst`) and notebook files. `index.rst` is the main page. API References are automatically generated by [Sphinx](http://www.sphinx-doc.org/en/stable/) according to the outlines under `docs/api/` and should be modified when any code changes. To compile documentation into webpage, run poe doc-build The generated webpages can subsequently be found in `docs/_build` and can be viewed with any browser. ### Verifications We have several automated verification methods for documentation: 1. pydocstyle (as part of ruff): tests all docstring under `tianshou/`; 2. doc8 (as part of ruff): tests ReStructuredText format; 3. sphinx spelling and test: test if there is any error/warning when generating front-end html documentation. ## Creating a Release To release a new version on PyPI, * set the version to be released in `tianshou/__init__.py` and in `pyproject.toml`, creating a commit * tag the commit with the version (using the format `v1.2.3`) * push the commit (`git push`) and the tag (`git push --tags`) * create a new release on GitHub based on the tag; this will trigger the release job for PyPI. In the past, we provided releases to conda-forge as well, but this is currently not maintained. ================================================ FILE: docs/06_contributors/contributors.rst ================================================ Contributors ============ We always welcome contributions to help make Tianshou better! Tianshou was originally created by the `THU-ML Group `_ at Tsinghua University. Today, it is backed by the `appliedAI Institute for Europe `_, a non-profit organization committed to making Tianshou the go-to resource for reinforcement learning research and development, guaranteeing its long-term maintenance and support. The original creator Jiayi Weng (`Trinkle23897 `_) continues to be involved in Tianshou development. The current core developers, who are behind the v1.0 and v2.0 releases of Tianshou, are: * Dominik Jain (`opcode81 `_) * Michael Panchenko (`MischaPanch `_) An incomplete list of early contributors is: * Alexis Duburcq (`duburcqa `_) * Kaichao You (`youkaichao `_) * Huayu Chen (`ChenDRAG `_) * Yi Su (`nuance1979 `_) You can find more information about contributors `here `_. ================================================ FILE: docs/_config.yml ================================================ # Book settings # Learn more at https://jupyterbook.org/customize/config.html ####################################################################################### # A default configuration that will be loaded for all jupyter books # Users are expected to override these values in their own `_config.yml` file. # This is also the "master list" of all allowed keys and values. ####################################################################################### # Book settings title : Tianshou Documentation # The title of the book. Will be placed in the left navbar. author : Tianshou contributors # The author of the book copyright : "2020, Tianshou contributors." # Copyright year to be placed in the footer logo : _static/images/tianshou-logo.png # A path to the book logo # Patterns to skip when building the book. Can be glob-style (e.g. "*skip.ipynb") exclude_patterns : ['**.ipynb_checkpoints', '.DS_Store', 'Thumbs.db', '_build', 'jupyter_execute', '.jupyter_cache', '.pytest_cache', 'docs/autogen_rst.py', 'docs/create_toc.py'] # Auto-exclude files not in the toc only_build_toc_files : false ####################################################################################### # Execution settings execute: execute_notebooks : cache # Whether to execute notebooks at build time. Must be one of ("auto", "force", "cache", "off") cache : "" # A path to the jupyter cache that will be used to store execution artifacts. Defaults to `_build/.jupyter_cache/` exclude_patterns : [] # A list of patterns to *skip* in execution (e.g. a notebook that takes a really long time) timeout : -1 # The maximum time (in seconds) each notebook cell is allowed to run. run_in_temp : false # If `True`, then a temporary directory will be created and used as the command working directory (cwd), # otherwise the notebook's parent directory will be the cwd. allow_errors : false # If `False`, when a code cell raises an error the execution is stopped, otherwise all cells are always run. stderr_output : show # One of 'show', 'remove', 'remove-warn', 'warn', 'error', 'severe' ####################################################################################### # Parse and render settings parse: myst_enable_extensions: # default extensions to enable in the myst parser. See https://myst-parser.readthedocs.io/en/latest/using/syntax-optional.html - amsmath - colon_fence # - deflist - dollarmath - html_admonition # - html_image - linkify # - replacements # - smartquotes - substitution - tasklist myst_url_schemes: [ mailto, http, https ] # URI schemes that will be recognised as external URLs in Markdown links myst_dmath_double_inline: true # Allow display math ($$) within an inline context ####################################################################################### # HTML-specific settings html: favicon : "_static/images/tianshou-favicon.png" # A path to a favicon image use_edit_page_button : false # Whether to add an "edit this page" button to pages. If `true`, repository information in repository: must be filled in use_repository_button : false # Whether to add a link to your repository button use_issues_button : false # Whether to add an "open an issue" button use_multitoc_numbering : true # Continuous numbering across parts/chapters extra_footer : "" google_analytics_id : "" # A GA id that can be used to track book views. home_page_in_navbar : true # Whether to include your home page in the left Navigation Bar baseurl : "https://tianshou.readthedocs.io/en/master/" analytics: comments: hypothesis : false utterances : false announcement : "" # A banner announcement at the top of the site. ####################################################################################### # LaTeX-specific settings latex: latex_engine : pdflatex # one of 'pdflatex', 'xelatex' (recommended for unicode), 'luatex', 'platex', 'uplatex' use_jupyterbook_latex : true # use sphinx-jupyterbook-latex for pdf builds as default targetname : book.tex # Add a bibtex file so that we can create citations bibtex_bibfiles: - refs.bib ####################################################################################### # Launch button settings launch_buttons: notebook_interface : classic # The interface interactive links will activate ["classic", "jupyterlab"] binderhub_url : "" # The URL of the BinderHub (e.g., https://mybinder.org) jupyterhub_url : "" # The URL of the JupyterHub (e.g., https://datahub.berkeley.edu) thebe : false # Add a thebe button to pages (requires the repository to run on Binder) colab_url : "https://colab.research.google.com" repository: url : https://github.com/thu-ml/tianshou # The URL to your book's repository path_to_book : docs # A path to your book's folder, relative to the repository root. branch : master # Which branch of the repository should be used when creating links ####################################################################################### # Advanced and power-user settings sphinx: extra_extensions : - sphinx.ext.autodoc - sphinx.ext.viewcode - sphinx_toolbox.more_autodoc.sourcelink - sphinxcontrib.spelling - sphinxcontrib.mermaid local_extensions : # A list of local extensions to load by sphinx specified by "name: path" items recursive_update : false # A boolean indicating whether to overwrite the Sphinx config (true) or recursively update (false) config : # key-value pairs to directly over-ride the Sphinx configuration autodoc_typehints_format: "short" autodoc_member_order: "bysource" autodoc_mock_imports: # mock imports for optional dependencies (e.g. dependencies of atari/atari_wrapper) - cv2 autoclass_content: "both" autodoc_default_options: show-inheritance: True html_js_files: # We have to list them explicitly because they need to be loaded in a specific order - js/vega@5.js - js/vega-lite@5.js - js/vega-embed@5.js autodoc_show_sourcelink: True add_module_names: False github_username: thu-ml github_repository: tianshou python_use_unqualified_type_names: True nb_mime_priority_overrides: [ [ 'html', 'application/vnd.jupyter.widget-view+json', 10 ], [ 'html', 'application/javascript', 20 ], [ 'html', 'text/html', 30 ], [ 'html', 'text/latex', 40 ], [ 'html', 'image/svg+xml', 50 ], [ 'html', 'image/png', 60 ], [ 'html', 'image/jpeg', 70 ], [ 'html', 'text/markdown', 80 ], [ 'html', 'text/plain', 90 ], [ 'spelling', 'application/vnd.jupyter.widget-view+json', 10 ], [ 'spelling', 'application/javascript', 20 ], [ 'spelling', 'text/html', 30 ], [ 'spelling', 'text/latex', 40 ], [ 'spelling', 'image/svg+xml', 50 ], [ 'spelling', 'image/png', 60 ], [ 'spelling', 'image/jpeg', 70 ], [ 'spelling', 'text/markdown', 80 ], [ 'spelling', 'text/plain', 90 ], ] mathjax_path: https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js myst_fence_as_directive: ["mermaid"] mathjax3_config: loader: { load: [ '[tex]/configmacros' ] } tex: packages: { '[+]': [ 'configmacros' ] } macros: vect: ["{\\mathbf{\\boldsymbol{#1}} }", 1] E: "{\\mathbb{E}}" P: "{\\mathbb{P}}" R: "{\\mathbb{R}}" abs: ["{\\left| #1 \\right|}", 1] simpl: ["{\\Delta^{#1} }", 1] amax: "{\\text{argmax}}" ================================================ FILE: docs/_static/css/style.css ================================================ body { font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; } /* Default header fonts are ugly */ h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; } /* Use white for docs background */ .wy-side-nav-search { background-color: #fff; } .wy-nav-content { max-width: 1200px !important; } .wy-nav-content-wrap, .wy-menu li.current > a { background-color: #fff; } .wy-side-nav-search>a img.logo { width: 80%; margin-top: 10px; } @media screen and (min-width: 1400px) { .wy-nav-content-wrap { background-color: #fff; } .wy-nav-content { background-color: #fff; } } /* Fixes for mobile */ .wy-nav-top { background-color: #fff; /*background-image: url('../images/tianshou-logo.png');*/ background-repeat: no-repeat; background-position: center; padding: 0; margin: 0.4045em 0.809em; color: #333; } .wy-nav-top > a { display: none; } @media (min-width: 960px) { .bd-page-width { max-width: none !important; } } @media screen and (max-width: 768px) { .wy-side-nav-search>a img.logo { height: 60px; } } /* This is needed to ensure that logo above search scales properly */ .wy-side-nav-search a { display: block; } /* This ensures that multiple constructors will remain in separate lines. */ .rst-content dl:not(.docutils) dt { display: table; } /* Use our red for literals (it's very similar to the original color) */ .rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { color: #4692BC; } .rst-content tt.xref, a .rst-content tt, .rst-content tt.xref, .rst-content code.xref, a .rst-content tt, a .rst-content code { color: #404040; } /* Change link colors (except for the menu) */ a { color: #4692BC; } a:hover { color: #4692BC; } a:visited { color: #4692BC; } .wy-menu a { color: #b3b3b3; } .wy-menu a:hover { color: #b3b3b3; } /* Default footer text is quite big */ footer { font-size: 80%; } footer .rst-footer-buttons { font-size: 125%; /* revert footer settings - 1/80% = 125% */ } footer p { font-size: 100%; } .ethical-rtd { display: none; } .ethical-fixedfooter { display: none; } .ethical-content { display: none; } /* For hidden headers that appear in TOC tree */ /* see http://stackoverflow.com/a/32363545/3343043 */ .rst-content .hidden-section { display: none; } nav .hidden-section { display: inherit; } .wy-side-nav-search>div.version { color: #000; } ================================================ FILE: docs/_static/js/benchmark.js ================================================ var mujoco_envs = [ "Ant-v4", ]; var atari_envs = [ // "PongNoFrameskip-v4", // "BreakoutNoFrameskip-v4", // "EnduroNoFrameskip-v4", // "QbertNoFrameskip-v4", // "MsPacmanNoFrameskip-v4", // "SeaquestNoFrameskip-v4", // "SpaceInvadersNoFrameskip-v4", ]; function getDataSource(selectEnv, dirName) { return { // Paths are relative to the only file using this script, which is docs/04_benchmarks/benchmarks.rst $schema: "https://vega.github.io/schema/vega-lite/v5.json", data: { url: "../_static/js/" + dirName + "/benchmark/" + selectEnv + "/results.json" }, mark: "line", height: 400, width: 800, params: [{name: "Range", value: 1000000, bind: {input: "range", min: 10000, max: 10000000}}], transform: [ {calculate: "datum.iqm_confidence_interval[0]", as: "iqm_confidence_lower"}, {calculate: "datum.iqm_confidence_interval[1]", as: "iqm_confidence_upper"}, {calculate: "datum.iqm", as: "tooltip_str"}, {filter: "datum.env_step <= Range"}, ], encoding: { color: {"field": "agent", "type": "nominal"}, x: {field: "env_step", type: "quantitative", title: "Env step"}, }, layer: [{ "encoding": { "opacity": {"value": 0.3}, "y": { "field": "iqm_confidence_lower", "type": "quantitative", }, "y2": {"field": "iqm_confidence_upper"}, tooltip: [ {field: "env_step", type: "quantitative", title: "Env step"}, {field: "agent", type: "nominal"}, {field: "tooltip_str", type: "nominal", title: "Return"}, ] }, "mark": "area" }, { "encoding": { "y": { "title": "Return Interquartile Mean (5 seeds)", "field": "iqm", "type": "quantitative" } }, "mark": "line" }] }; } function showMujocoResults(elem) { const selectEnv = elem.value || mujoco_envs[0]; const dataSource = getDataSource(selectEnv, "mujoco"); vegaEmbed("#vis-mujoco", dataSource); } function showAtariResults(elem) { const selectEnv = elem.value || atari_envs[0]; const dataSource = getDataSource(selectEnv, "atari"); vegaEmbed("#vis-atari", dataSource); } document.addEventListener('DOMContentLoaded', function() { var envMujocoSelect = $("#env-mujoco"); if (envMujocoSelect.length) { $.each(mujoco_envs, function(idx, env) {envMujocoSelect.append($("").val(env).html(env));}) showMujocoResults(envMujocoSelect); } var envAtariSelect = $("#env-atari"); if (envAtariSelect.length) { $.each(atari_envs, function(idx, env) {envAtariSelect.append($("").val(env).html(env));}) showAtariResults(envAtariSelect); } }); ================================================ FILE: docs/_static/js/copybutton.js ================================================ document.addEventListener('DOMContentLoaded', function() { /* Add a [>>>] button on the top-right corner of code samples to hide * the >>> and ... prompts and the output and thus make the code * copyable. */ var div = $('.highlight-python .highlight,' + '.highlight-python3 .highlight,' + '.highlight-pycon .highlight,' + '.highlight-default .highlight'); var pre = div.find('pre'); // get the styles from the current theme pre.parent().parent().css('position', 'relative'); var hide_text = 'Hide the prompts and output'; var show_text = 'Show the prompts and output'; var border_width = pre.css('border-top-width'); var border_style = pre.css('border-top-style'); var border_color = pre.css('border-top-color'); var button_styles = { 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', 'border-color': border_color, 'border-style': border_style, 'border-width': border_width, 'color': border_color, 'text-size': '75%', 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', 'border-radius': '0 3px 0 0' } // create and add the button to all the code blocks that contain >>> div.each(function(index) { var jthis = $(this); if (jthis.find('.gp').length > 0) { var button = $('>>>'); button.css(button_styles) button.attr('title', hide_text); button.data('hidden', 'false'); jthis.prepend(button); } // tracebacks (.gt) contain bare text elements that need to be // wrapped in a span to work with .nextUntil() (see later) jthis.find('pre:has(.gt)').contents().filter(function() { return ((this.nodeType == 3) && (this.data.trim().length > 0)); }).wrap(''); }); // define the behavior of the button when it's clicked $('.copybutton').click(function(e){ e.preventDefault(); var button = $(this); if (button.data('hidden') === 'false') { // hide the code output button.parent().find('.go, .gp, .gt').hide(); button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); button.css('text-decoration', 'line-through'); button.attr('title', show_text); button.data('hidden', 'true'); } else { // show the code output button.parent().find('.go, .gp, .gt').show(); button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); button.css('text-decoration', 'none'); button.attr('title', hide_text); button.data('hidden', 'false'); } }); }); ================================================ FILE: docs/_static/js/mujoco/benchmark/Ant-v4/results.json ================================================ [{"env_step":0.0,"rew":-42.10716438293457,"rew_std":24.039944270941408,"iqm":-40.97881825764974,"iqm_confidence_interval":[-70.32092030843098,-16.246893564860027],"agent":"NPG"},{"env_step":30720.0,"rew":-28.9559645652771,"rew_std":24.61339003090938,"iqm":-31.551718711853027,"iqm_confidence_interval":[-56.042564392089844,0.8885383605957031],"agent":"NPG"},{"env_step":61440.0,"rew":-11.55083270072937,"rew_std":6.110236122564615,"iqm":-12.476377169291178,"iqm_confidence_interval":[-17.99208927154541,-3.9522757530212402],"agent":"NPG"},{"env_step":92160.0,"rew":18.439939069747926,"rew_std":10.787288878835925,"iqm":18.934740384419758,"iqm_confidence_interval":[5.5553921063741045,30.87356185913086],"agent":"NPG"},{"env_step":122880.0,"rew":45.153709411621094,"rew_std":10.71529835147344,"iqm":42.32269287109375,"iqm_confidence_interval":[35.65855153401693,59.28111139933268],"agent":"NPG"},{"env_step":153600.0,"rew":88.77803573608398,"rew_std":23.054480042346643,"iqm":89.32733408610027,"iqm_confidence_interval":[60.482452392578125,112.37182362874348],"agent":"NPG"},{"env_step":184320.0,"rew":121.12091217041015,"rew_std":30.036537018727188,"iqm":120.6772689819336,"iqm_confidence_interval":[85.26046752929688,155.5569814046224],"agent":"NPG"},{"env_step":215040.0,"rew":165.61034240722657,"rew_std":32.24095152542376,"iqm":168.1877187093099,"iqm_confidence_interval":[126.84097290039062,202.26611836751303],"agent":"NPG"},{"env_step":245760.0,"rew":264.8183166503906,"rew_std":68.57585730650109,"iqm":258.05474853515625,"iqm_confidence_interval":[192.57537333170572,347.3933512369792],"agent":"NPG"},{"env_step":276480.0,"rew":316.14655151367185,"rew_std":61.54133584779444,"iqm":329.4031473795573,"iqm_confidence_interval":[237.4843953450521,373.4185078938802],"agent":"NPG"},{"env_step":307200.0,"rew":394.7198486328125,"rew_std":37.27962477000651,"iqm":394.5103759765625,"iqm_confidence_interval":[352.14459228515625,439.00913492838544],"agent":"NPG"},{"env_step":337920.0,"rew":526.7607849121093,"rew_std":89.47713614019793,"iqm":527.4289042154948,"iqm_confidence_interval":[423.1936747233073,624.7489217122396],"agent":"NPG"},{"env_step":368640.0,"rew":579.3040954589844,"rew_std":115.02891711430003,"iqm":588.0213419596354,"iqm_confidence_interval":[440.58685302734375,703.8764444986979],"agent":"NPG"},{"env_step":399360.0,"rew":724.5651000976562,"rew_std":105.69306907584757,"iqm":722.1583048502604,"iqm_confidence_interval":[604.6950073242188,845.4350992838541],"agent":"NPG"},{"env_step":430080.0,"rew":863.718896484375,"rew_std":188.1599330780623,"iqm":876.3805948893229,"iqm_confidence_interval":[638.7650349934896,1069.964619954427],"agent":"NPG"},{"env_step":460800.0,"rew":1041.2549926757813,"rew_std":166.71506155109884,"iqm":1064.1121215820312,"iqm_confidence_interval":[832.3571980794271,1219.2620849609375],"agent":"NPG"},{"env_step":491520.0,"rew":1216.867138671875,"rew_std":161.53753506344896,"iqm":1205.9603678385417,"iqm_confidence_interval":[1039.2998860677083,1407.0541178385417],"agent":"NPG"},{"env_step":522240.0,"rew":1292.862158203125,"rew_std":136.2897089330508,"iqm":1266.5582682291667,"iqm_confidence_interval":[1160.0815022786458,1464.893798828125],"agent":"NPG"},{"env_step":552960.0,"rew":1349.3367919921875,"rew_std":216.29374512926137,"iqm":1342.1337076822917,"iqm_confidence_interval":[1095.9574788411458,1596.9417724609375],"agent":"NPG"},{"env_step":583680.0,"rew":1286.1487548828125,"rew_std":325.0510298184137,"iqm":1309.6995442708333,"iqm_confidence_interval":[923.5931396484375,1622.830078125],"agent":"NPG"},{"env_step":614400.0,"rew":1501.306494140625,"rew_std":135.64654587687215,"iqm":1486.9634602864583,"iqm_confidence_interval":[1367.4110107421875,1674.4379069010417],"agent":"NPG"},{"env_step":645120.0,"rew":1566.505615234375,"rew_std":207.3775273975989,"iqm":1595.12890625,"iqm_confidence_interval":[1309.4273681640625,1787.2928873697917],"agent":"NPG"},{"env_step":675840.0,"rew":1648.5074951171875,"rew_std":167.47017729280728,"iqm":1674.1391194661458,"iqm_confidence_interval":[1451.5147705078125,1829.447021484375],"agent":"NPG"},{"env_step":706560.0,"rew":1826.64931640625,"rew_std":210.24840626670206,"iqm":1826.7261962890625,"iqm_confidence_interval":[1598.8306477864583,2082.1640625],"agent":"NPG"},{"env_step":737280.0,"rew":1870.0795654296876,"rew_std":116.74630498193775,"iqm":1893.3362630208333,"iqm_confidence_interval":[1723.8335367838542,1989.8927408854167],"agent":"NPG"},{"env_step":768000.0,"rew":1872.2861328125,"rew_std":85.56317841655876,"iqm":1872.4123942057292,"iqm_confidence_interval":[1775.9134521484375,1975.1356201171875],"agent":"NPG"},{"env_step":798720.0,"rew":2006.606005859375,"rew_std":258.3459636539935,"iqm":1976.9451904296875,"iqm_confidence_interval":[1724.650146484375,2327.4354654947915],"agent":"NPG"},{"env_step":829440.0,"rew":1947.3549560546876,"rew_std":218.90451416506937,"iqm":1959.3906656901042,"iqm_confidence_interval":[1693.5713704427083,2196.7334798177085],"agent":"NPG"},{"env_step":860160.0,"rew":2106.194873046875,"rew_std":117.60013614027503,"iqm":2067.5992838541665,"iqm_confidence_interval":[2008.2084147135417,2256.8001302083335],"agent":"NPG"},{"env_step":890880.0,"rew":2071.628271484375,"rew_std":242.5118564930354,"iqm":2149.12744140625,"iqm_confidence_interval":[1755.46728515625,2272.0733235677085],"agent":"NPG"},{"env_step":921600.0,"rew":2366.916162109375,"rew_std":289.17324371255546,"iqm":2325.9986979166665,"iqm_confidence_interval":[2074.3031412760415,2734.9005533854165],"agent":"NPG"},{"env_step":952320.0,"rew":2205.8991455078126,"rew_std":422.5754930755271,"iqm":2282.3590494791665,"iqm_confidence_interval":[1670.6670735677083,2625.806884765625],"agent":"NPG"},{"env_step":983040.0,"rew":2261.9572265625,"rew_std":509.6706066309041,"iqm":2190.5692138671875,"iqm_confidence_interval":[1760.44482421875,2919.6507161458335],"agent":"NPG"},{"env_step":1013760.0,"rew":2433.71572265625,"rew_std":493.58342825899086,"iqm":2341.9806315104165,"iqm_confidence_interval":[1942.987060546875,3055.232666015625],"agent":"NPG"},{"env_step":1044480.0,"rew":2279.4240234375,"rew_std":646.6016713206892,"iqm":2132.0210774739585,"iqm_confidence_interval":[1656.076904296875,3117.5598958333335],"agent":"NPG"},{"env_step":1075200.0,"rew":2621.333935546875,"rew_std":588.2199189530081,"iqm":2415.3409016927085,"iqm_confidence_interval":[2146.8382161458335,3389.3638509114585],"agent":"NPG"},{"env_step":1105920.0,"rew":2538.0572998046873,"rew_std":516.4605893509663,"iqm":2412.5494791666665,"iqm_confidence_interval":[2085.3553059895835,3158.64501953125],"agent":"NPG"},{"env_step":1136640.0,"rew":2696.96259765625,"rew_std":536.0556880955188,"iqm":2612.5350748697915,"iqm_confidence_interval":[2147.63671875,3383.4264322916665],"agent":"NPG"},{"env_step":1167360.0,"rew":2600.9685546875,"rew_std":502.4029487892295,"iqm":2644.7255859375,"iqm_confidence_interval":[2001.6358235677083,3170.20849609375],"agent":"NPG"},{"env_step":1198080.0,"rew":2765.19111328125,"rew_std":577.5845895357381,"iqm":2693.3756510416665,"iqm_confidence_interval":[2225.9715169270835,3483.7542317708335],"agent":"NPG"},{"env_step":1228800.0,"rew":2881.88994140625,"rew_std":527.5027862353074,"iqm":2920.961669921875,"iqm_confidence_interval":[2219.39794921875,3442.47705078125],"agent":"NPG"},{"env_step":1259520.0,"rew":2948.0620849609377,"rew_std":594.5163611818551,"iqm":3007.4317220052085,"iqm_confidence_interval":[2209.598876953125,3598.4187825520835],"agent":"NPG"},{"env_step":1290240.0,"rew":3064.065625,"rew_std":440.83851262024234,"iqm":2960.6610514322915,"iqm_confidence_interval":[2634.0765787760415,3625.16015625],"agent":"NPG"},{"env_step":1320960.0,"rew":3280.083984375,"rew_std":707.1134563719127,"iqm":3348.0303548177085,"iqm_confidence_interval":[2432.7224934895835,4057.8509114583335],"agent":"NPG"},{"env_step":1351680.0,"rew":3205.251513671875,"rew_std":316.46324853306436,"iqm":3163.8677571614585,"iqm_confidence_interval":[2869.4271647135415,3602.5740559895835],"agent":"NPG"},{"env_step":1382400.0,"rew":3444.421728515625,"rew_std":436.2355752828923,"iqm":3369.95654296875,"iqm_confidence_interval":[3007.6968587239585,4004.7176920572915],"agent":"NPG"},{"env_step":1413120.0,"rew":3011.940234375,"rew_std":634.0937754760221,"iqm":3051.2601725260415,"iqm_confidence_interval":[2281.888427734375,3751.8732096354165],"agent":"NPG"},{"env_step":1443840.0,"rew":3466.91298828125,"rew_std":572.6590330179324,"iqm":3594.772216796875,"iqm_confidence_interval":[2714.4599609375,3963.66162109375],"agent":"NPG"},{"env_step":1474560.0,"rew":3302.862255859375,"rew_std":762.3598545058242,"iqm":3239.7962239583335,"iqm_confidence_interval":[2556.1107584635415,4242.556315104167],"agent":"NPG"},{"env_step":1505280.0,"rew":3478.742822265625,"rew_std":557.1591110802749,"iqm":3495.2234700520835,"iqm_confidence_interval":[2808.22705078125,4099.184000651042],"agent":"NPG"},{"env_step":1536000.0,"rew":3490.6912109375,"rew_std":444.31353563177555,"iqm":3609.8375651041665,"iqm_confidence_interval":[2909.4974772135415,3896.072998046875],"agent":"NPG"},{"env_step":1566720.0,"rew":3646.065234375,"rew_std":773.6771731564615,"iqm":3672.367431640625,"iqm_confidence_interval":[2757.34033203125,4560.120768229167],"agent":"NPG"},{"env_step":1597440.0,"rew":3487.64560546875,"rew_std":466.1504258357491,"iqm":3650.9379069010415,"iqm_confidence_interval":[2892.069580078125,3872.5625],"agent":"NPG"},{"env_step":1628160.0,"rew":3813.238232421875,"rew_std":195.45548449408994,"iqm":3848.3933919270835,"iqm_confidence_interval":[3557.7486979166665,3992.6292317708335],"agent":"NPG"},{"env_step":1658880.0,"rew":3634.64384765625,"rew_std":357.5938916670007,"iqm":3626.664306640625,"iqm_confidence_interval":[3249.5238444010415,4047.3922526041665],"agent":"NPG"},{"env_step":1689600.0,"rew":3831.650048828125,"rew_std":373.6238249300375,"iqm":3942.2557779947915,"iqm_confidence_interval":[3367.9471842447915,4165.78466796875],"agent":"NPG"},{"env_step":1720320.0,"rew":3441.63681640625,"rew_std":107.20122493250447,"iqm":3469.4847005208335,"iqm_confidence_interval":[3299.667724609375,3530.85693359375],"agent":"NPG"},{"env_step":1751040.0,"rew":4085.60537109375,"rew_std":380.79098883303703,"iqm":4027.6840006510415,"iqm_confidence_interval":[3706.8218587239585,4541.901041666667],"agent":"NPG"},{"env_step":1781760.0,"rew":3908.0486328125,"rew_std":531.1436678021352,"iqm":3868.1513671875,"iqm_confidence_interval":[3365.0298665364585,4572.987141927083],"agent":"NPG"},{"env_step":1812480.0,"rew":3977.9677734375,"rew_std":334.31211529476235,"iqm":3977.7456868489585,"iqm_confidence_interval":[3583.3868001302085,4351.19580078125],"agent":"NPG"},{"env_step":1843200.0,"rew":3542.310986328125,"rew_std":385.2524129298876,"iqm":3677.203857421875,"iqm_confidence_interval":[3076.446533203125,3828.6275227864585],"agent":"NPG"},{"env_step":1873920.0,"rew":4163.43583984375,"rew_std":492.598368484635,"iqm":4214.400065104167,"iqm_confidence_interval":[3540.4664713541665,4677.500813802083],"agent":"NPG"},{"env_step":1904640.0,"rew":4181.5763671875,"rew_std":554.0300039532967,"iqm":4131.946533203125,"iqm_confidence_interval":[3620.850341796875,4883.752278645833],"agent":"NPG"},{"env_step":1935360.0,"rew":4026.918310546875,"rew_std":583.5711637398964,"iqm":4088.448486328125,"iqm_confidence_interval":[3314.7400716145835,4676.281412760417],"agent":"NPG"},{"env_step":1966080.0,"rew":4025.4564453125,"rew_std":454.36050405051066,"iqm":3980.6463216145835,"iqm_confidence_interval":[3539.2442220052085,4587.694498697917],"agent":"NPG"},{"env_step":1996800.0,"rew":4082.299462890625,"rew_std":567.8325714042335,"iqm":4076.89306640625,"iqm_confidence_interval":[3463.0868326822915,4747.28076171875],"agent":"NPG"},{"env_step":2027520.0,"rew":3936.75859375,"rew_std":526.8825948319501,"iqm":3821.395263671875,"iqm_confidence_interval":[3427.847412109375,4580.795654296875],"agent":"NPG"},{"env_step":2058240.0,"rew":4397.854345703125,"rew_std":587.2665622699866,"iqm":4586.83935546875,"iqm_confidence_interval":[3680.1486002604165,4879.48291015625],"agent":"NPG"},{"env_step":2088960.0,"rew":4056.78583984375,"rew_std":715.0343288525298,"iqm":4241.515950520833,"iqm_confidence_interval":[3110.4615885416665,4684.402180989583],"agent":"NPG"},{"env_step":2119680.0,"rew":4201.72041015625,"rew_std":434.0396009798354,"iqm":4308.550944010417,"iqm_confidence_interval":[3637.28857421875,4594.300618489583],"agent":"NPG"},{"env_step":2150400.0,"rew":3781.59541015625,"rew_std":346.4414147044312,"iqm":3839.78369140625,"iqm_confidence_interval":[3346.6486002604165,4140.158121744792],"agent":"NPG"},{"env_step":2181120.0,"rew":3943.3373046875,"rew_std":610.6881082446964,"iqm":3962.8224283854165,"iqm_confidence_interval":[3243.5960286458335,4605.189534505208],"agent":"NPG"},{"env_step":2211840.0,"rew":4266.37197265625,"rew_std":681.0246834918399,"iqm":4187.938720703125,"iqm_confidence_interval":[3565.35400390625,5115.400065104167],"agent":"NPG"},{"env_step":2242560.0,"rew":4126.19501953125,"rew_std":487.938527027499,"iqm":4053.7718098958335,"iqm_confidence_interval":[3646.608642578125,4739.072591145833],"agent":"NPG"},{"env_step":2273280.0,"rew":4228.608203125,"rew_std":549.7404756899181,"iqm":4230.651936848958,"iqm_confidence_interval":[3596.397216796875,4873.546712239583],"agent":"NPG"},{"env_step":2304000.0,"rew":4057.852783203125,"rew_std":602.2027305602329,"iqm":4093.2787272135415,"iqm_confidence_interval":[3350.7344563802085,4752.5224609375],"agent":"NPG"},{"env_step":2334720.0,"rew":4452.448046875,"rew_std":531.1544726008665,"iqm":4385.053548177083,"iqm_confidence_interval":[3901.7164713541665,5128.13818359375],"agent":"NPG"},{"env_step":2365440.0,"rew":4271.966796875,"rew_std":437.37979559110437,"iqm":4285.06005859375,"iqm_confidence_interval":[3740.1287434895835,4744.523600260417],"agent":"NPG"},{"env_step":2396160.0,"rew":4347.164404296875,"rew_std":223.5063369913802,"iqm":4431.816080729167,"iqm_confidence_interval":[4060.6329752604165,4509.202473958333],"agent":"NPG"},{"env_step":2426880.0,"rew":4481.2849609375,"rew_std":346.00707012552084,"iqm":4513.728515625,"iqm_confidence_interval":[4048.8810221354165,4857.212076822917],"agent":"NPG"},{"env_step":2457600.0,"rew":4564.82314453125,"rew_std":637.0825921129947,"iqm":4590.755533854167,"iqm_confidence_interval":[3819.2635091145835,5305.782552083333],"agent":"NPG"},{"env_step":2488320.0,"rew":4520.0427734375,"rew_std":436.4336347118552,"iqm":4590.888834635417,"iqm_confidence_interval":[3962.0810546875,4967.877278645833],"agent":"NPG"},{"env_step":2519040.0,"rew":4359.85390625,"rew_std":287.77420296772794,"iqm":4469.613932291667,"iqm_confidence_interval":[4007.3859049479165,4579.832682291667],"agent":"NPG"},{"env_step":2549760.0,"rew":4693.39501953125,"rew_std":396.4874171852096,"iqm":4781.868326822917,"iqm_confidence_interval":[4193.547037760417,5072.042643229167],"agent":"NPG"},{"env_step":2580480.0,"rew":4582.3765625,"rew_std":421.0727810737856,"iqm":4619.725748697917,"iqm_confidence_interval":[4058.74072265625,5022.57373046875],"agent":"NPG"},{"env_step":2611200.0,"rew":4703.40751953125,"rew_std":210.00483804774578,"iqm":4754.280924479167,"iqm_confidence_interval":[4431.056315104167,4899.427571614583],"agent":"NPG"},{"env_step":2641920.0,"rew":4900.25234375,"rew_std":290.4812495566195,"iqm":4872.10693359375,"iqm_confidence_interval":[4618.145182291667,5264.572916666667],"agent":"NPG"},{"env_step":2672640.0,"rew":4474.5326171875,"rew_std":269.2334315633679,"iqm":4443.0693359375,"iqm_confidence_interval":[4191.284016927083,4815.879231770833],"agent":"NPG"},{"env_step":2703360.0,"rew":4658.04375,"rew_std":336.6230138689072,"iqm":4545.173502604167,"iqm_confidence_interval":[4387.645182291667,5095.66064453125],"agent":"NPG"},{"env_step":2734080.0,"rew":4878.24658203125,"rew_std":334.79437839842825,"iqm":4920.015950520833,"iqm_confidence_interval":[4472.2958984375,5213.7880859375],"agent":"NPG"},{"env_step":2764800.0,"rew":4760.90791015625,"rew_std":214.13738081026182,"iqm":4727.073567708333,"iqm_confidence_interval":[4549.66845703125,5038.178548177083],"agent":"NPG"},{"env_step":2795520.0,"rew":4684.60634765625,"rew_std":77.41656898240524,"iqm":4706.542317708333,"iqm_confidence_interval":[4591.1083984375,4749.69140625],"agent":"NPG"},{"env_step":2826240.0,"rew":4977.3126953125,"rew_std":182.15891965403532,"iqm":4958.58740234375,"iqm_confidence_interval":[4791.443684895833,5209.349609375],"agent":"NPG"},{"env_step":2856960.0,"rew":4734.32392578125,"rew_std":102.18784300235029,"iqm":4715.5224609375,"iqm_confidence_interval":[4632.307942708333,4864.966959635417],"agent":"NPG"},{"env_step":2887680.0,"rew":5007.94814453125,"rew_std":206.2397802762669,"iqm":5018.766438802083,"iqm_confidence_interval":[4766.834309895833,5246.975911458333],"agent":"NPG"},{"env_step":2918400.0,"rew":5097.96357421875,"rew_std":163.5823288450074,"iqm":5032.53125,"iqm_confidence_interval":[4977.635091145833,5299.5048828125],"agent":"NPG"},{"env_step":2949120.0,"rew":4788.7689453125,"rew_std":376.3747806878339,"iqm":4787.894368489583,"iqm_confidence_interval":[4337.253580729167,5212.905110677083],"agent":"NPG"},{"env_step":2979840.0,"rew":4699.702587890625,"rew_std":583.2143349211142,"iqm":4777.147786458333,"iqm_confidence_interval":[3975.2060546875,5330.618489583333],"agent":"NPG"},{"env_step":3010560.0,"rew":4728.6564453125,"rew_std":402.81542872103023,"iqm":4677.973795572917,"iqm_confidence_interval":[4308.068196614583,5241.167154947917],"agent":"NPG"},{"env_step":3041280.0,"rew":4987.98134765625,"rew_std":409.05004072355825,"iqm":4909.922526041667,"iqm_confidence_interval":[4619.532063802083,5518.312662760417],"agent":"NPG"},{"env_step":3072000.0,"rew":4764.5708984375,"rew_std":448.03125410300396,"iqm":4817.55908203125,"iqm_confidence_interval":[4203.057291666667,5225.935709635417],"agent":"NPG"},{"env_step":0.0,"rew":-42.10716438293457,"rew_std":24.039944270941408,"iqm":-40.97881825764974,"iqm_confidence_interval":[-70.32092030843098,-16.246893564860027],"agent":"PPO"},{"env_step":30720.0,"rew":-37.372220039367676,"rew_std":19.689550303988273,"iqm":-35.171875,"iqm_confidence_interval":[-62.41522725423177,-17.062697728474934],"agent":"PPO"},{"env_step":61440.0,"rew":-21.24895796775818,"rew_std":14.76330539847627,"iqm":-22.682526270548504,"iqm_confidence_interval":[-36.67624537150065,-2.578148365020752],"agent":"PPO"},{"env_step":92160.0,"rew":-13.753775215148925,"rew_std":14.13574584193426,"iqm":-19.474328994750977,"iqm_confidence_interval":[-23.108850479125977,3.3546009063720703],"agent":"PPO"},{"env_step":122880.0,"rew":-12.556952285766602,"rew_std":13.641631192489111,"iqm":-8.645663897196451,"iqm_confidence_interval":[-29.560285568237305,-0.892268180847168],"agent":"PPO"},{"env_step":153600.0,"rew":5.548850393295288,"rew_std":5.419250191615881,"iqm":6.114233175913493,"iqm_confidence_interval":[-0.7145605087280273,10.98777691523234],"agent":"PPO"},{"env_step":184320.0,"rew":4.924838566780091,"rew_std":8.117651806112379,"iqm":6.176069458325704,"iqm_confidence_interval":[-5.165180961290996,13.583230336507162],"agent":"PPO"},{"env_step":215040.0,"rew":17.765496063232423,"rew_std":13.144212017376775,"iqm":20.197986602783203,"iqm_confidence_interval":[1.0707556406656902,30.401254653930664],"agent":"PPO"},{"env_step":245760.0,"rew":31.95562801361084,"rew_std":12.791377589350782,"iqm":35.37232271830241,"iqm_confidence_interval":[15.888467152913412,43.269569396972656],"agent":"PPO"},{"env_step":276480.0,"rew":63.5789077758789,"rew_std":14.202631661044354,"iqm":66.58460362752278,"iqm_confidence_interval":[44.90095901489258,76.55831654866536],"agent":"PPO"},{"env_step":307200.0,"rew":86.53251419067382,"rew_std":29.695438915186116,"iqm":80.42616017659505,"iqm_confidence_interval":[55.82790883382162,122.85542551676433],"agent":"PPO"},{"env_step":337920.0,"rew":102.9419448852539,"rew_std":36.432540278237525,"iqm":93.47763570149739,"iqm_confidence_interval":[70.79665120442708,150.19486490885416],"agent":"PPO"},{"env_step":368640.0,"rew":138.24014587402343,"rew_std":17.957499408815412,"iqm":137.02452596028647,"iqm_confidence_interval":[118.35025533040364,158.88599141438803],"agent":"PPO"},{"env_step":399360.0,"rew":166.96195373535156,"rew_std":10.492143797386662,"iqm":165.95062255859375,"iqm_confidence_interval":[156.25365702311197,179.8980712890625],"agent":"PPO"},{"env_step":430080.0,"rew":188.4481964111328,"rew_std":21.446438325420985,"iqm":185.77357482910156,"iqm_confidence_interval":[165.51288350423178,215.12000528971353],"agent":"PPO"},{"env_step":460800.0,"rew":254.69054565429687,"rew_std":40.37535841417855,"iqm":246.8503214518229,"iqm_confidence_interval":[216.7704823811849,307.5547281901042],"agent":"PPO"},{"env_step":491520.0,"rew":313.9736083984375,"rew_std":13.84002768360938,"iqm":313.45168050130206,"iqm_confidence_interval":[298.9747823079427,330.33074951171875],"agent":"PPO"},{"env_step":522240.0,"rew":321.7760498046875,"rew_std":27.74167703701409,"iqm":319.7954610188802,"iqm_confidence_interval":[292.18919881184894,356.52332560221356],"agent":"PPO"},{"env_step":552960.0,"rew":377.2718872070312,"rew_std":82.12267171002435,"iqm":350.90550740559894,"iqm_confidence_interval":[306.2198486328125,479.6610412597656],"agent":"PPO"},{"env_step":583680.0,"rew":390.71156616210936,"rew_std":35.88038655768206,"iqm":389.2766825358073,"iqm_confidence_interval":[350.6089782714844,434.0635986328125],"agent":"PPO"},{"env_step":614400.0,"rew":470.87006225585935,"rew_std":100.31065042455326,"iqm":447.6538899739583,"iqm_confidence_interval":[374.72276814778644,600.7935587565104],"agent":"PPO"},{"env_step":645120.0,"rew":522.5170715332031,"rew_std":101.98108030878707,"iqm":549.3511149088541,"iqm_confidence_interval":[398.22825113932294,611.7529296875],"agent":"PPO"},{"env_step":675840.0,"rew":539.6492065429687,"rew_std":127.94532386643006,"iqm":551.1155395507812,"iqm_confidence_interval":[388.29457600911456,673.1277262369791],"agent":"PPO"},{"env_step":706560.0,"rew":600.2794616699218,"rew_std":82.50651436358652,"iqm":606.5103149414062,"iqm_confidence_interval":[508.68638102213544,695.8967692057291],"agent":"PPO"},{"env_step":737280.0,"rew":742.03623046875,"rew_std":90.94995649924844,"iqm":761.3958740234375,"iqm_confidence_interval":[631.2799479166666,832.5506184895834],"agent":"PPO"},{"env_step":768000.0,"rew":615.0272155761719,"rew_std":58.96221259019708,"iqm":632.1763916015625,"iqm_confidence_interval":[541.9233805338541,666.8138631184896],"agent":"PPO"},{"env_step":798720.0,"rew":657.0923034667969,"rew_std":205.12095729630033,"iqm":611.0579833984375,"iqm_confidence_interval":[448.13185628255206,900.5207722981771],"agent":"PPO"},{"env_step":829440.0,"rew":808.9402954101563,"rew_std":93.11438452829933,"iqm":812.4151611328125,"iqm_confidence_interval":[706.8053385416666,906.13330078125],"agent":"PPO"},{"env_step":860160.0,"rew":690.5447570800782,"rew_std":178.52351903516816,"iqm":656.3587036132812,"iqm_confidence_interval":[508.01845296223956,916.5430094401041],"agent":"PPO"},{"env_step":890880.0,"rew":824.253515625,"rew_std":147.86348202648495,"iqm":773.6602986653646,"iqm_confidence_interval":[698.3820393880209,1007.6150512695312],"agent":"PPO"},{"env_step":921600.0,"rew":1026.9066162109375,"rew_std":139.86281711892363,"iqm":971.0055745442709,"iqm_confidence_interval":[931.3086954752604,1202.254659016927],"agent":"PPO"},{"env_step":952320.0,"rew":983.5424194335938,"rew_std":139.64380326887348,"iqm":989.9229329427084,"iqm_confidence_interval":[820.3156941731771,1147.0453287760417],"agent":"PPO"},{"env_step":983040.0,"rew":1006.61904296875,"rew_std":130.47452131573982,"iqm":983.8915812174479,"iqm_confidence_interval":[880.1797892252604,1169.6405843098958],"agent":"PPO"},{"env_step":1013760.0,"rew":1023.1817016601562,"rew_std":229.86015328346903,"iqm":970.7048950195312,"iqm_confidence_interval":[800.6198527018229,1319.8238118489583],"agent":"PPO"},{"env_step":1044480.0,"rew":1235.4329956054687,"rew_std":308.5856035365616,"iqm":1193.8307698567708,"iqm_confidence_interval":[917.0994059244791,1627.1339111328125],"agent":"PPO"},{"env_step":1075200.0,"rew":1283.767919921875,"rew_std":369.94054944625407,"iqm":1235.2249755859375,"iqm_confidence_interval":[910.7933349609375,1758.7515462239583],"agent":"PPO"},{"env_step":1105920.0,"rew":1306.4467895507812,"rew_std":304.9427535917461,"iqm":1280.5807291666667,"iqm_confidence_interval":[970.6359049479166,1681.99755859375],"agent":"PPO"},{"env_step":1136640.0,"rew":1421.9716064453125,"rew_std":408.54132443351,"iqm":1296.5053304036458,"iqm_confidence_interval":[1087.0143229166667,1953.9322102864583],"agent":"PPO"},{"env_step":1167360.0,"rew":1590.911376953125,"rew_std":294.2090689637845,"iqm":1670.9098307291667,"iqm_confidence_interval":[1203.5606689453125,1852.9765218098958],"agent":"PPO"},{"env_step":1198080.0,"rew":1585.925048828125,"rew_std":281.95354699539524,"iqm":1636.9428304036458,"iqm_confidence_interval":[1220.1016438802083,1863.3033447265625],"agent":"PPO"},{"env_step":1228800.0,"rew":1488.6505126953125,"rew_std":242.07114369351723,"iqm":1458.1148681640625,"iqm_confidence_interval":[1236.4732259114583,1792.5924886067708],"agent":"PPO"},{"env_step":1259520.0,"rew":1674.939990234375,"rew_std":158.53917724624395,"iqm":1687.1951090494792,"iqm_confidence_interval":[1475.0849202473958,1840.28564453125],"agent":"PPO"},{"env_step":1290240.0,"rew":1613.6635986328124,"rew_std":226.93365135714689,"iqm":1641.7682291666667,"iqm_confidence_interval":[1333.0029296875,1854.9034016927083],"agent":"PPO"},{"env_step":1320960.0,"rew":1690.7473876953125,"rew_std":228.711704729816,"iqm":1671.047119140625,"iqm_confidence_interval":[1444.6800537109375,1975.0709228515625],"agent":"PPO"},{"env_step":1351680.0,"rew":1637.27607421875,"rew_std":326.39842414211853,"iqm":1689.7603759765625,"iqm_confidence_interval":[1214.1621500651042,1938.1739501953125],"agent":"PPO"},{"env_step":1382400.0,"rew":1875.2035888671876,"rew_std":117.24078150507262,"iqm":1870.533447265625,"iqm_confidence_interval":[1749.0623779296875,2015.3394368489583],"agent":"PPO"},{"env_step":1413120.0,"rew":1630.64677734375,"rew_std":278.33546365831046,"iqm":1666.410888671875,"iqm_confidence_interval":[1284.4798583984375,1922.7447102864583],"agent":"PPO"},{"env_step":1443840.0,"rew":2303.1189697265627,"rew_std":332.04121880276756,"iqm":2275.5123697916665,"iqm_confidence_interval":[1921.7642415364583,2683.3470052083335],"agent":"PPO"},{"env_step":1474560.0,"rew":1948.187353515625,"rew_std":450.3484385458742,"iqm":1980.5458984375,"iqm_confidence_interval":[1415.1756998697917,2445.75146484375],"agent":"PPO"},{"env_step":1505280.0,"rew":1660.3700439453125,"rew_std":81.58748049080477,"iqm":1657.2394612630208,"iqm_confidence_interval":[1576.0192057291667,1758.386474609375],"agent":"PPO"},{"env_step":1536000.0,"rew":2080.3249755859374,"rew_std":516.867492908338,"iqm":2198.355753580729,"iqm_confidence_interval":[1411.3249918619792,2566.95361328125],"agent":"PPO"},{"env_step":1566720.0,"rew":1932.98095703125,"rew_std":497.55494819068974,"iqm":1932.7194010416667,"iqm_confidence_interval":[1333.7371419270833,2499.1190592447915],"agent":"PPO"},{"env_step":1597440.0,"rew":2143.2340576171873,"rew_std":475.77974848071915,"iqm":2227.9281412760415,"iqm_confidence_interval":[1549.1127115885417,2640.3284505208335],"agent":"PPO"},{"env_step":1628160.0,"rew":2226.6865478515624,"rew_std":246.70248568804817,"iqm":2290.8931477864585,"iqm_confidence_interval":[1912.80712890625,2454.1283365885415],"agent":"PPO"},{"env_step":1658880.0,"rew":2434.57939453125,"rew_std":281.12032818486637,"iqm":2458.473388671875,"iqm_confidence_interval":[2097.693359375,2749.7991536458335],"agent":"PPO"},{"env_step":1689600.0,"rew":2254.104052734375,"rew_std":364.90063328970285,"iqm":2279.374267578125,"iqm_confidence_interval":[1812.7180989583333,2615.252197265625],"agent":"PPO"},{"env_step":1720320.0,"rew":2224.58935546875,"rew_std":301.9480649936515,"iqm":2230.579793294271,"iqm_confidence_interval":[1857.5064697265625,2564.9047037760415],"agent":"PPO"},{"env_step":1751040.0,"rew":2474.04462890625,"rew_std":413.29056303724695,"iqm":2463.1443684895835,"iqm_confidence_interval":[1996.7662760416667,2942.8408203125],"agent":"PPO"},{"env_step":1781760.0,"rew":2450.4150634765624,"rew_std":406.0064709999227,"iqm":2600.332763671875,"iqm_confidence_interval":[1916.1021321614583,2758.1326497395835],"agent":"PPO"},{"env_step":1812480.0,"rew":2615.113134765625,"rew_std":379.6628530641846,"iqm":2674.5680338541665,"iqm_confidence_interval":[2129.9462076822915,2991.6184895833335],"agent":"PPO"},{"env_step":1843200.0,"rew":2482.0716552734375,"rew_std":328.13810380085454,"iqm":2535.2766927083335,"iqm_confidence_interval":[2060.6763509114585,2815.9703776041665],"agent":"PPO"},{"env_step":1873920.0,"rew":2287.769189453125,"rew_std":273.181875960619,"iqm":2222.917724609375,"iqm_confidence_interval":[2027.02685546875,2637.3173014322915],"agent":"PPO"},{"env_step":1904640.0,"rew":2599.824365234375,"rew_std":228.66302607578032,"iqm":2573.5817057291665,"iqm_confidence_interval":[2358.0563151041665,2879.1910807291665],"agent":"PPO"},{"env_step":1935360.0,"rew":2541.87373046875,"rew_std":289.5021642360021,"iqm":2539.570556640625,"iqm_confidence_interval":[2208.8688151041665,2887.46533203125],"agent":"PPO"},{"env_step":1966080.0,"rew":2933.87001953125,"rew_std":365.7239536954544,"iqm":2887.910400390625,"iqm_confidence_interval":[2533.5679524739585,3381.1993001302085],"agent":"PPO"},{"env_step":1996800.0,"rew":2579.30927734375,"rew_std":454.20133503110765,"iqm":2520.7469889322915,"iqm_confidence_interval":[2141.8518880208335,3161.8784993489585],"agent":"PPO"},{"env_step":2027520.0,"rew":3025.467724609375,"rew_std":270.24624348081664,"iqm":3012.11279296875,"iqm_confidence_interval":[2737.0088704427085,3361.9505208333335],"agent":"PPO"},{"env_step":2058240.0,"rew":2684.42509765625,"rew_std":306.8236635672918,"iqm":2601.1487630208335,"iqm_confidence_interval":[2417.588134765625,3090.6845703125],"agent":"PPO"},{"env_step":2088960.0,"rew":2737.178173828125,"rew_std":260.7824921193103,"iqm":2758.1848958333335,"iqm_confidence_interval":[2414.929931640625,3024.1651204427085],"agent":"PPO"},{"env_step":2119680.0,"rew":2952.967919921875,"rew_std":372.3186731081476,"iqm":2970.3858235677085,"iqm_confidence_interval":[2496.8993326822915,3367.85791015625],"agent":"PPO"},{"env_step":2150400.0,"rew":2704.452392578125,"rew_std":463.98607495935727,"iqm":2612.8490397135415,"iqm_confidence_interval":[2269.6512858072915,3307.7020670572915],"agent":"PPO"},{"env_step":2181120.0,"rew":2952.671826171875,"rew_std":263.0708852230181,"iqm":3056.8150227864585,"iqm_confidence_interval":[2625.5342610677085,3149.9842122395835],"agent":"PPO"},{"env_step":2211840.0,"rew":2808.459326171875,"rew_std":238.9752354834229,"iqm":2726.9951985677085,"iqm_confidence_interval":[2604.6825358072915,3101.40673828125],"agent":"PPO"},{"env_step":2242560.0,"rew":3009.982470703125,"rew_std":74.9151516668644,"iqm":3012.4922688802085,"iqm_confidence_interval":[2919.9193522135415,3096.140625],"agent":"PPO"},{"env_step":2273280.0,"rew":2761.7705078125,"rew_std":360.90498272036405,"iqm":2695.267822265625,"iqm_confidence_interval":[2459.7607421875,3223.03955078125],"agent":"PPO"},{"env_step":2304000.0,"rew":3065.429833984375,"rew_std":208.58126876046154,"iqm":3072.5851236979165,"iqm_confidence_interval":[2825.865478515625,3308.800048828125],"agent":"PPO"},{"env_step":2334720.0,"rew":2497.6119140625,"rew_std":168.72569472467632,"iqm":2471.8560384114585,"iqm_confidence_interval":[2331.022705078125,2695.2312825520835],"agent":"PPO"},{"env_step":2365440.0,"rew":3119.83740234375,"rew_std":394.03295535098914,"iqm":3144.8118489583335,"iqm_confidence_interval":[2627.7822265625,3542.5327962239585],"agent":"PPO"},{"env_step":2396160.0,"rew":3456.946533203125,"rew_std":292.95038801934356,"iqm":3519.7478841145835,"iqm_confidence_interval":[3081.8028157552085,3747.7261555989585],"agent":"PPO"},{"env_step":2426880.0,"rew":2753.8327392578126,"rew_std":506.1009013588715,"iqm":2723.019775390625,"iqm_confidence_interval":[2205.009765625,3349.3614908854165],"agent":"PPO"},{"env_step":2457600.0,"rew":2759.991162109375,"rew_std":499.19339471250754,"iqm":2571.3480631510415,"iqm_confidence_interval":[2386.9571126302085,3371.4786783854165],"agent":"PPO"},{"env_step":2488320.0,"rew":3062.648388671875,"rew_std":378.59710060804554,"iqm":3073.1571451822915,"iqm_confidence_interval":[2607.640625,3467.515869140625],"agent":"PPO"},{"env_step":2519040.0,"rew":3111.23828125,"rew_std":287.9575490617171,"iqm":3045.4600423177085,"iqm_confidence_interval":[2838.7545572916665,3485.8863118489585],"agent":"PPO"},{"env_step":2549760.0,"rew":3244.82412109375,"rew_std":206.83376168878198,"iqm":3249.402587890625,"iqm_confidence_interval":[2996.6829427083335,3480.8203938802085],"agent":"PPO"},{"env_step":2580480.0,"rew":3003.4182861328127,"rew_std":664.6998400999946,"iqm":3071.9596354166665,"iqm_confidence_interval":[2204.2228190104165,3750.7945149739585],"agent":"PPO"},{"env_step":2611200.0,"rew":3196.064501953125,"rew_std":377.4709011433584,"iqm":3297.418212890625,"iqm_confidence_interval":[2721.66455078125,3556.1273600260415],"agent":"PPO"},{"env_step":2641920.0,"rew":2605.3752197265626,"rew_std":588.9940724955435,"iqm":2536.89013671875,"iqm_confidence_interval":[2026.8653157552083,3360.588623046875],"agent":"PPO"},{"env_step":2672640.0,"rew":2879.83369140625,"rew_std":273.4504838936323,"iqm":2911.3448079427085,"iqm_confidence_interval":[2535.531494140625,3139.8009440104165],"agent":"PPO"},{"env_step":2703360.0,"rew":3016.5560546875,"rew_std":229.28307374375265,"iqm":3019.5062662760415,"iqm_confidence_interval":[2774.4791666666665,3292.1170247395835],"agent":"PPO"},{"env_step":2734080.0,"rew":3156.365966796875,"rew_std":366.55058529979067,"iqm":3045.2474772135415,"iqm_confidence_interval":[2837.220947265625,3637.2925618489585],"agent":"PPO"},{"env_step":2764800.0,"rew":3243.116015625,"rew_std":318.76341846175626,"iqm":3214.4501953125,"iqm_confidence_interval":[2888.8876139322915,3635.6490071614585],"agent":"PPO"},{"env_step":2795520.0,"rew":3162.737255859375,"rew_std":421.73512918702005,"iqm":3265.7046712239585,"iqm_confidence_interval":[2640.34912109375,3538.4329427083335],"agent":"PPO"},{"env_step":2826240.0,"rew":2813.2316650390626,"rew_std":538.8788688188539,"iqm":2993.232421875,"iqm_confidence_interval":[2138.80712890625,3238.02294921875],"agent":"PPO"},{"env_step":2856960.0,"rew":2723.7482421875,"rew_std":395.313112852787,"iqm":2772.0569661458335,"iqm_confidence_interval":[2232.9578450520835,3151.04736328125],"agent":"PPO"},{"env_step":2887680.0,"rew":3243.996533203125,"rew_std":438.84849169788566,"iqm":3380.9866536458335,"iqm_confidence_interval":[2680.8663736979165,3632.7811686197915],"agent":"PPO"},{"env_step":2918400.0,"rew":3352.18662109375,"rew_std":178.99601980908236,"iqm":3343.7360026041665,"iqm_confidence_interval":[3150.2766927083335,3570.91748046875],"agent":"PPO"},{"env_step":2949120.0,"rew":3044.200390625,"rew_std":366.54644565146214,"iqm":2977.0995279947915,"iqm_confidence_interval":[2705.872314453125,3523.64599609375],"agent":"PPO"},{"env_step":2979840.0,"rew":3209.932763671875,"rew_std":189.78177916093884,"iqm":3224.5018717447915,"iqm_confidence_interval":[2990.1775716145835,3427.7244466145835],"agent":"PPO"},{"env_step":3010560.0,"rew":3192.97626953125,"rew_std":673.4347294627223,"iqm":3317.2999674479165,"iqm_confidence_interval":[2343.0367024739585,3888.066650390625],"agent":"PPO"},{"env_step":3041280.0,"rew":2891.0755859375,"rew_std":264.96420492417445,"iqm":2896.5139973958335,"iqm_confidence_interval":[2567.085693359375,3168.7027180989585],"agent":"PPO"},{"env_step":3072000.0,"rew":3297.96611328125,"rew_std":252.4458971635053,"iqm":3263.8701985677085,"iqm_confidence_interval":[3042.9195149739585,3622.0187174479165],"agent":"PPO"},{"env_step":0.0,"rew":19.122140550613402,"rew_std":10.828158316117733,"iqm":19.053614298502605,"iqm_confidence_interval":[6.762896378835042,31.858115514119465],"agent":"REDQ"},{"env_step":5000.0,"rew":-0.3874577522277832,"rew_std":20.51044985790059,"iqm":-3.4084525108337402,"iqm_confidence_interval":[-21.0011043548584,25.558574040730793],"agent":"REDQ"},{"env_step":10000.0,"rew":15.17558171749115,"rew_std":19.118557777633733,"iqm":7.209461688995361,"iqm_confidence_interval":[2.4170307318369546,39.32273292541504],"agent":"REDQ"},{"env_step":15000.0,"rew":18.561113548278808,"rew_std":26.933683712547285,"iqm":16.118902524312336,"iqm_confidence_interval":[-11.86785856882731,50.54333623250326],"agent":"REDQ"},{"env_step":20000.0,"rew":22.58060312271118,"rew_std":18.478246441131926,"iqm":22.68181037902832,"iqm_confidence_interval":[0.4586966832478841,42.824031829833984],"agent":"REDQ"},{"env_step":25000.0,"rew":26.83237566947937,"rew_std":22.788315952093132,"iqm":23.187583605448406,"iqm_confidence_interval":[2.362187226613363,53.76578903198242],"agent":"REDQ"},{"env_step":30000.0,"rew":26.75415654182434,"rew_std":14.504257444432223,"iqm":29.219275156656902,"iqm_confidence_interval":[8.868590831756592,42.14922841389974],"agent":"REDQ"},{"env_step":35000.0,"rew":40.10686206817627,"rew_std":28.254898689421985,"iqm":33.00629107157389,"iqm_confidence_interval":[14.200682322184244,75.1447016398112],"agent":"REDQ"},{"env_step":40000.0,"rew":29.391694402694704,"rew_std":26.29284086929632,"iqm":27.753348350524902,"iqm_confidence_interval":[2.39839760462443,62.38681411743164],"agent":"REDQ"},{"env_step":45000.0,"rew":34.10289249420166,"rew_std":14.799463210084305,"iqm":31.986361821492512,"iqm_confidence_interval":[19.76604715983073,51.56613032023112],"agent":"REDQ"},{"env_step":50000.0,"rew":36.32660884857178,"rew_std":15.576514826259867,"iqm":36.95574633280436,"iqm_confidence_interval":[19.14916229248047,54.83012898763021],"agent":"REDQ"},{"env_step":55000.0,"rew":42.851513671875,"rew_std":13.609421520237806,"iqm":43.765868504842125,"iqm_confidence_interval":[25.845146814982098,57.435323079427086],"agent":"REDQ"},{"env_step":60000.0,"rew":41.53638954162598,"rew_std":19.636276908871174,"iqm":41.28042538960775,"iqm_confidence_interval":[17.94518979390462,62.46479288736979],"agent":"REDQ"},{"env_step":65000.0,"rew":50.33994483947754,"rew_std":24.545305592826463,"iqm":49.94168599446615,"iqm_confidence_interval":[22.127840677897137,79.39129384358723],"agent":"REDQ"},{"env_step":70000.0,"rew":44.47986755371094,"rew_std":4.600869240164357,"iqm":45.67780685424805,"iqm_confidence_interval":[38.634847005208336,48.49294408162435],"agent":"REDQ"},{"env_step":75000.0,"rew":48.29617462158203,"rew_std":16.110098364439374,"iqm":48.11051813761393,"iqm_confidence_interval":[30.79833221435547,66.8612912495931],"agent":"REDQ"},{"env_step":80000.0,"rew":55.402624893188474,"rew_std":24.430260082246257,"iqm":51.397787729899086,"iqm_confidence_interval":[29.97335433959961,86.28074391682942],"agent":"REDQ"},{"env_step":85000.0,"rew":44.108120346069335,"rew_std":11.068213006744026,"iqm":44.7865244547526,"iqm_confidence_interval":[30.729891459147137,56.39498519897461],"agent":"REDQ"},{"env_step":90000.0,"rew":42.49583158493042,"rew_std":31.699443718350526,"iqm":41.404434521993004,"iqm_confidence_interval":[9.195992151896158,80.47208913167317],"agent":"REDQ"},{"env_step":95000.0,"rew":65.2559928894043,"rew_std":19.92327549779489,"iqm":66.64696884155273,"iqm_confidence_interval":[41.3400510152181,88.01953379313152],"agent":"REDQ"},{"env_step":100000.0,"rew":58.94149398803711,"rew_std":12.552099052111915,"iqm":55.4300905863444,"iqm_confidence_interval":[48.69604619344076,75.6355463663737],"agent":"REDQ"},{"env_step":105000.0,"rew":47.423474884033205,"rew_std":18.442554244270397,"iqm":52.15555318196615,"iqm_confidence_interval":[23.097124735514324,62.573177337646484],"agent":"REDQ"},{"env_step":110000.0,"rew":46.11401519775391,"rew_std":17.63670714798056,"iqm":49.090494791666664,"iqm_confidence_interval":[23.976512908935547,64.5013033548991],"agent":"REDQ"},{"env_step":115000.0,"rew":48.510765838623044,"rew_std":14.95405674973619,"iqm":49.95635096232096,"iqm_confidence_interval":[29.695638020833332,63.88478088378906],"agent":"REDQ"},{"env_step":120000.0,"rew":46.53772888183594,"rew_std":7.871587588416568,"iqm":46.81902313232422,"iqm_confidence_interval":[36.88836924235026,54.4616813659668],"agent":"REDQ"},{"env_step":125000.0,"rew":39.374982833862305,"rew_std":12.732017198364703,"iqm":38.69713592529297,"iqm_confidence_interval":[26.014727274576824,55.24446360270182],"agent":"REDQ"},{"env_step":130000.0,"rew":33.05560150146484,"rew_std":22.92158223741304,"iqm":33.4929567972819,"iqm_confidence_interval":[5.686606089274089,59.452101389567055],"agent":"REDQ"},{"env_step":135000.0,"rew":63.105552673339844,"rew_std":16.518051063313987,"iqm":66.08929824829102,"iqm_confidence_interval":[41.579349517822266,78.98898824055989],"agent":"REDQ"},{"env_step":140000.0,"rew":46.85098991394043,"rew_std":18.930184847734054,"iqm":47.70589701334635,"iqm_confidence_interval":[25.453510284423828,69.16705067952473],"agent":"REDQ"},{"env_step":145000.0,"rew":52.60828285217285,"rew_std":20.0156007925802,"iqm":50.19569396972656,"iqm_confidence_interval":[32.482496897379555,77.17677942911784],"agent":"REDQ"},{"env_step":150000.0,"rew":32.87973213195801,"rew_std":5.752220982221004,"iqm":33.12152926127116,"iqm_confidence_interval":[25.864628473917644,38.533990224202476],"agent":"REDQ"},{"env_step":155000.0,"rew":44.07918701171875,"rew_std":8.609497115980329,"iqm":44.441888173421226,"iqm_confidence_interval":[33.86382802327474,54.07351048787435],"agent":"REDQ"},{"env_step":160000.0,"rew":30.693387031555176,"rew_std":10.843262637384841,"iqm":33.20183245340983,"iqm_confidence_interval":[17.02465057373047,41.24501164754232],"agent":"REDQ"},{"env_step":165000.0,"rew":33.355489349365236,"rew_std":5.196249955937854,"iqm":33.20615895589193,"iqm_confidence_interval":[27.692097981770832,39.44086456298828],"agent":"REDQ"},{"env_step":170000.0,"rew":35.94187335968017,"rew_std":15.523436180672805,"iqm":38.41239356994629,"iqm_confidence_interval":[16.758402506510418,52.45171864827474],"agent":"REDQ"},{"env_step":175000.0,"rew":41.43852634429932,"rew_std":27.498704792345396,"iqm":35.69605954488119,"iqm_confidence_interval":[14.622135162353516,74.53206888834636],"agent":"REDQ"},{"env_step":180000.0,"rew":34.13255910873413,"rew_std":25.307062223373727,"iqm":28.26945209503174,"iqm_confidence_interval":[8.89621353149414,65.08337910970052],"agent":"REDQ"},{"env_step":185000.0,"rew":39.39708557128906,"rew_std":31.013538301188465,"iqm":43.62894503275553,"iqm_confidence_interval":[1.4370594024658203,72.83478291829427],"agent":"REDQ"},{"env_step":190000.0,"rew":25.197536277770997,"rew_std":14.176854535443127,"iqm":23.52732563018799,"iqm_confidence_interval":[9.304034868876139,41.48284022013346],"agent":"REDQ"},{"env_step":195000.0,"rew":19.246288204193114,"rew_std":20.540901260792012,"iqm":9.52669350306193,"iqm_confidence_interval":[7.575188159942627,44.38832473754883],"agent":"REDQ"},{"env_step":200000.0,"rew":26.64177188873291,"rew_std":10.990681141501083,"iqm":24.55473454793294,"iqm_confidence_interval":[15.542388280232748,40.352203369140625],"agent":"REDQ"},{"env_step":205000.0,"rew":22.146988439559937,"rew_std":12.838547894626247,"iqm":22.134310404459637,"iqm_confidence_interval":[7.028153578440349,36.973257064819336],"agent":"REDQ"},{"env_step":210000.0,"rew":20.998230075836183,"rew_std":13.356703236839795,"iqm":24.047775904337566,"iqm_confidence_interval":[5.1614993413289385,32.8866761525472],"agent":"REDQ"},{"env_step":215000.0,"rew":27.793178844451905,"rew_std":13.918359181461092,"iqm":28.164426803588867,"iqm_confidence_interval":[12.030838330586752,43.87256622314453],"agent":"REDQ"},{"env_step":220000.0,"rew":30.735445976257324,"rew_std":14.063949149444934,"iqm":28.017695744832356,"iqm_confidence_interval":[16.682162602742512,48.69248708089193],"agent":"REDQ"},{"env_step":225000.0,"rew":38.60630149841309,"rew_std":31.994663408504838,"iqm":33.90487003326416,"iqm_confidence_interval":[7.282408714294434,79.19370778401692],"agent":"REDQ"},{"env_step":230000.0,"rew":19.250569534301757,"rew_std":24.54294644655565,"iqm":11.479660749435425,"iqm_confidence_interval":[-1.0587985515594482,51.295522689819336],"agent":"REDQ"},{"env_step":235000.0,"rew":30.77835578918457,"rew_std":12.581581321709292,"iqm":26.17158381144206,"iqm_confidence_interval":[20.701234181722004,46.75053914388021],"agent":"REDQ"},{"env_step":240000.0,"rew":15.26291389465332,"rew_std":9.251941738607126,"iqm":16.373140652974445,"iqm_confidence_interval":[4.278467178344727,25.60233434041341],"agent":"REDQ"},{"env_step":245000.0,"rew":29.499822044372557,"rew_std":23.60619473210108,"iqm":23.36173439025879,"iqm_confidence_interval":[8.617469787597656,59.65530014038086],"agent":"REDQ"},{"env_step":250000.0,"rew":36.33521785736084,"rew_std":21.546205271752644,"iqm":34.795641581217446,"iqm_confidence_interval":[16.443824768066406,62.70384216308594],"agent":"REDQ"},{"env_step":255000.0,"rew":18.074655723571777,"rew_std":9.11244349956379,"iqm":14.337118784586588,"iqm_confidence_interval":[11.932367960611979,29.727651596069336],"agent":"REDQ"},{"env_step":260000.0,"rew":27.723381662368773,"rew_std":37.81439982632864,"iqm":12.885271072387695,"iqm_confidence_interval":[1.7803092002868652,73.44462776184082],"agent":"REDQ"},{"env_step":265000.0,"rew":25.311579477787017,"rew_std":21.841700020220447,"iqm":20.46875,"iqm_confidence_interval":[5.30793559551239,51.78204854329427],"agent":"REDQ"},{"env_step":270000.0,"rew":16.696401643753052,"rew_std":16.042648501377865,"iqm":17.084654887517292,"iqm_confidence_interval":[-2.5589532057444253,34.69189961751302],"agent":"REDQ"},{"env_step":275000.0,"rew":21.177660942077637,"rew_std":14.303141693630103,"iqm":19.086255073547363,"iqm_confidence_interval":[5.847034772237142,38.21140734354655],"agent":"REDQ"},{"env_step":280000.0,"rew":23.45450601577759,"rew_std":19.752276983360957,"iqm":20.248744010925293,"iqm_confidence_interval":[5.931210835774739,48.72629928588867],"agent":"REDQ"},{"env_step":285000.0,"rew":32.70703010559082,"rew_std":13.501850589686523,"iqm":34.144697189331055,"iqm_confidence_interval":[16.019423802693684,47.23796463012695],"agent":"REDQ"},{"env_step":290000.0,"rew":15.175395011901855,"rew_std":14.741239047740697,"iqm":19.14916197458903,"iqm_confidence_interval":[-3.7742255528767905,28.7479674021403],"agent":"REDQ"},{"env_step":295000.0,"rew":26.005640029907227,"rew_std":38.27730697784234,"iqm":13.446321487426758,"iqm_confidence_interval":[-3.740265210469564,71.88369782765706],"agent":"REDQ"},{"env_step":300000.0,"rew":19.6172076523304,"rew_std":12.230573673359512,"iqm":21.172632535298664,"iqm_confidence_interval":[4.248912910620372,32.60005187988281],"agent":"REDQ"},{"env_step":305000.0,"rew":10.947058582305909,"rew_std":21.094597096111134,"iqm":7.029097716013591,"iqm_confidence_interval":[-9.622747580210367,38.37720934549967],"agent":"REDQ"},{"env_step":310000.0,"rew":17.46815927028656,"rew_std":10.885388360104855,"iqm":16.692235628763836,"iqm_confidence_interval":[5.680540323257446,29.90889040629069],"agent":"REDQ"},{"env_step":315000.0,"rew":5.793402194976807,"rew_std":9.936588845500665,"iqm":8.28311554590861,"iqm_confidence_interval":[-6.639059543609619,15.044461886088053],"agent":"REDQ"},{"env_step":320000.0,"rew":13.770581036806107,"rew_std":14.23274143825097,"iqm":10.514999677737555,"iqm_confidence_interval":[-0.614104300737381,31.041234334309895],"agent":"REDQ"},{"env_step":325000.0,"rew":18.799122714996336,"rew_std":10.376715250176975,"iqm":16.759232838948567,"iqm_confidence_interval":[8.241253852844238,31.976130803426106],"agent":"REDQ"},{"env_step":330000.0,"rew":16.406962966918947,"rew_std":22.157508760557633,"iqm":16.556129455566406,"iqm_confidence_interval":[-8.651859919230143,41.26849365234375],"agent":"REDQ"},{"env_step":335000.0,"rew":12.169895958900451,"rew_std":14.963715049441639,"iqm":9.281240185101828,"iqm_confidence_interval":[-2.3518089850743613,31.649239857991535],"agent":"REDQ"},{"env_step":340000.0,"rew":3.442046642303467,"rew_std":14.780112286893091,"iqm":-2.1294188499450684,"iqm_confidence_interval":[-7.394394556681315,22.499412536621094],"agent":"REDQ"},{"env_step":345000.0,"rew":4.95629916191101,"rew_std":13.910492011418887,"iqm":4.2144068876902265,"iqm_confidence_interval":[-11.02927009264628,21.360683759053547],"agent":"REDQ"},{"env_step":350000.0,"rew":18.853069972991943,"rew_std":19.73532570992836,"iqm":13.482548395792643,"iqm_confidence_interval":[1.857444127400716,43.77008628845215],"agent":"REDQ"},{"env_step":355000.0,"rew":22.750345325469972,"rew_std":16.966662232915105,"iqm":18.741106351216633,"iqm_confidence_interval":[7.5960343678792315,44.02186838785807],"agent":"REDQ"},{"env_step":360000.0,"rew":5.956750690937042,"rew_std":21.695415764944062,"iqm":7.111966788768768,"iqm_confidence_interval":[-18.873626073201496,31.441300710042317],"agent":"REDQ"},{"env_step":365000.0,"rew":20.61870470046997,"rew_std":21.678673031794197,"iqm":13.428698062896729,"iqm_confidence_interval":[2.7947897911071777,49.16689236958822],"agent":"REDQ"},{"env_step":370000.0,"rew":14.885484218597412,"rew_std":17.805360659590868,"iqm":12.111168702443441,"iqm_confidence_interval":[-4.504808266957601,35.830423990885414],"agent":"REDQ"},{"env_step":375000.0,"rew":32.646864891052246,"rew_std":20.373741121242226,"iqm":26.919445673624676,"iqm_confidence_interval":[14.16562016805013,58.93025588989258],"agent":"REDQ"},{"env_step":380000.0,"rew":0.800808048248291,"rew_std":8.66086065871918,"iqm":-0.5694886843363444,"iqm_confidence_interval":[-8.144091129302979,11.844081242879232],"agent":"REDQ"},{"env_step":385000.0,"rew":5.095739841461182,"rew_std":10.521219934499408,"iqm":7.06280533472697,"iqm_confidence_interval":[-7.770551681518555,15.9800812403361],"agent":"REDQ"},{"env_step":390000.0,"rew":10.66844825744629,"rew_std":13.29951939275367,"iqm":12.87478764851888,"iqm_confidence_interval":[-5.858527183532715,23.706661860148113],"agent":"REDQ"},{"env_step":395000.0,"rew":16.440044784545897,"rew_std":11.947485845534818,"iqm":13.981200218200684,"iqm_confidence_interval":[5.077644983927409,31.113087336222332],"agent":"REDQ"},{"env_step":400000.0,"rew":-0.9765789985656739,"rew_std":15.48268801545026,"iqm":-0.12583525975545248,"iqm_confidence_interval":[-19.707523345947266,14.964693705240885],"agent":"REDQ"},{"env_step":405000.0,"rew":6.249713802337647,"rew_std":7.012356161983059,"iqm":8.156244277954102,"iqm_confidence_interval":[-2.619986057281494,12.68600082397461],"agent":"REDQ"},{"env_step":410000.0,"rew":-0.9776759624481202,"rew_std":14.570706019678896,"iqm":0.0393820603688558,"iqm_confidence_interval":[-18.631032307942707,14.512978871663412],"agent":"REDQ"},{"env_step":415000.0,"rew":9.446916484832764,"rew_std":9.888179038609788,"iqm":10.515723069508871,"iqm_confidence_interval":[-2.8868762652079263,18.93733787536621],"agent":"REDQ"},{"env_step":420000.0,"rew":-0.6988149166107178,"rew_std":16.74480844346608,"iqm":-1.5075164635976155,"iqm_confidence_interval":[-19.361233075459797,19.76446533203125],"agent":"REDQ"},{"env_step":425000.0,"rew":-10.437113916873932,"rew_std":9.06114434251262,"iqm":-8.526923179626465,"iqm_confidence_interval":[-21.928112665812176,-1.3592464526494343],"agent":"REDQ"},{"env_step":430000.0,"rew":12.44153401851654,"rew_std":21.58998283354288,"iqm":6.713498870531718,"iqm_confidence_interval":[-6.644435087839763,40.96184539794922],"agent":"REDQ"},{"env_step":435000.0,"rew":13.494499111175537,"rew_std":6.296845758071655,"iqm":12.34780216217041,"iqm_confidence_interval":[7.303154309590657,21.508914947509766],"agent":"REDQ"},{"env_step":440000.0,"rew":9.318463969230653,"rew_std":16.37547903187382,"iqm":7.349913080533345,"iqm_confidence_interval":[-6.832603057225545,30.141359329223633],"agent":"REDQ"},{"env_step":445000.0,"rew":14.813998603820801,"rew_std":27.582136294485988,"iqm":11.343953132629395,"iqm_confidence_interval":[-15.433674812316895,48.29724884033203],"agent":"REDQ"},{"env_step":450000.0,"rew":4.246308040618897,"rew_std":10.75259828037478,"iqm":4.867857774098714,"iqm_confidence_interval":[-7.741820812225342,15.575287024180094],"agent":"REDQ"},{"env_step":455000.0,"rew":10.270092296600343,"rew_std":21.201130261835395,"iqm":11.415112018585205,"iqm_confidence_interval":[-13.926047801971436,33.083889961242676],"agent":"REDQ"},{"env_step":460000.0,"rew":16.007566070556642,"rew_std":22.588210333120543,"iqm":12.718087514241537,"iqm_confidence_interval":[-6.672385851542155,43.067083994547524],"agent":"REDQ"},{"env_step":465000.0,"rew":17.37092866897583,"rew_std":17.09418538925163,"iqm":18.513753096262615,"iqm_confidence_interval":[-3.585215409596761,36.4017219543457],"agent":"REDQ"},{"env_step":470000.0,"rew":7.109980523586273,"rew_std":16.376672204163395,"iqm":7.444703638553619,"iqm_confidence_interval":[-12.121398071448008,25.28855323791504],"agent":"REDQ"},{"env_step":475000.0,"rew":10.492794704437255,"rew_std":22.391037243119204,"iqm":16.986163934071858,"iqm_confidence_interval":[-19.13445170720418,29.666130701700848],"agent":"REDQ"},{"env_step":480000.0,"rew":-11.270056629180909,"rew_std":21.05769175036746,"iqm":-12.146312236785889,"iqm_confidence_interval":[-34.96834373474121,13.961121718088785],"agent":"REDQ"},{"env_step":485000.0,"rew":17.99180948138237,"rew_std":25.44042909642577,"iqm":11.619555821021399,"iqm_confidence_interval":[-6.0503911674022675,50.00814310709635],"agent":"REDQ"},{"env_step":490000.0,"rew":5.295922470092774,"rew_std":33.95100987171947,"iqm":9.257699966430664,"iqm_confidence_interval":[-37.17231877644857,41.746816635131836],"agent":"REDQ"},{"env_step":495000.0,"rew":3.446676015853882,"rew_std":7.801368735135198,"iqm":2.087717135747274,"iqm_confidence_interval":[-4.688169797261556,13.158220609029135],"agent":"REDQ"},{"env_step":500000.0,"rew":8.002798652648925,"rew_std":18.659321671851902,"iqm":12.137170473734537,"iqm_confidence_interval":[-14.34347407023112,25.695288976033527],"agent":"REDQ"},{"env_step":0.0,"rew":30.388103103637697,"rew_std":6.788395302817334,"iqm":29.850717544555664,"iqm_confidence_interval":[23.30682309468587,38.79926681518555],"agent":"Reinforce"},{"env_step":30720.0,"rew":41.393571853637695,"rew_std":13.971328545803654,"iqm":42.14703114827474,"iqm_confidence_interval":[24.269193013509113,56.667057037353516],"agent":"Reinforce"},{"env_step":61440.0,"rew":26.356314849853515,"rew_std":12.084152781955408,"iqm":26.866626739501953,"iqm_confidence_interval":[12.766315460205078,40.5825449625651],"agent":"Reinforce"},{"env_step":92160.0,"rew":31.004388427734376,"rew_std":17.585687276059105,"iqm":26.937780062357586,"iqm_confidence_interval":[13.437909126281738,52.94202423095703],"agent":"Reinforce"},{"env_step":122880.0,"rew":36.31029825210571,"rew_std":18.476436394027456,"iqm":40.446022033691406,"iqm_confidence_interval":[13.678818066914877,55.07053248087565],"agent":"Reinforce"},{"env_step":153600.0,"rew":56.68689727783203,"rew_std":20.55872155871622,"iqm":60.506306966145836,"iqm_confidence_interval":[29.85924784342448,76.37279001871745],"agent":"Reinforce"},{"env_step":184320.0,"rew":40.220344161987306,"rew_std":13.267794175764193,"iqm":41.17553265889486,"iqm_confidence_interval":[23.531400680541992,54.020825703938804],"agent":"Reinforce"},{"env_step":215040.0,"rew":62.15134239196777,"rew_std":18.107928204603546,"iqm":68.45537567138672,"iqm_confidence_interval":[39.44554773966471,76.57495880126953],"agent":"Reinforce"},{"env_step":245760.0,"rew":58.648136138916016,"rew_std":22.385100878109142,"iqm":55.41466522216797,"iqm_confidence_interval":[36.151807149251304,86.53092956542969],"agent":"Reinforce"},{"env_step":276480.0,"rew":57.12719497680664,"rew_std":11.675106309035622,"iqm":55.473453521728516,"iqm_confidence_interval":[44.32411575317383,71.13336181640625],"agent":"Reinforce"},{"env_step":307200.0,"rew":59.3137393951416,"rew_std":18.880961069375342,"iqm":61.800671895345054,"iqm_confidence_interval":[35.953857421875,79.8323237101237],"agent":"Reinforce"},{"env_step":337920.0,"rew":75.0232650756836,"rew_std":28.034409418640404,"iqm":67.36776860555013,"iqm_confidence_interval":[50.934304555257164,111.92232259114583],"agent":"Reinforce"},{"env_step":368640.0,"rew":70.71477661132812,"rew_std":7.937664767412498,"iqm":69.91966247558594,"iqm_confidence_interval":[61.88165791829427,80.43016560872395],"agent":"Reinforce"},{"env_step":399360.0,"rew":75.26746444702148,"rew_std":23.88399289851664,"iqm":75.4661038716634,"iqm_confidence_interval":[47.51989618937174,103.60155741373698],"agent":"Reinforce"},{"env_step":430080.0,"rew":112.8280014038086,"rew_std":39.87061054787958,"iqm":101.47241719563802,"iqm_confidence_interval":[79.63001251220703,165.8230438232422],"agent":"Reinforce"},{"env_step":460800.0,"rew":97.8234634399414,"rew_std":23.66546101387088,"iqm":100.41691080729167,"iqm_confidence_interval":[67.86837259928386,122.61161295572917],"agent":"Reinforce"},{"env_step":491520.0,"rew":89.31181411743164,"rew_std":43.67601321953432,"iqm":72.80683771769206,"iqm_confidence_interval":[56.229427337646484,143.82876586914062],"agent":"Reinforce"},{"env_step":522240.0,"rew":79.20274124145507,"rew_std":23.178516509016514,"iqm":74.65371195475261,"iqm_confidence_interval":[55.99585723876953,108.97505950927734],"agent":"Reinforce"},{"env_step":552960.0,"rew":85.39522552490234,"rew_std":12.139876558939596,"iqm":84.75838979085286,"iqm_confidence_interval":[72.48362986246745,100.51192982991536],"agent":"Reinforce"},{"env_step":583680.0,"rew":109.23733825683594,"rew_std":25.486743575543077,"iqm":109.1050033569336,"iqm_confidence_interval":[81.64951833089192,140.30201212565103],"agent":"Reinforce"},{"env_step":614400.0,"rew":138.59989166259766,"rew_std":33.6939212817492,"iqm":142.34720357259116,"iqm_confidence_interval":[97.67273966471355,172.8327433268229],"agent":"Reinforce"},{"env_step":645120.0,"rew":121.40247039794922,"rew_std":49.794980618645425,"iqm":122.70049794514973,"iqm_confidence_interval":[60.47559356689453,176.4300994873047],"agent":"Reinforce"},{"env_step":675840.0,"rew":115.31985321044922,"rew_std":19.179720367158627,"iqm":123.2266337076823,"iqm_confidence_interval":[91.66427357991536,128.65331013997397],"agent":"Reinforce"},{"env_step":706560.0,"rew":109.73521270751954,"rew_std":32.05617777423248,"iqm":112.18231201171875,"iqm_confidence_interval":[70.27650451660156,139.39081319173178],"agent":"Reinforce"},{"env_step":737280.0,"rew":138.98448944091797,"rew_std":31.167571793249614,"iqm":125.43122863769531,"iqm_confidence_interval":[119.23805491129558,176.8132781982422],"agent":"Reinforce"},{"env_step":768000.0,"rew":160.27299957275392,"rew_std":31.830329501119625,"iqm":162.18966166178384,"iqm_confidence_interval":[125.64125569661458,197.5405527750651],"agent":"Reinforce"},{"env_step":798720.0,"rew":127.82812042236328,"rew_std":31.185511138850533,"iqm":133.3681157430013,"iqm_confidence_interval":[90.76575978597005,161.08210245768228],"agent":"Reinforce"},{"env_step":829440.0,"rew":131.0335693359375,"rew_std":54.88868142273811,"iqm":117.62583923339844,"iqm_confidence_interval":[82.56664021809895,199.98969523111978],"agent":"Reinforce"},{"env_step":860160.0,"rew":174.5028549194336,"rew_std":45.755765517387324,"iqm":182.59103393554688,"iqm_confidence_interval":[118.81187438964844,221.41290283203125],"agent":"Reinforce"},{"env_step":890880.0,"rew":159.05300750732422,"rew_std":29.948995732290975,"iqm":165.14197794596353,"iqm_confidence_interval":[120.02970886230469,187.74015299479166],"agent":"Reinforce"},{"env_step":921600.0,"rew":189.1221496582031,"rew_std":37.90998388924396,"iqm":190.80534871419272,"iqm_confidence_interval":[142.2677256266276,230.343994140625],"agent":"Reinforce"},{"env_step":952320.0,"rew":200.0733184814453,"rew_std":34.72404034142182,"iqm":198.57618204752603,"iqm_confidence_interval":[158.88312276204428,238.5328572591146],"agent":"Reinforce"},{"env_step":983040.0,"rew":184.41802368164062,"rew_std":41.505971796954185,"iqm":178.31390889485678,"iqm_confidence_interval":[140.6153818766276,234.4051971435547],"agent":"Reinforce"},{"env_step":1013760.0,"rew":217.9223663330078,"rew_std":30.55818587396373,"iqm":212.4499053955078,"iqm_confidence_interval":[185.58199055989584,255.27669779459634],"agent":"Reinforce"},{"env_step":1044480.0,"rew":188.50809783935546,"rew_std":45.63991159019269,"iqm":192.9031728108724,"iqm_confidence_interval":[134.54771931966147,236.86406453450522],"agent":"Reinforce"},{"env_step":1075200.0,"rew":172.07819671630858,"rew_std":52.15507095930203,"iqm":170.8068389892578,"iqm_confidence_interval":[113.67265319824219,235.25662231445312],"agent":"Reinforce"},{"env_step":1105920.0,"rew":202.97726135253907,"rew_std":36.1947818092481,"iqm":192.60591634114584,"iqm_confidence_interval":[172.39493815104166,247.6568857828776],"agent":"Reinforce"},{"env_step":1136640.0,"rew":155.8261291503906,"rew_std":15.666026150139057,"iqm":157.44652303059897,"iqm_confidence_interval":[136.67408243815103,173.25584920247397],"agent":"Reinforce"},{"env_step":1167360.0,"rew":203.0422790527344,"rew_std":49.88510628853617,"iqm":203.64251200358072,"iqm_confidence_interval":[143.92074584960938,260.19286092122394],"agent":"Reinforce"},{"env_step":1198080.0,"rew":244.92487182617188,"rew_std":60.045671210825894,"iqm":231.50265502929688,"iqm_confidence_interval":[189.49774169921875,323.7748209635417],"agent":"Reinforce"},{"env_step":1228800.0,"rew":205.7066192626953,"rew_std":62.55029791054683,"iqm":226.20610555013022,"iqm_confidence_interval":[123.72357686360677,258.00376383463544],"agent":"Reinforce"},{"env_step":1259520.0,"rew":250.18730773925782,"rew_std":65.97905425786584,"iqm":240.34812927246094,"iqm_confidence_interval":[178.1267344156901,328.5118916829427],"agent":"Reinforce"},{"env_step":1290240.0,"rew":224.54561767578124,"rew_std":46.52547907892198,"iqm":231.78507486979166,"iqm_confidence_interval":[165.31780497233072,267.29746500651044],"agent":"Reinforce"},{"env_step":1320960.0,"rew":218.5148956298828,"rew_std":47.08101810365439,"iqm":221.06966654459634,"iqm_confidence_interval":[165.04052225748697,273.56219482421875],"agent":"Reinforce"},{"env_step":1351680.0,"rew":187.0691879272461,"rew_std":40.25982166153624,"iqm":190.78805541992188,"iqm_confidence_interval":[141.08477274576822,233.19166056315103],"agent":"Reinforce"},{"env_step":1382400.0,"rew":258.0915588378906,"rew_std":52.01694964975601,"iqm":247.74063618977866,"iqm_confidence_interval":[207.2652333577474,323.55576578776044],"agent":"Reinforce"},{"env_step":1413120.0,"rew":215.4867919921875,"rew_std":58.539436951587135,"iqm":221.1207021077474,"iqm_confidence_interval":[151.37629191080728,282.22046915690106],"agent":"Reinforce"},{"env_step":1443840.0,"rew":216.3734924316406,"rew_std":38.32389133621041,"iqm":212.92101033528647,"iqm_confidence_interval":[176.09254455566406,261.8755645751953],"agent":"Reinforce"},{"env_step":1474560.0,"rew":264.408349609375,"rew_std":38.30968749271509,"iqm":262.39134216308594,"iqm_confidence_interval":[221.9061533610026,311.08123779296875],"agent":"Reinforce"},{"env_step":1505280.0,"rew":283.94684143066405,"rew_std":104.28181017938446,"iqm":277.3927815755208,"iqm_confidence_interval":[163.11287434895834,405.19097900390625],"agent":"Reinforce"},{"env_step":1536000.0,"rew":213.6966125488281,"rew_std":21.125488073494697,"iqm":221.94438680013022,"iqm_confidence_interval":[187.61536153157553,229.0794881184896],"agent":"Reinforce"},{"env_step":1566720.0,"rew":240.69333190917968,"rew_std":43.58871731853648,"iqm":239.96679178873697,"iqm_confidence_interval":[189.83861287434897,289.31268310546875],"agent":"Reinforce"},{"env_step":1597440.0,"rew":227.7518280029297,"rew_std":43.6807819928471,"iqm":226.01578776041666,"iqm_confidence_interval":[178.9113566080729,280.186767578125],"agent":"Reinforce"},{"env_step":1628160.0,"rew":254.28907470703126,"rew_std":58.96974427894262,"iqm":265.43126424153644,"iqm_confidence_interval":[183.42256673177084,316.2568359375],"agent":"Reinforce"},{"env_step":1658880.0,"rew":279.0290191650391,"rew_std":29.8961086169682,"iqm":284.4513448079427,"iqm_confidence_interval":[240.81451416015625,309.4889424641927],"agent":"Reinforce"},{"env_step":1689600.0,"rew":280.4690307617187,"rew_std":69.44635836416715,"iqm":291.83249918619794,"iqm_confidence_interval":[195.19434611002603,353.3846028645833],"agent":"Reinforce"},{"env_step":1720320.0,"rew":285.4685363769531,"rew_std":45.42759739091963,"iqm":292.0075988769531,"iqm_confidence_interval":[229.6476847330729,333.6572977701823],"agent":"Reinforce"},{"env_step":1751040.0,"rew":266.4340545654297,"rew_std":47.66348445637528,"iqm":280.7711690266927,"iqm_confidence_interval":[205.85944620768228,306.71190388997394],"agent":"Reinforce"},{"env_step":1781760.0,"rew":234.5447265625,"rew_std":36.225980542194165,"iqm":237.1103769938151,"iqm_confidence_interval":[191.05297342936197,275.2825419108073],"agent":"Reinforce"},{"env_step":1812480.0,"rew":307.52418212890626,"rew_std":84.81374513240468,"iqm":275.64890543619794,"iqm_confidence_interval":[246.65121459960938,410.00844319661456],"agent":"Reinforce"},{"env_step":1843200.0,"rew":266.19762268066404,"rew_std":62.69389135691518,"iqm":265.62952677408856,"iqm_confidence_interval":[190.63823445638022,336.2323404947917],"agent":"Reinforce"},{"env_step":1873920.0,"rew":291.81089782714844,"rew_std":50.755500210592444,"iqm":272.91077677408856,"iqm_confidence_interval":[254.44548543294272,355.9960530598958],"agent":"Reinforce"},{"env_step":1904640.0,"rew":259.18329467773435,"rew_std":48.49455400892606,"iqm":253.11312866210938,"iqm_confidence_interval":[205.61703491210938,317.95094807942706],"agent":"Reinforce"},{"env_step":1935360.0,"rew":307.09999389648436,"rew_std":43.827630316738166,"iqm":313.4484151204427,"iqm_confidence_interval":[254.4388631184896,352.08115641276044],"agent":"Reinforce"},{"env_step":1966080.0,"rew":281.89708557128904,"rew_std":45.02522260976583,"iqm":289.0003356933594,"iqm_confidence_interval":[224.81385294596353,328.21592203776044],"agent":"Reinforce"},{"env_step":1996800.0,"rew":309.51084899902344,"rew_std":78.08004038475458,"iqm":303.74212646484375,"iqm_confidence_interval":[223.89665730794272,398.45416259765625],"agent":"Reinforce"},{"env_step":2027520.0,"rew":333.3232360839844,"rew_std":35.671012939999144,"iqm":334.38214111328125,"iqm_confidence_interval":[289.5745035807292,369.2607421875],"agent":"Reinforce"},{"env_step":2058240.0,"rew":271.2226257324219,"rew_std":14.60967271548954,"iqm":273.5323740641276,"iqm_confidence_interval":[252.78433736165366,283.69317626953125],"agent":"Reinforce"},{"env_step":2088960.0,"rew":314.7043701171875,"rew_std":32.77643435528697,"iqm":314.237060546875,"iqm_confidence_interval":[276.87865193684894,353.4198506673177],"agent":"Reinforce"},{"env_step":2119680.0,"rew":290.65659790039064,"rew_std":50.538724138399026,"iqm":286.47316487630206,"iqm_confidence_interval":[235.33682250976562,353.5723063151042],"agent":"Reinforce"},{"env_step":2150400.0,"rew":324.43235473632814,"rew_std":34.699153547407164,"iqm":327.8403015136719,"iqm_confidence_interval":[281.2065124511719,359.95556640625],"agent":"Reinforce"},{"env_step":2181120.0,"rew":334.29482116699216,"rew_std":61.61826359031619,"iqm":343.5783386230469,"iqm_confidence_interval":[261.01324462890625,395.5978291829427],"agent":"Reinforce"},{"env_step":2211840.0,"rew":301.3869171142578,"rew_std":54.98068490657697,"iqm":296.2631429036458,"iqm_confidence_interval":[242.42849731445312,369.9410705566406],"agent":"Reinforce"},{"env_step":2242560.0,"rew":338.5574951171875,"rew_std":51.59526706535765,"iqm":328.11199951171875,"iqm_confidence_interval":[286.99342854817706,404.38866170247394],"agent":"Reinforce"},{"env_step":2273280.0,"rew":301.586181640625,"rew_std":16.842084248418107,"iqm":305.23716227213544,"iqm_confidence_interval":[281.1344909667969,318.3143310546875],"agent":"Reinforce"},{"env_step":2304000.0,"rew":323.445556640625,"rew_std":61.267217037347834,"iqm":332.5047098795573,"iqm_confidence_interval":[245.55340576171875,387.0192565917969],"agent":"Reinforce"},{"env_step":2334720.0,"rew":320.56577758789064,"rew_std":35.492188799324026,"iqm":309.4940490722656,"iqm_confidence_interval":[289.9680887858073,367.02295939127606],"agent":"Reinforce"},{"env_step":2365440.0,"rew":335.6709899902344,"rew_std":66.80572709335965,"iqm":337.3677469889323,"iqm_confidence_interval":[255.53173828125,412.7243143717448],"agent":"Reinforce"},{"env_step":2396160.0,"rew":329.3257080078125,"rew_std":53.664796517586076,"iqm":331.4996032714844,"iqm_confidence_interval":[266.3409016927083,391.79189046223956],"agent":"Reinforce"},{"env_step":2426880.0,"rew":329.5802276611328,"rew_std":84.17181155441796,"iqm":352.0137125651042,"iqm_confidence_interval":[223.3980712890625,403.7108459472656],"agent":"Reinforce"},{"env_step":2457600.0,"rew":318.6984924316406,"rew_std":74.45687540237842,"iqm":315.1752421061198,"iqm_confidence_interval":[234.7754109700521,405.18434651692706],"agent":"Reinforce"},{"env_step":2488320.0,"rew":313.40746154785154,"rew_std":73.8331593572831,"iqm":319.9508361816406,"iqm_confidence_interval":[222.83462524414062,395.68932088216144],"agent":"Reinforce"},{"env_step":2519040.0,"rew":335.41685180664064,"rew_std":47.44474724052373,"iqm":317.7933044433594,"iqm_confidence_interval":[300.66258748372394,397.95128377278644],"agent":"Reinforce"},{"env_step":2549760.0,"rew":292.2950775146484,"rew_std":50.12930012184995,"iqm":307.0918782552083,"iqm_confidence_interval":[226.51502482096353,335.8219909667969],"agent":"Reinforce"},{"env_step":2580480.0,"rew":378.39520874023435,"rew_std":36.30654114289557,"iqm":374.59874471028644,"iqm_confidence_interval":[339.26439412434894,423.85561116536456],"agent":"Reinforce"},{"env_step":2611200.0,"rew":339.3572692871094,"rew_std":52.24209828973846,"iqm":342.9746602376302,"iqm_confidence_interval":[280.0288798014323,393.28118896484375],"agent":"Reinforce"},{"env_step":2641920.0,"rew":311.60890808105466,"rew_std":88.25142354618134,"iqm":342.87502034505206,"iqm_confidence_interval":[201.42610677083334,378.2011006673177],"agent":"Reinforce"},{"env_step":2672640.0,"rew":353.88108215332034,"rew_std":82.5540304412778,"iqm":361.1837870279948,"iqm_confidence_interval":[252.4522501627604,432.9479064941406],"agent":"Reinforce"},{"env_step":2703360.0,"rew":331.2927978515625,"rew_std":33.18178391165257,"iqm":332.3551737467448,"iqm_confidence_interval":[291.2009684244792,363.9695332845052],"agent":"Reinforce"},{"env_step":2734080.0,"rew":372.3266265869141,"rew_std":78.96542655118607,"iqm":378.66721598307294,"iqm_confidence_interval":[274.5861307779948,457.87548828125],"agent":"Reinforce"},{"env_step":2764800.0,"rew":325.185791015625,"rew_std":50.082658331132286,"iqm":322.7134602864583,"iqm_confidence_interval":[266.5505676269531,381.635498046875],"agent":"Reinforce"},{"env_step":2795520.0,"rew":274.6520294189453,"rew_std":49.78978121357848,"iqm":271.5624643961589,"iqm_confidence_interval":[220.35921732584634,334.1891581217448],"agent":"Reinforce"},{"env_step":2826240.0,"rew":322.1300994873047,"rew_std":89.95667783686396,"iqm":298.76662190755206,"iqm_confidence_interval":[235.33484903971353,432.0729471842448],"agent":"Reinforce"},{"env_step":2856960.0,"rew":340.1513916015625,"rew_std":37.89333285730595,"iqm":345.91286214192706,"iqm_confidence_interval":[296.1921793619792,381.60411580403644],"agent":"Reinforce"},{"env_step":2887680.0,"rew":366.31533203125,"rew_std":55.248782096101024,"iqm":362.03123982747394,"iqm_confidence_interval":[303.2717692057292,431.9060872395833],"agent":"Reinforce"},{"env_step":2918400.0,"rew":372.2232238769531,"rew_std":84.5105446284943,"iqm":377.65968831380206,"iqm_confidence_interval":[268.6006571451823,463.58217366536456],"agent":"Reinforce"},{"env_step":2949120.0,"rew":309.97247314453125,"rew_std":41.89657001617316,"iqm":311.08339436848956,"iqm_confidence_interval":[259.52565511067706,356.1070963541667],"agent":"Reinforce"},{"env_step":2979840.0,"rew":353.62855224609376,"rew_std":44.15351430823454,"iqm":354.0978698730469,"iqm_confidence_interval":[301.82387288411456,405.1334737141927],"agent":"Reinforce"},{"env_step":3010560.0,"rew":297.6916809082031,"rew_std":84.90951670827089,"iqm":291.50318400065106,"iqm_confidence_interval":[210.40422566731772,394.10882568359375],"agent":"Reinforce"},{"env_step":3041280.0,"rew":301.43951721191405,"rew_std":59.52449076004893,"iqm":282.8259633382161,"iqm_confidence_interval":[251.1397959391276,379.67467244466144],"agent":"Reinforce"},{"env_step":3072000.0,"rew":343.80126953125,"rew_std":65.90217959384564,"iqm":336.1062316894531,"iqm_confidence_interval":[274.6854248046875,427.3608703613281],"agent":"Reinforce"},{"env_step":0.0,"rew":-77.19025497436523,"rew_std":62.2254596836894,"iqm":-54.83956527709961,"iqm_confidence_interval":[-152.90450032552084,-29.95730209350586],"agent":"SAC"},{"env_step":5000.0,"rew":-35.89192981719971,"rew_std":17.41051894400482,"iqm":-36.26547304789225,"iqm_confidence_interval":[-55.806538899739586,-15.982662200927734],"agent":"SAC"},{"env_step":10000.0,"rew":-8.365598672628403,"rew_std":8.014129819343015,"iqm":-7.29702623685201,"iqm_confidence_interval":[-18.26535193125407,0.24046256144841513],"agent":"SAC"},{"env_step":15000.0,"rew":-20.333968925476075,"rew_std":31.329479140643837,"iqm":-21.02641010284424,"iqm_confidence_interval":[-57.44995625813802,14.767145792643229],"agent":"SAC"},{"env_step":20000.0,"rew":17.028606390953065,"rew_std":37.66061999840288,"iqm":8.315877874692282,"iqm_confidence_interval":[-17.764819820721943,63.04864088694254],"agent":"SAC"},{"env_step":25000.0,"rew":37.705559158325194,"rew_std":25.836473099852302,"iqm":32.60944366455078,"iqm_confidence_interval":[13.486283620198568,68.4135856628418],"agent":"SAC"},{"env_step":30000.0,"rew":75.73391571044922,"rew_std":22.290208907489014,"iqm":74.22437032063802,"iqm_confidence_interval":[52.38097127278646,102.75186920166016],"agent":"SAC"},{"env_step":35000.0,"rew":84.19118194580078,"rew_std":17.008699099731583,"iqm":85.2749532063802,"iqm_confidence_interval":[62.907440185546875,102.23167928059895],"agent":"SAC"},{"env_step":40000.0,"rew":129.09686126708985,"rew_std":22.589794187279907,"iqm":129.61712137858072,"iqm_confidence_interval":[101.76744079589844,154.85226440429688],"agent":"SAC"},{"env_step":45000.0,"rew":152.65745239257814,"rew_std":48.46755208574535,"iqm":155.56519826253256,"iqm_confidence_interval":[93.70250701904297,206.66668192545572],"agent":"SAC"},{"env_step":50000.0,"rew":155.83685607910155,"rew_std":31.52167141078223,"iqm":143.8017832438151,"iqm_confidence_interval":[132.99841817220053,195.23765563964844],"agent":"SAC"},{"env_step":55000.0,"rew":182.55461730957032,"rew_std":43.99510158347367,"iqm":188.25503540039062,"iqm_confidence_interval":[126.17987569173177,224.71722412109375],"agent":"SAC"},{"env_step":60000.0,"rew":176.60393676757812,"rew_std":43.76801045676628,"iqm":160.3560536702474,"iqm_confidence_interval":[142.53045145670572,231.58539835611978],"agent":"SAC"},{"env_step":65000.0,"rew":199.31842346191405,"rew_std":58.64699520848267,"iqm":183.22557576497397,"iqm_confidence_interval":[147.38675435384116,272.1580505371094],"agent":"SAC"},{"env_step":70000.0,"rew":219.99323272705078,"rew_std":85.8307555926905,"iqm":203.48806762695312,"iqm_confidence_interval":[130.2459971110026,322.3173828125],"agent":"SAC"},{"env_step":75000.0,"rew":188.92798461914063,"rew_std":35.46983211589195,"iqm":192.1189727783203,"iqm_confidence_interval":[146.61749267578125,228.26592508951822],"agent":"SAC"},{"env_step":80000.0,"rew":286.16457214355466,"rew_std":85.8509054317401,"iqm":306.4248555501302,"iqm_confidence_interval":[183.60598754882812,372.40846761067706],"agent":"SAC"},{"env_step":85000.0,"rew":284.43455505371094,"rew_std":96.45630722641503,"iqm":294.31304423014325,"iqm_confidence_interval":[165.90079243977866,385.88283284505206],"agent":"SAC"},{"env_step":90000.0,"rew":250.1466552734375,"rew_std":42.79532797670044,"iqm":254.65687052408853,"iqm_confidence_interval":[200.3510538736979,295.53907267252606],"agent":"SAC"},{"env_step":95000.0,"rew":252.18509826660156,"rew_std":35.802201568385726,"iqm":250.29408264160156,"iqm_confidence_interval":[210.97714233398438,294.58123779296875],"agent":"SAC"},{"env_step":100000.0,"rew":286.2181793212891,"rew_std":84.1511572584835,"iqm":300.5720520019531,"iqm_confidence_interval":[181.56843058268228,368.3089294433594],"agent":"SAC"},{"env_step":105000.0,"rew":264.8419158935547,"rew_std":35.2395765966379,"iqm":275.5354715983073,"iqm_confidence_interval":[222.22441609700522,293.87743123372394],"agent":"SAC"},{"env_step":110000.0,"rew":513.6656555175781,"rew_std":218.65707318919647,"iqm":428.05491129557294,"iqm_confidence_interval":[357.8131408691406,797.7045084635416],"agent":"SAC"},{"env_step":115000.0,"rew":497.7557800292969,"rew_std":366.65431431565366,"iqm":323.65578206380206,"iqm_confidence_interval":[290.1732482910156,945.1743876139323],"agent":"SAC"},{"env_step":120000.0,"rew":476.4174530029297,"rew_std":283.98567750072647,"iqm":374.6880594889323,"iqm_confidence_interval":[262.7366231282552,837.0888977050781],"agent":"SAC"},{"env_step":125000.0,"rew":453.8511901855469,"rew_std":223.38435370933277,"iqm":371.7033182779948,"iqm_confidence_interval":[277.9562479654948,729.7964579264323],"agent":"SAC"},{"env_step":130000.0,"rew":548.385791015625,"rew_std":283.5735249021172,"iqm":504.7932637532552,"iqm_confidence_interval":[261.8063456217448,908.9465535481771],"agent":"SAC"},{"env_step":135000.0,"rew":561.4510925292968,"rew_std":308.97892296397185,"iqm":446.41258748372394,"iqm_confidence_interval":[319.62439982096356,958.6476236979166],"agent":"SAC"},{"env_step":140000.0,"rew":454.39158020019534,"rew_std":285.15522952077515,"iqm":392.1344451904297,"iqm_confidence_interval":[168.9661102294922,801.3722330729166],"agent":"SAC"},{"env_step":145000.0,"rew":604.955941772461,"rew_std":565.4580665055674,"iqm":358.6328938802083,"iqm_confidence_interval":[242.29569498697916,1315.4267679850261],"agent":"SAC"},{"env_step":150000.0,"rew":579.629736328125,"rew_std":174.94201468630686,"iqm":614.0927429199219,"iqm_confidence_interval":[350.2121276855469,741.2681070963541],"agent":"SAC"},{"env_step":155000.0,"rew":801.4467224121094,"rew_std":331.04880898309625,"iqm":846.8262329101562,"iqm_confidence_interval":[391.0495198567708,1158.1105550130208],"agent":"SAC"},{"env_step":160000.0,"rew":779.7697937011719,"rew_std":433.51506755083886,"iqm":664.1381429036459,"iqm_confidence_interval":[394.3942057291667,1309.4696451822917],"agent":"SAC"},{"env_step":165000.0,"rew":1033.8955139160157,"rew_std":481.05420825123053,"iqm":1123.970926920573,"iqm_confidence_interval":[414.709716796875,1445.7089029947917],"agent":"SAC"},{"env_step":170000.0,"rew":800.0918518066406,"rew_std":329.7873973303998,"iqm":779.8632609049479,"iqm_confidence_interval":[441.96872965494794,1199.1294759114583],"agent":"SAC"},{"env_step":175000.0,"rew":915.3458435058594,"rew_std":477.158514471823,"iqm":856.32373046875,"iqm_confidence_interval":[407.51715087890625,1494.6025390625],"agent":"SAC"},{"env_step":180000.0,"rew":704.3885070800782,"rew_std":199.70826616083713,"iqm":752.0989583333334,"iqm_confidence_interval":[440.9189860026042,879.9899088541666],"agent":"SAC"},{"env_step":185000.0,"rew":686.2416809082031,"rew_std":325.6238141858492,"iqm":636.9073893229166,"iqm_confidence_interval":[371.8382568359375,1070.7754516601562],"agent":"SAC"},{"env_step":190000.0,"rew":1294.975811767578,"rew_std":770.3208203642287,"iqm":1214.5337524414062,"iqm_confidence_interval":[454.859130859375,2224.272705078125],"agent":"SAC"},{"env_step":195000.0,"rew":1870.000341796875,"rew_std":1260.2279022802952,"iqm":1780.2755533854167,"iqm_confidence_interval":[536.3902587890625,3447.1735026041665],"agent":"SAC"},{"env_step":200000.0,"rew":1622.0802001953125,"rew_std":623.6543352847739,"iqm":1583.7178141276042,"iqm_confidence_interval":[951.5055745442709,2395.7915852864585],"agent":"SAC"},{"env_step":205000.0,"rew":1854.9751220703124,"rew_std":1056.689688092062,"iqm":1635.9498291015625,"iqm_confidence_interval":[887.18212890625,3234.2255859375],"agent":"SAC"},{"env_step":210000.0,"rew":1839.5556640625,"rew_std":955.7378555572526,"iqm":1730.114013671875,"iqm_confidence_interval":[848.3274332682291,3050.9403483072915],"agent":"SAC"},{"env_step":215000.0,"rew":1429.5697509765625,"rew_std":648.9423857356364,"iqm":1418.4546508789062,"iqm_confidence_interval":[736.1454060872396,2217.648681640625],"agent":"SAC"},{"env_step":220000.0,"rew":1728.10546875,"rew_std":925.2835955326163,"iqm":1642.0914713541667,"iqm_confidence_interval":[722.1980387369791,2882.3997395833335],"agent":"SAC"},{"env_step":225000.0,"rew":1997.5099365234375,"rew_std":1183.7667825669266,"iqm":1655.85205078125,"iqm_confidence_interval":[985.1422119140625,3557.1197102864585],"agent":"SAC"},{"env_step":230000.0,"rew":2152.180712890625,"rew_std":840.7990009143479,"iqm":1817.1878662109375,"iqm_confidence_interval":[1581.4125162760417,3174.1283365885415],"agent":"SAC"},{"env_step":235000.0,"rew":2194.316857910156,"rew_std":673.0463746958201,"iqm":2331.4443359375,"iqm_confidence_interval":[1369.1734619140625,2877.5187174479165],"agent":"SAC"},{"env_step":240000.0,"rew":2253.086572265625,"rew_std":1177.115510302391,"iqm":2028.7047932942708,"iqm_confidence_interval":[1036.8601481119792,3703.5294596354165],"agent":"SAC"},{"env_step":245000.0,"rew":2316.2286376953125,"rew_std":1117.8252019819483,"iqm":2303.7169596354165,"iqm_confidence_interval":[1055.2244466145833,3587.3098958333335],"agent":"SAC"},{"env_step":250000.0,"rew":2549.455847167969,"rew_std":1084.187701336248,"iqm":2822.6012369791665,"iqm_confidence_interval":[1117.0762125651042,3513.427001953125],"agent":"SAC"},{"env_step":255000.0,"rew":2546.011181640625,"rew_std":862.4939491058706,"iqm":2500.5267333984375,"iqm_confidence_interval":[1539.4224039713542,3553.4488932291665],"agent":"SAC"},{"env_step":260000.0,"rew":2777.9132080078125,"rew_std":813.6251453201878,"iqm":2819.093017578125,"iqm_confidence_interval":[1862.713623046875,3660.3916829427085],"agent":"SAC"},{"env_step":265000.0,"rew":2675.9651123046874,"rew_std":1029.449430698658,"iqm":2811.8418782552085,"iqm_confidence_interval":[1408.0743815104167,3794.1632486979165],"agent":"SAC"},{"env_step":270000.0,"rew":2514.9052490234376,"rew_std":1124.8782698523578,"iqm":2861.392333984375,"iqm_confidence_interval":[1048.2747395833333,3495.075439453125],"agent":"SAC"},{"env_step":275000.0,"rew":3179.036376953125,"rew_std":607.7492449757316,"iqm":3177.6503092447915,"iqm_confidence_interval":[2465.8914388020835,3866.8968098958335],"agent":"SAC"},{"env_step":280000.0,"rew":2910.2163330078124,"rew_std":1210.8612926093572,"iqm":2791.435302734375,"iqm_confidence_interval":[1585.8204752604167,4416.846435546875],"agent":"SAC"},{"env_step":285000.0,"rew":3425.8759033203123,"rew_std":976.7204258971652,"iqm":3409.6282552083335,"iqm_confidence_interval":[2300.1615397135415,4559.906412760417],"agent":"SAC"},{"env_step":290000.0,"rew":3364.4786376953125,"rew_std":1288.5817990288117,"iqm":3554.5956217447915,"iqm_confidence_interval":[1742.9095052083333,4730.11376953125],"agent":"SAC"},{"env_step":295000.0,"rew":3307.5587890625,"rew_std":822.0034940166391,"iqm":3246.4475911458335,"iqm_confidence_interval":[2356.705810546875,4260.758626302083],"agent":"SAC"},{"env_step":300000.0,"rew":3306.75478515625,"rew_std":806.6955844456846,"iqm":3014.4081217447915,"iqm_confidence_interval":[2705.4403483072915,4350.734456380208],"agent":"SAC"},{"env_step":305000.0,"rew":3625.849853515625,"rew_std":980.7105282519498,"iqm":3727.7289225260415,"iqm_confidence_interval":[2494.114013671875,4673.794596354167],"agent":"SAC"},{"env_step":310000.0,"rew":3425.81767578125,"rew_std":610.7586462895265,"iqm":3526.7610677083335,"iqm_confidence_interval":[2645.7422688802085,4039.8333333333335],"agent":"SAC"},{"env_step":315000.0,"rew":4136.97744140625,"rew_std":674.0098307295987,"iqm":4147.991780598958,"iqm_confidence_interval":[3440.5952962239585,4941.0205078125],"agent":"SAC"},{"env_step":320000.0,"rew":3651.937939453125,"rew_std":619.3627563908244,"iqm":3881.1543782552085,"iqm_confidence_interval":[2836.1516927083335,4108.755126953125],"agent":"SAC"},{"env_step":325000.0,"rew":3813.688037109375,"rew_std":619.8694270427005,"iqm":3843.1497395833335,"iqm_confidence_interval":[3050.9657389322915,4457.54150390625],"agent":"SAC"},{"env_step":330000.0,"rew":4016.08896484375,"rew_std":877.8359813093385,"iqm":4268.130859375,"iqm_confidence_interval":[2952.9524739583335,4768.163736979167],"agent":"SAC"},{"env_step":335000.0,"rew":3351.1053955078123,"rew_std":1340.7030859505958,"iqm":3409.1995849609375,"iqm_confidence_interval":[1728.4311930338542,4773.09912109375],"agent":"SAC"},{"env_step":340000.0,"rew":4029.405517578125,"rew_std":694.3555798668357,"iqm":4012.6536458333335,"iqm_confidence_interval":[3209.7360026041665,4837.2890625],"agent":"SAC"},{"env_step":345000.0,"rew":4357.635205078125,"rew_std":243.9905210451901,"iqm":4462.39453125,"iqm_confidence_interval":[4045.9969075520835,4517.338541666667],"agent":"SAC"},{"env_step":350000.0,"rew":3720.547509765625,"rew_std":718.0328362710919,"iqm":3745.6464029947915,"iqm_confidence_interval":[2883.9214680989585,4566.412434895833],"agent":"SAC"},{"env_step":355000.0,"rew":3888.85146484375,"rew_std":922.4671085394701,"iqm":3855.592529296875,"iqm_confidence_interval":[2889.3636067708335,4916.275065104167],"agent":"SAC"},{"env_step":360000.0,"rew":4492.90849609375,"rew_std":427.0038617487227,"iqm":4377.559244791667,"iqm_confidence_interval":[4113.779296875,5044.82470703125],"agent":"SAC"},{"env_step":365000.0,"rew":3824.1125,"rew_std":466.36924421733886,"iqm":3795.31396484375,"iqm_confidence_interval":[3277.9092610677085,4345.49267578125],"agent":"SAC"},{"env_step":370000.0,"rew":3793.528271484375,"rew_std":423.99898561883765,"iqm":3723.669677734375,"iqm_confidence_interval":[3378.8328450520835,4334.930257161458],"agent":"SAC"},{"env_step":375000.0,"rew":4652.83203125,"rew_std":312.9602999232348,"iqm":4602.4091796875,"iqm_confidence_interval":[4353.67041015625,5055.286946614583],"agent":"SAC"},{"env_step":380000.0,"rew":4107.30537109375,"rew_std":775.9508890311822,"iqm":4418.271809895833,"iqm_confidence_interval":[3123.5266927083335,4674.585774739583],"agent":"SAC"},{"env_step":385000.0,"rew":4406.541162109375,"rew_std":367.5362551001566,"iqm":4326.8076171875,"iqm_confidence_interval":[4060.4236653645835,4876.874674479167],"agent":"SAC"},{"env_step":390000.0,"rew":4134.431591796875,"rew_std":558.0600379590566,"iqm":4135.73583984375,"iqm_confidence_interval":[3465.0011393229165,4748.268391927083],"agent":"SAC"},{"env_step":395000.0,"rew":4550.18330078125,"rew_std":226.7923284440831,"iqm":4571.38525390625,"iqm_confidence_interval":[4289.759602864583,4809.229329427083],"agent":"SAC"},{"env_step":400000.0,"rew":3855.42265625,"rew_std":591.9787384668502,"iqm":3861.1411946614585,"iqm_confidence_interval":[3179.0397135416665,4565.104817708333],"agent":"SAC"},{"env_step":405000.0,"rew":4034.952392578125,"rew_std":541.1052846818027,"iqm":3977.388427734375,"iqm_confidence_interval":[3428.2539876302085,4668.5390625],"agent":"SAC"},{"env_step":410000.0,"rew":4127.98525390625,"rew_std":706.011830856558,"iqm":4253.806477864583,"iqm_confidence_interval":[3245.044921875,4861.324381510417],"agent":"SAC"},{"env_step":415000.0,"rew":4885.615625,"rew_std":229.95295192904064,"iqm":4960.629069010417,"iqm_confidence_interval":[4596.09375,5081.218098958333],"agent":"SAC"},{"env_step":420000.0,"rew":4803.430078125,"rew_std":513.4374209524454,"iqm":4912.762369791667,"iqm_confidence_interval":[4155.508138020833,5321.6865234375],"agent":"SAC"},{"env_step":425000.0,"rew":4393.957666015625,"rew_std":427.68871366006164,"iqm":4471.551432291667,"iqm_confidence_interval":[3847.70751953125,4829.618489583333],"agent":"SAC"},{"env_step":430000.0,"rew":3938.97265625,"rew_std":863.6299788862947,"iqm":3929.3496907552085,"iqm_confidence_interval":[2924.7867024739585,4926.1904296875],"agent":"SAC"},{"env_step":435000.0,"rew":4272.44384765625,"rew_std":997.5040725809271,"iqm":4437.238199869792,"iqm_confidence_interval":[2987.696044921875,5159.943522135417],"agent":"SAC"},{"env_step":440000.0,"rew":4870.52685546875,"rew_std":395.9247725369098,"iqm":4934.081380208333,"iqm_confidence_interval":[4360.022623697917,5268.646647135417],"agent":"SAC"},{"env_step":445000.0,"rew":4458.2224609375,"rew_std":696.4949334455875,"iqm":4540.100667317708,"iqm_confidence_interval":[3592.3868001302085,5141.94677734375],"agent":"SAC"},{"env_step":450000.0,"rew":4328.62060546875,"rew_std":492.0183738827475,"iqm":4298.752360026042,"iqm_confidence_interval":[3862.9173990885415,4926.407877604167],"agent":"SAC"},{"env_step":455000.0,"rew":4464.820166015625,"rew_std":971.6304504932598,"iqm":4835.134114583333,"iqm_confidence_interval":[3282.0231119791665,5189.101888020833],"agent":"SAC"},{"env_step":460000.0,"rew":4928.2134765625,"rew_std":171.17832240165666,"iqm":4899.203287760417,"iqm_confidence_interval":[4749.197591145833,5139.430501302083],"agent":"SAC"},{"env_step":465000.0,"rew":4820.03369140625,"rew_std":660.0672310678303,"iqm":5072.874348958333,"iqm_confidence_interval":[3981.7281901041665,5301.434407552083],"agent":"SAC"},{"env_step":470000.0,"rew":4975.690625,"rew_std":416.3427877328529,"iqm":5020.641764322917,"iqm_confidence_interval":[4453.964680989583,5423.12646484375],"agent":"SAC"},{"env_step":475000.0,"rew":4900.019287109375,"rew_std":549.8527721873079,"iqm":5025.114908854167,"iqm_confidence_interval":[4240.366373697917,5388.130208333333],"agent":"SAC"},{"env_step":480000.0,"rew":4733.905322265625,"rew_std":876.2419560665137,"iqm":4990.281087239583,"iqm_confidence_interval":[3567.96240234375,5448.71728515625],"agent":"SAC"},{"env_step":485000.0,"rew":4593.1205078125,"rew_std":480.91717325260674,"iqm":4764.295084635417,"iqm_confidence_interval":[4002.5148111979165,4965.733723958333],"agent":"SAC"},{"env_step":490000.0,"rew":4962.76171875,"rew_std":297.40314389316654,"iqm":5018.493977864583,"iqm_confidence_interval":[4597.000325520833,5266.617350260417],"agent":"SAC"},{"env_step":495000.0,"rew":4791.8896484375,"rew_std":331.4742715277328,"iqm":4652.039388020833,"iqm_confidence_interval":[4566.775227864583,5213.990071614583],"agent":"SAC"},{"env_step":500000.0,"rew":4950.46767578125,"rew_std":438.3517698707152,"iqm":4896.762532552083,"iqm_confidence_interval":[4529.126790364583,5503.750813802083],"agent":"SAC"},{"env_step":0.0,"rew":829.9240783691406,"rew_std":213.16720514662308,"iqm":924.9030558268229,"iqm_confidence_interval":[572.5575561523438,960.5938924153646],"agent":"TD3"},{"env_step":5000.0,"rew":539.5705200195313,"rew_std":114.70532131045941,"iqm":527.8493448893229,"iqm_confidence_interval":[419.621826171875,684.4314371744791],"agent":"TD3"},{"env_step":10000.0,"rew":756.800830078125,"rew_std":51.67615449175576,"iqm":761.0208943684896,"iqm_confidence_interval":[694.5470377604166,815.1962687174479],"agent":"TD3"},{"env_step":15000.0,"rew":735.1518920898437,"rew_std":49.79578714039183,"iqm":729.9076741536459,"iqm_confidence_interval":[681.5853881835938,797.6240641276041],"agent":"TD3"},{"env_step":20000.0,"rew":566.7820190429687,"rew_std":128.86571468965482,"iqm":559.4737752278646,"iqm_confidence_interval":[415.3776041666667,710.1801961263021],"agent":"TD3"},{"env_step":25000.0,"rew":546.2086364746094,"rew_std":114.56158906049062,"iqm":598.8281453450521,"iqm_confidence_interval":[404.30999755859375,615.4661051432291],"agent":"TD3"},{"env_step":30000.0,"rew":584.0201599121094,"rew_std":69.366430886614,"iqm":573.4462280273438,"iqm_confidence_interval":[512.0141398111979,671.46533203125],"agent":"TD3"},{"env_step":35000.0,"rew":569.2431213378907,"rew_std":150.3781384089316,"iqm":571.3526204427084,"iqm_confidence_interval":[402.2006429036458,750.3717244466146],"agent":"TD3"},{"env_step":40000.0,"rew":539.8349243164063,"rew_std":133.25074037366434,"iqm":563.3495279947916,"iqm_confidence_interval":[379.60784912109375,677.0299275716146],"agent":"TD3"},{"env_step":45000.0,"rew":592.5176391601562,"rew_std":179.49365470179063,"iqm":607.0400797526041,"iqm_confidence_interval":[377.11484781901044,785.9901326497396],"agent":"TD3"},{"env_step":50000.0,"rew":650.7749572753906,"rew_std":186.60011518107405,"iqm":687.1291097005209,"iqm_confidence_interval":[417.8120930989583,843.447265625],"agent":"TD3"},{"env_step":55000.0,"rew":702.7271850585937,"rew_std":233.75226525935827,"iqm":699.2322184244791,"iqm_confidence_interval":[425.00307210286456,949.2679036458334],"agent":"TD3"},{"env_step":60000.0,"rew":770.480322265625,"rew_std":105.0838953860834,"iqm":786.7170613606771,"iqm_confidence_interval":[647.1668497721354,884.4464314778646],"agent":"TD3"},{"env_step":65000.0,"rew":767.2954956054688,"rew_std":185.80310917721403,"iqm":825.8852335611979,"iqm_confidence_interval":[521.1190795898438,921.1991373697916],"agent":"TD3"},{"env_step":70000.0,"rew":881.9124389648438,"rew_std":124.16713120431044,"iqm":892.2832438151041,"iqm_confidence_interval":[726.3288167317709,1013.8692016601562],"agent":"TD3"},{"env_step":75000.0,"rew":809.5596252441406,"rew_std":210.11544260847276,"iqm":840.7815755208334,"iqm_confidence_interval":[538.25390625,1003.7309773763021],"agent":"TD3"},{"env_step":80000.0,"rew":914.9486206054687,"rew_std":185.53253372227354,"iqm":866.2797037760416,"iqm_confidence_interval":[749.0588175455729,1146.3099772135417],"agent":"TD3"},{"env_step":85000.0,"rew":1038.5048583984376,"rew_std":354.55029263760997,"iqm":991.1633097330729,"iqm_confidence_interval":[685.3418986002604,1467.1781005859375],"agent":"TD3"},{"env_step":90000.0,"rew":1095.1505615234375,"rew_std":294.7722406276977,"iqm":1029.7857259114583,"iqm_confidence_interval":[825.5033569335938,1483.0994873046875],"agent":"TD3"},{"env_step":95000.0,"rew":926.5935485839843,"rew_std":358.78989947098245,"iqm":1008.4535522460938,"iqm_confidence_interval":[481.00408935546875,1286.0092366536458],"agent":"TD3"},{"env_step":100000.0,"rew":1249.812744140625,"rew_std":468.10361371058315,"iqm":1109.7845255533855,"iqm_confidence_interval":[836.6089070638021,1829.5582682291667],"agent":"TD3"},{"env_step":105000.0,"rew":1086.3578125,"rew_std":289.683933445901,"iqm":1119.0757853190105,"iqm_confidence_interval":[748.8553873697916,1412.0789794921875],"agent":"TD3"},{"env_step":110000.0,"rew":1384.1274536132812,"rew_std":511.3336586709261,"iqm":1256.2594807942708,"iqm_confidence_interval":[882.284423828125,2008.5397135416667],"agent":"TD3"},{"env_step":115000.0,"rew":1471.0005249023438,"rew_std":620.7393531156671,"iqm":1336.388448079427,"iqm_confidence_interval":[840.9969278971354,2220.077189127604],"agent":"TD3"},{"env_step":120000.0,"rew":1260.1736083984374,"rew_std":469.7023744926701,"iqm":1260.9583740234375,"iqm_confidence_interval":[691.15625,1786.0275065104167],"agent":"TD3"},{"env_step":125000.0,"rew":1511.67509765625,"rew_std":648.2241845549785,"iqm":1302.0400797526042,"iqm_confidence_interval":[942.7909342447916,2330.532185872396],"agent":"TD3"},{"env_step":130000.0,"rew":1458.1156372070313,"rew_std":549.6783235933749,"iqm":1327.2992350260417,"iqm_confidence_interval":[938.1090901692709,2141.26123046875],"agent":"TD3"},{"env_step":135000.0,"rew":1574.3567504882812,"rew_std":429.26752013044785,"iqm":1604.1842854817708,"iqm_confidence_interval":[1050.6182454427083,2058.0555419921875],"agent":"TD3"},{"env_step":140000.0,"rew":1368.9617309570312,"rew_std":371.94584620777067,"iqm":1352.2195638020833,"iqm_confidence_interval":[931.149658203125,1769.9747314453125],"agent":"TD3"},{"env_step":145000.0,"rew":1501.2026000976562,"rew_std":400.98189748060196,"iqm":1487.1974283854167,"iqm_confidence_interval":[1043.6548258463542,1973.142578125],"agent":"TD3"},{"env_step":150000.0,"rew":1727.3000610351562,"rew_std":601.7172140285786,"iqm":1626.8328043619792,"iqm_confidence_interval":[1127.7975260416667,2441.1529541015625],"agent":"TD3"},{"env_step":155000.0,"rew":1559.7651000976562,"rew_std":446.35813544513456,"iqm":1646.3952229817708,"iqm_confidence_interval":[1017.4626057942709,2022.75537109375],"agent":"TD3"},{"env_step":160000.0,"rew":1764.975537109375,"rew_std":732.2193744679822,"iqm":1591.4624837239583,"iqm_confidence_interval":[1065.2289632161458,2654.5350748697915],"agent":"TD3"},{"env_step":165000.0,"rew":1822.7712280273438,"rew_std":678.5803775092464,"iqm":1730.8170572916667,"iqm_confidence_interval":[1132.1419677734375,2662.698689778646],"agent":"TD3"},{"env_step":170000.0,"rew":1985.6068969726562,"rew_std":728.3031488916075,"iqm":1910.0275472005208,"iqm_confidence_interval":[1256.8134765625,2821.941202799479],"agent":"TD3"},{"env_step":175000.0,"rew":2109.0537353515624,"rew_std":782.0241694379353,"iqm":2076.4483642578125,"iqm_confidence_interval":[1259.3393147786458,3005.3229166666665],"agent":"TD3"},{"env_step":180000.0,"rew":2179.0469360351562,"rew_std":905.9575695451357,"iqm":2056.8221028645835,"iqm_confidence_interval":[1259.8915608723958,3278.7978515625],"agent":"TD3"},{"env_step":185000.0,"rew":2229.725476074219,"rew_std":865.0804432932722,"iqm":2190.3130696614585,"iqm_confidence_interval":[1324.1541748046875,3203.8898111979165],"agent":"TD3"},{"env_step":190000.0,"rew":2130.791638183594,"rew_std":708.9686972152747,"iqm":2196.904541015625,"iqm_confidence_interval":[1300.7169189453125,2856.4830729166665],"agent":"TD3"},{"env_step":195000.0,"rew":2176.3774291992186,"rew_std":1007.79712057956,"iqm":2057.643310546875,"iqm_confidence_interval":[1057.2288818359375,3397.5015462239585],"agent":"TD3"},{"env_step":200000.0,"rew":2305.0114013671873,"rew_std":957.4217813148268,"iqm":2259.0384928385415,"iqm_confidence_interval":[1214.5,3451.973876953125],"agent":"TD3"},{"env_step":205000.0,"rew":2238.530725097656,"rew_std":1034.0895668294195,"iqm":2172.598429361979,"iqm_confidence_interval":[1095.5213216145833,3516.4203287760415],"agent":"TD3"},{"env_step":210000.0,"rew":2403.3504272460937,"rew_std":953.5621026200623,"iqm":2470.398193359375,"iqm_confidence_interval":[1326.7755126953125,3386.0760904947915],"agent":"TD3"},{"env_step":215000.0,"rew":2477.0895263671873,"rew_std":938.5595760089883,"iqm":2550.865966796875,"iqm_confidence_interval":[1381.4463704427083,3520.93017578125],"agent":"TD3"},{"env_step":220000.0,"rew":2316.696423339844,"rew_std":853.0289853639126,"iqm":2350.9966634114585,"iqm_confidence_interval":[1342.1841227213542,3271.8939615885415],"agent":"TD3"},{"env_step":225000.0,"rew":2547.7952514648437,"rew_std":931.6966721701526,"iqm":2657.4053548177085,"iqm_confidence_interval":[1460.1538492838542,3533.2898763020835],"agent":"TD3"},{"env_step":230000.0,"rew":2465.588232421875,"rew_std":943.1721250006842,"iqm":2511.775390625,"iqm_confidence_interval":[1387.5060221354167,3513.1005045572915],"agent":"TD3"},{"env_step":235000.0,"rew":2547.3243774414063,"rew_std":940.3479030837134,"iqm":2770.5331217447915,"iqm_confidence_interval":[1395.3839925130208,3487.5193684895835],"agent":"TD3"},{"env_step":240000.0,"rew":2731.887939453125,"rew_std":1114.621718219357,"iqm":2842.3636067708335,"iqm_confidence_interval":[1444.1981608072917,3993.7285970052085],"agent":"TD3"},{"env_step":245000.0,"rew":2679.134765625,"rew_std":983.2848046758737,"iqm":2902.3509928385415,"iqm_confidence_interval":[1481.5174967447917,3653.83642578125],"agent":"TD3"},{"env_step":250000.0,"rew":2762.9610595703125,"rew_std":1060.952467688137,"iqm":2915.1644694010415,"iqm_confidence_interval":[1509.3499348958333,3902.3920084635415],"agent":"TD3"},{"env_step":255000.0,"rew":2753.7060668945314,"rew_std":1110.3489032409016,"iqm":2886.451171875,"iqm_confidence_interval":[1436.4121500651042,3987.7591959635415],"agent":"TD3"},{"env_step":260000.0,"rew":2606.6949462890625,"rew_std":1016.9661334599042,"iqm":2760.07666015625,"iqm_confidence_interval":[1398.0731608072917,3715.1659342447915],"agent":"TD3"},{"env_step":265000.0,"rew":2693.0078369140624,"rew_std":1088.9586023803274,"iqm":2883.4249674479165,"iqm_confidence_interval":[1411.1376953125,3859.88525390625],"agent":"TD3"},{"env_step":270000.0,"rew":2813.0802001953125,"rew_std":1107.950631195792,"iqm":2979.3451334635415,"iqm_confidence_interval":[1505.247314453125,4003.2862955729165],"agent":"TD3"},{"env_step":275000.0,"rew":2840.199169921875,"rew_std":1205.5980886899304,"iqm":2990.45751953125,"iqm_confidence_interval":[1433.0826822916667,4187.828857421875],"agent":"TD3"},{"env_step":280000.0,"rew":2649.196984863281,"rew_std":977.7908186825631,"iqm":2771.1956380208335,"iqm_confidence_interval":[1496.7768961588542,3610.7386067708335],"agent":"TD3"},{"env_step":285000.0,"rew":2984.6728271484376,"rew_std":1222.1267736710947,"iqm":3239.305419921875,"iqm_confidence_interval":[1495.8108723958333,4246.451497395833],"agent":"TD3"},{"env_step":290000.0,"rew":2653.69794921875,"rew_std":1022.5000328919515,"iqm":2808.8805338541665,"iqm_confidence_interval":[1406.3455403645833,3749.87744140625],"agent":"TD3"},{"env_step":295000.0,"rew":2903.6751953125,"rew_std":1186.7828461038976,"iqm":3059.5785319010415,"iqm_confidence_interval":[1480.8075358072917,4207.113037109375],"agent":"TD3"},{"env_step":300000.0,"rew":3046.1340087890626,"rew_std":1347.5070082991244,"iqm":3239.381591796875,"iqm_confidence_interval":[1406.5304361979167,4510.792643229167],"agent":"TD3"},{"env_step":305000.0,"rew":3028.121826171875,"rew_std":1355.1566766121175,"iqm":3201.7381184895835,"iqm_confidence_interval":[1385.89306640625,4520.408040364583],"agent":"TD3"},{"env_step":310000.0,"rew":3045.340930175781,"rew_std":1369.5901648511212,"iqm":3232.55810546875,"iqm_confidence_interval":[1463.2810465494792,4562.6923828125],"agent":"TD3"},{"env_step":315000.0,"rew":3028.4850219726563,"rew_std":1246.8533587402467,"iqm":3266.2167154947915,"iqm_confidence_interval":[1508.8992919921875,4332.45166015625],"agent":"TD3"},{"env_step":320000.0,"rew":2986.0309448242188,"rew_std":1193.0953476903455,"iqm":3256.34033203125,"iqm_confidence_interval":[1530.6516520182292,4195.118489583333],"agent":"TD3"},{"env_step":325000.0,"rew":3073.6957885742186,"rew_std":1349.3847990133245,"iqm":3310.2652994791665,"iqm_confidence_interval":[1459.5026448567708,4514.5234375],"agent":"TD3"},{"env_step":330000.0,"rew":3099.622509765625,"rew_std":1356.848981051641,"iqm":3403.2476399739585,"iqm_confidence_interval":[1420.4486490885417,4465.79931640625],"agent":"TD3"},{"env_step":335000.0,"rew":3085.056433105469,"rew_std":1285.395665056501,"iqm":3366.172607421875,"iqm_confidence_interval":[1542.4388427734375,4402.2119140625],"agent":"TD3"},{"env_step":340000.0,"rew":3071.9107177734377,"rew_std":1235.7570850757356,"iqm":3367.0099283854165,"iqm_confidence_interval":[1569.3037923177083,4306.1474609375],"agent":"TD3"},{"env_step":345000.0,"rew":3190.3670043945312,"rew_std":1315.1221008980763,"iqm":3472.8580729166665,"iqm_confidence_interval":[1540.4414469401042,4516.2958984375],"agent":"TD3"},{"env_step":350000.0,"rew":3151.562292480469,"rew_std":1319.802287809035,"iqm":3421.1626790364585,"iqm_confidence_interval":[1482.6082763671875,4491.857421875],"agent":"TD3"},{"env_step":355000.0,"rew":2993.9238891601562,"rew_std":1345.2920648248396,"iqm":3180.001708984375,"iqm_confidence_interval":[1400.4963785807292,4477.604817708333],"agent":"TD3"},{"env_step":360000.0,"rew":3240.46015625,"rew_std":1393.0734888624074,"iqm":3622.8191731770835,"iqm_confidence_interval":[1470.490478515625,4551.108072916667],"agent":"TD3"},{"env_step":365000.0,"rew":3034.42294921875,"rew_std":1246.2603451907094,"iqm":3263.3148600260415,"iqm_confidence_interval":[1509.1800130208333,4335.368489583333],"agent":"TD3"},{"env_step":370000.0,"rew":3270.672399902344,"rew_std":1393.160059416677,"iqm":3608.8076171875,"iqm_confidence_interval":[1505.1509195963542,4633.247395833333],"agent":"TD3"},{"env_step":375000.0,"rew":3249.9776000976562,"rew_std":1413.7088589889643,"iqm":3535.9663899739585,"iqm_confidence_interval":[1565.8721110026042,4724.507975260417],"agent":"TD3"},{"env_step":380000.0,"rew":3273.317138671875,"rew_std":1383.7569272549713,"iqm":3605.3651529947915,"iqm_confidence_interval":[1515.472412109375,4610.3828125],"agent":"TD3"},{"env_step":385000.0,"rew":3252.874890136719,"rew_std":1415.2156000510515,"iqm":3531.2469075520835,"iqm_confidence_interval":[1524.0929768880208,4721.978678385417],"agent":"TD3"},{"env_step":390000.0,"rew":3301.0752197265624,"rew_std":1415.3416560680325,"iqm":3627.8466796875,"iqm_confidence_interval":[1514.631103515625,4702.47021484375],"agent":"TD3"},{"env_step":395000.0,"rew":3278.423193359375,"rew_std":1397.7760225048944,"iqm":3655.669189453125,"iqm_confidence_interval":[1503.4281412760417,4591.7353515625],"agent":"TD3"},{"env_step":400000.0,"rew":2884.5741577148438,"rew_std":1445.511380441809,"iqm":2972.203084309896,"iqm_confidence_interval":[1176.9991861979167,4553.221842447917],"agent":"TD3"},{"env_step":405000.0,"rew":3330.7515502929687,"rew_std":1387.3228046538802,"iqm":3683.735107421875,"iqm_confidence_interval":[1551.3370361328125,4650.876139322917],"agent":"TD3"},{"env_step":410000.0,"rew":2710.1162353515624,"rew_std":1382.221121753278,"iqm":2593.720743815104,"iqm_confidence_interval":[1125.9298095703125,4308.127115885417],"agent":"TD3"},{"env_step":415000.0,"rew":3388.8285400390623,"rew_std":1412.7014251822077,"iqm":3769.463623046875,"iqm_confidence_interval":[1562.707763671875,4672.133626302083],"agent":"TD3"},{"env_step":420000.0,"rew":3448.588269042969,"rew_std":1405.2624520259374,"iqm":3859.49951171875,"iqm_confidence_interval":[1635.9779052734375,4720.6904296875],"agent":"TD3"},{"env_step":425000.0,"rew":3365.4913940429688,"rew_std":1484.9048095701287,"iqm":3660.5003255208335,"iqm_confidence_interval":[1520.9411214192708,4900.017740885417],"agent":"TD3"},{"env_step":430000.0,"rew":3435.064733886719,"rew_std":1556.7886671552396,"iqm":3886.5701497395835,"iqm_confidence_interval":[1424.7179768880208,4846.15966796875],"agent":"TD3"},{"env_step":435000.0,"rew":3276.301806640625,"rew_std":1451.3867925958466,"iqm":3560.5865071614585,"iqm_confidence_interval":[1544.50732421875,4797.141927083333],"agent":"TD3"},{"env_step":440000.0,"rew":3431.6516845703127,"rew_std":1528.6457745757746,"iqm":3875.1138509114585,"iqm_confidence_interval":[1508.0724283854167,4847.526204427083],"agent":"TD3"},{"env_step":445000.0,"rew":3552.8244384765626,"rew_std":1487.1323046080424,"iqm":3961.4072265625,"iqm_confidence_interval":[1658.235107421875,4947.947265625],"agent":"TD3"},{"env_step":450000.0,"rew":3356.4051879882813,"rew_std":1376.3847353292676,"iqm":3702.5616861979165,"iqm_confidence_interval":[1633.0870361328125,4680.08984375],"agent":"TD3"},{"env_step":455000.0,"rew":3563.1517944335938,"rew_std":1491.4385603812143,"iqm":3977.2855631510415,"iqm_confidence_interval":[1634.7741292317708,4936.918782552083],"agent":"TD3"},{"env_step":460000.0,"rew":3524.6409545898437,"rew_std":1430.2851011633504,"iqm":3919.8612467447915,"iqm_confidence_interval":[1705.1669514973958,4847.28564453125],"agent":"TD3"},{"env_step":465000.0,"rew":3604.3880004882812,"rew_std":1518.1113450254556,"iqm":3995.2644856770835,"iqm_confidence_interval":[1668.4000244140625,5048.280924479167],"agent":"TD3"},{"env_step":470000.0,"rew":3391.0613037109374,"rew_std":1477.569491993997,"iqm":3703.388916015625,"iqm_confidence_interval":[1602.2677408854167,4914.701822916667],"agent":"TD3"},{"env_step":475000.0,"rew":3521.8194580078125,"rew_std":1549.656846893595,"iqm":3898.186279296875,"iqm_confidence_interval":[1528.638427734375,4991.42919921875],"agent":"TD3"},{"env_step":480000.0,"rew":3536.843542480469,"rew_std":1479.3773751625279,"iqm":3875.009521484375,"iqm_confidence_interval":[1649.8277587890625,4982.876139322917],"agent":"TD3"},{"env_step":485000.0,"rew":3698.1513671875,"rew_std":1529.7254531111644,"iqm":4141.316487630208,"iqm_confidence_interval":[1714.9497884114583,5071.2822265625],"agent":"TD3"},{"env_step":490000.0,"rew":3550.1833251953126,"rew_std":1494.588728177185,"iqm":3913.3807779947915,"iqm_confidence_interval":[1690.8838704427083,5020.81591796875],"agent":"TD3"},{"env_step":495000.0,"rew":3703.76689453125,"rew_std":1576.607793115339,"iqm":4101.045084635417,"iqm_confidence_interval":[1677.4986165364583,5207.55615234375],"agent":"TD3"},{"env_step":500000.0,"rew":3575.063610839844,"rew_std":1511.873916898478,"iqm":3972.8521321614585,"iqm_confidence_interval":[1623.5158284505208,4996.941080729167],"agent":"TD3"}] ================================================ FILE: docs/_static/js/v5.json ================================================ { "$ref": "#/definitions/TopLevelSpec", "$schema": "http://json-schema.org/draft-07/schema#", "definitions": { "Aggregate": { "anyOf": [ { "$ref": "#/definitions/NonArgAggregateOp" }, { "$ref": "#/definitions/ArgmaxDef" }, { "$ref": "#/definitions/ArgminDef" } ] }, "AggregateOp": { "enum": [ "argmax", "argmin", "average", "count", "distinct", "max", "mean", "median", "min", "missing", "product", "q1", "q3", "ci0", "ci1", "stderr", "stdev", "stdevp", "sum", "valid", "values", "variance", "variancep" ], "type": "string" }, "AggregateTransform": { "additionalProperties": false, "properties": { "aggregate": { "description": "Array of objects that define fields to aggregate.", "items": { "$ref": "#/definitions/AggregatedFieldDef" }, "type": "array" }, "groupby": { "description": "The data fields to group by. If not specified, a single group containing all data objects will be used.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" } }, "required": [ "aggregate" ], "type": "object" }, "AggregatedFieldDef": { "additionalProperties": false, "properties": { "as": { "$ref": "#/definitions/FieldName", "description": "The output field names to use for each aggregated field." }, "field": { "$ref": "#/definitions/FieldName", "description": "The data field for which to compute aggregate function. This is required for all aggregation operations except `\"count\"`." }, "op": { "$ref": "#/definitions/AggregateOp", "description": "The aggregation operation to apply to the fields (e.g., `\"sum\"`, `\"average\"`, or `\"count\"`). See the [full list of supported aggregation operations](https://vega.github.io/vega-lite/docs/aggregate.html#ops) for more information." } }, "required": [ "op", "as" ], "type": "object" }, "Align": { "enum": [ "left", "center", "right" ], "type": "string" }, "AllSortString": { "anyOf": [ { "$ref": "#/definitions/SortOrder" }, { "$ref": "#/definitions/SortByChannel" }, { "$ref": "#/definitions/SortByChannelDesc" } ] }, "AnyMark": { "anyOf": [ { "$ref": "#/definitions/CompositeMark" }, { "$ref": "#/definitions/CompositeMarkDef" }, { "$ref": "#/definitions/Mark" }, { "$ref": "#/definitions/MarkDef" } ] }, "AnyMarkConfig": { "anyOf": [ { "$ref": "#/definitions/MarkConfig" }, { "$ref": "#/definitions/AreaConfig" }, { "$ref": "#/definitions/BarConfig" }, { "$ref": "#/definitions/RectConfig" }, { "$ref": "#/definitions/LineConfig" }, { "$ref": "#/definitions/TickConfig" } ] }, "AreaConfig": { "additionalProperties": false, "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The horizontal alignment of the text or ranged marks (area, bar, image, rect, rule). One of `\"left\"`, `\"right\"`, `\"center\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "angle": { "anyOf": [ { "description": "The rotation angle of the text, in degrees.", "maximum": 360, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG element, removing the mark item from the ARIA accessibility tree.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRole": { "anyOf": [ { "description": "Sets the type of user interface element of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"role\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRoleDescription": { "anyOf": [ { "description": "A human-readable, author-localized description for the role of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"aria-roledescription\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "aspect": { "anyOf": [ { "description": "Whether to keep aspect ratio of image marks.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "baseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For text marks, the vertical text baseline. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, `\"line-bottom\"`, or an expression reference that provides one of the valid values. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `lineHeight` rather than `fontSize` alone.\n\nFor range marks, the vertical alignment of the marks. One of `\"top\"`, `\"middle\"`, `\"bottom\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "blend": { "anyOf": [ { "$ref": "#/definitions/Blend", "description": "The color blend mode for drawing an item on its current background. Any valid [CSS mix-blend-mode](https://developer.mozilla.org/en-US/docs/Web/CSS/mix-blend-mode) value can be used.\n\n__Default value: `\"source-over\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "cornerRadius": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles or arcs' corners.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cursor": { "anyOf": [ { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the mark. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "dir": { "anyOf": [ { "$ref": "#/definitions/TextDirection", "description": "The direction of the text. One of `\"ltr\"` (left-to-right) or `\"rtl\"` (right-to-left). This property determines on which side is truncated in response to the limit parameter.\n\n__Default value:__ `\"ltr\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "dx": { "anyOf": [ { "description": "The horizontal offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "dy": { "anyOf": [ { "description": "The vertical offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "ellipsis": { "anyOf": [ { "description": "The ellipsis string for text truncated in response to the limit parameter.\n\n__Default value:__ `\"…\"`", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "endAngle": { "anyOf": [ { "description": "The end angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fill": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default fill color. This property has higher precedence than `config.color`. Set to `null` to remove fill.\n\n__Default value:__ (None)" }, "fillOpacity": { "anyOf": [ { "description": "The fill opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "filled": { "description": "Whether the mark's color should be used as fill color instead of stroke color.\n\n__Default value:__ `false` for all `point`, `line`, and `rule` marks as well as `geoshape` marks for [`graticule`](https://vega.github.io/vega-lite/docs/data.html#graticule) data sources; otherwise, `true`.\n\n__Note:__ This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).", "type": "boolean" }, "font": { "anyOf": [ { "description": "The typeface to set the text in (e.g., `\"Helvetica Neue\"`).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontSize": { "anyOf": [ { "description": "The font size, in pixels.\n\n__Default value:__ `11`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style (e.g., `\"italic\"`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "fontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "height": { "anyOf": [ { "description": "Height of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "href": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "A URL to load upon mouse click. If defined, the mark acts as a hyperlink." }, { "$ref": "#/definitions/ExprRef" } ] }, "innerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The inner radius in pixels of arc marks. `innerRadius` is an alias for `radius2`.\n\n__Default value:__ `0`", "minimum": 0 }, "interpolate": { "anyOf": [ { "$ref": "#/definitions/Interpolate", "description": "The line interpolation method to use for line and area marks. One of the following:\n- `\"linear\"`: piecewise linear segments, as in a polyline.\n- `\"linear-closed\"`: close the linear segments to form a polygon.\n- `\"step\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"step-before\"`: alternate between vertical and horizontal segments, as in a step function.\n- `\"step-after\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"basis\"`: a B-spline, with control point duplication on the ends.\n- `\"basis-open\"`: an open B-spline; may not intersect the start or end.\n- `\"basis-closed\"`: a closed B-spline, as in a loop.\n- `\"cardinal\"`: a Cardinal spline, with control point duplication on the ends.\n- `\"cardinal-open\"`: an open Cardinal spline; may not intersect the start or end, but will intersect other control points.\n- `\"cardinal-closed\"`: a closed Cardinal spline, as in a loop.\n- `\"bundle\"`: equivalent to basis, except the tension parameter is used to straighten the spline.\n- `\"monotone\"`: cubic interpolation that preserves monotonicity in y." }, { "$ref": "#/definitions/ExprRef" } ] }, "invalid": { "description": "Defines how Vega-Lite should handle marks for invalid values (`null` and `NaN`).\n- If set to `\"filter\"` (default), all data items with null values will be skipped (for line, trail, and area marks) or filtered (for other marks).\n- If `null`, all data items are included. In this case, invalid values will be interpreted as zeroes.", "enum": [ "filter", null ], "type": [ "string", "null" ] }, "limit": { "anyOf": [ { "description": "The maximum length of the text mark in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0` -- indicating no limit", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "line": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/OverlayMarkDef" } ], "description": "A flag for overlaying line on top of area marks, or an object defining the properties of the overlayed lines.\n\n- If this value is an empty object (`{}`) or `true`, lines with default properties will be used.\n\n- If this value is `false`, no lines would be automatically added to area marks.\n\n__Default value:__ `false`." }, "lineBreak": { "anyOf": [ { "description": "A delimiter, such as a newline character, upon which to break text strings into multiple lines. This property is ignored if the text is array-valued.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineHeight": { "anyOf": [ { "description": "The line height in pixels (the spacing between subsequent lines of text) for multi-line text marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "opacity": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The overall opacity (value between [0,1]).\n\n__Default value:__ `0.7` for non-aggregate plots with `point`, `tick`, `circle`, or `square` marks or layered `bar` charts and `1` otherwise.", "maximum": 1, "minimum": 0 }, "order": { "description": "For line and trail marks, this `order` property can be set to `null` or `false` to make the lines use the original order in the data sources.", "type": [ "null", "boolean" ] }, "orient": { "$ref": "#/definitions/Orientation", "description": "The orientation of a non-stacked bar, tick, area, and line charts. The value is either horizontal (default) or vertical.\n- For bar, rule and tick, this determines whether the size of the bar and tick should be applied to x or y dimension.\n- For area, this property determines the orient property of the Vega output.\n- For line and trail marks, this property determines the sort order of the points in the line if `config.sortLineBy` is not specified. For stacked charts, this is always determined by the orientation of the stack; therefore explicitly specified value will be ignored." }, "outerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The outer radius in pixels of arc marks. `outerRadius` is an alias for `radius`.\n\n__Default value:__ `0`", "minimum": 0 }, "padAngle": { "anyOf": [ { "description": "The angular padding applied to sides of the arc, in radians.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "point": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/OverlayMarkDef" }, { "const": "transparent", "type": "string" } ], "description": "A flag for overlaying points on top of line or area marks, or an object defining the properties of the overlayed points.\n\n- If this property is `\"transparent\"`, transparent points will be used (for enhancing tooltips and selections).\n\n- If this property is an empty object (`{}`) or `true`, filled points with default properties will be used.\n\n- If this property is `false`, no points would be automatically added to line or area marks.\n\n__Default value:__ `false`." }, "radius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For arc mark, the primary (outer) radius in pixels.\n\nFor text marks, polar coordinate radial offset, in pixels, of the text from the origin determined by the `x` and `y` properties.\n\n__Default value:__ `min(plot_width, plot_height)/2`", "minimum": 0 }, "radius2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The secondary (inner) radius in pixels of arc marks.\n\n__Default value:__ `0`", "minimum": 0 }, "shape": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/SymbolShape" }, { "type": "string" } ], "description": "Shape of the point marks. Supported values include:\n- plotting shapes: `\"circle\"`, `\"square\"`, `\"cross\"`, `\"diamond\"`, `\"triangle-up\"`, `\"triangle-down\"`, `\"triangle-right\"`, or `\"triangle-left\"`.\n- the line symbol `\"stroke\"`\n- centered directional shapes `\"arrow\"`, `\"wedge\"`, or `\"triangle\"`\n- a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) (For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.)\n\n__Default value:__ `\"circle\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "size": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default size for marks.\n- For `point`/`circle`/`square`, this represents the pixel area of the marks. Note that this value sets the area of the symbol; the side lengths will increase with the square root of this value.\n- For `bar`, this represents the band size of the bar, in pixels.\n- For `text`, this represents the font size, in pixels.\n\n__Default value:__\n- `30` for point, circle, square marks; width/height's `step`\n- `2` for bar marks with discrete dimensions;\n- `5` for bar marks with continuous dimensions;\n- `11` for text marks.", "minimum": 0 }, "smooth": { "anyOf": [ { "description": "A boolean flag (default true) indicating if the image should be smoothed when resized. If false, individual pixels should be scaled directly rather than interpolated with smoothing. For SVG rendering, this option may not work in some browsers due to lack of standardization.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "startAngle": { "anyOf": [ { "description": "The start angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "stroke": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default stroke color. This property has higher precedence than `config.color`. Set to `null` to remove stroke.\n\n__Default value:__ (None)" }, "strokeCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for line ending style. One of `\"butt\"`, `\"round\"`, or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "An array of alternating stroke, space lengths for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDashOffset": { "anyOf": [ { "description": "The offset (in pixels) into which to begin drawing with the stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeJoin": { "anyOf": [ { "$ref": "#/definitions/StrokeJoin", "description": "The stroke line join method. One of `\"miter\"`, `\"round\"` or `\"bevel\"`.\n\n__Default value:__ `\"miter\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeMiterLimit": { "anyOf": [ { "description": "The miter limit at which to bevel a line join.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOffset": { "anyOf": [ { "description": "The offset in pixels at which to draw the group stroke and fill. If unspecified, the default behavior is to dynamically offset stroked groups such that 1 pixel stroke widths align with the pixel grid.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOpacity": { "anyOf": [ { "description": "The stroke opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "The stroke width, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "tension": { "anyOf": [ { "description": "Depending on the interpolation type, sets the tension parameter (for line and area marks).", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "text": { "anyOf": [ { "$ref": "#/definitions/Text", "description": "Placeholder text if the `text` channel is not specified" }, { "$ref": "#/definitions/ExprRef" } ] }, "theta": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "- For arc marks, the arc length in radians if theta2 is not specified, otherwise the start arc angle. (A value of 0 indicates up or “north”, increasing values proceed clockwise.)\n\n- For text marks, polar coordinate angle in radians.", "maximum": 360, "minimum": 0 }, "theta2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The end angle of arc marks in radians. A value of 0 indicates up or “north”, increasing values proceed clockwise." }, "timeUnitBandPosition": { "description": "Default relative band position for a time unit. If set to `0`, the marks will be positioned at the beginning of the time unit band step. If set to `0.5`, the marks will be positioned in the middle of the time unit band step.", "type": "number" }, "timeUnitBandSize": { "description": "Default relative band size for a time unit. If set to `1`, the bandwidth of the marks will be equal to the time unit band step. If set to `0.5`, bandwidth of the marks will be half of the time unit band step.", "type": "number" }, "tooltip": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "type": "boolean" }, { "$ref": "#/definitions/TooltipContent" }, { "$ref": "#/definitions/ExprRef" }, { "type": "null" } ], "description": "The tooltip text string to show upon mouse hover or an object defining which fields should the tooltip be derived from.\n\n- If `tooltip` is `true` or `{\"content\": \"encoding\"}`, then all fields from `encoding` will be used.\n- If `tooltip` is `{\"content\": \"data\"}`, then all fields that appear in the highlighted data point will be used.\n- If set to `null` or `false`, then no tooltip will be used.\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite.\n\n__Default value:__ `null`" }, "url": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "The URL of the image file for image marks." }, { "$ref": "#/definitions/ExprRef" } ] }, "width": { "anyOf": [ { "description": "Width of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "x": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X coordinates of the marks, or width of horizontal `\"bar\"` and `\"area\"` without specified `x2` or `width`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "y": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y coordinates of the marks, or height of vertical `\"bar\"` and `\"area\"` without specified `y2` or `height`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." } }, "type": "object" }, "ArgmaxDef": { "additionalProperties": false, "properties": { "argmax": { "$ref": "#/definitions/FieldName" } }, "required": [ "argmax" ], "type": "object" }, "ArgminDef": { "additionalProperties": false, "properties": { "argmin": { "$ref": "#/definitions/FieldName" } }, "required": [ "argmin" ], "type": "object" }, "AutoSizeParams": { "additionalProperties": false, "properties": { "contains": { "description": "Determines how size calculation should be performed, one of `\"content\"` or `\"padding\"`. The default setting (`\"content\"`) interprets the width and height settings as the data rectangle (plotting) dimensions, to which padding is then added. In contrast, the `\"padding\"` setting includes the padding within the view size calculations, such that the width and height settings indicate the **total** intended size of the view.\n\n__Default value__: `\"content\"`", "enum": [ "content", "padding" ], "type": "string" }, "resize": { "description": "A boolean flag indicating if autosize layout should be re-calculated on every view update.\n\n__Default value__: `false`", "type": "boolean" }, "type": { "$ref": "#/definitions/AutosizeType", "description": "The sizing format type. One of `\"pad\"`, `\"fit\"`, `\"fit-x\"`, `\"fit-y\"`, or `\"none\"`. See the [autosize type](https://vega.github.io/vega-lite/docs/size.html#autosize) documentation for descriptions of each.\n\n__Default value__: `\"pad\"`" } }, "type": "object" }, "AutosizeType": { "enum": [ "pad", "none", "fit", "fit-x", "fit-y" ], "type": "string" }, "Axis": { "additionalProperties": false, "properties": { "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG group, removing the axis from the ARIA accessibility tree.\n\n__Default value:__ `true`", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "bandPosition": { "anyOf": [ { "description": "An interpolation fraction indicating where, for `band` scales, axis ticks should be positioned. A value of `0` places ticks at the left edge of their bands. A value of `0.5` places ticks in the middle of their bands.\n\n __Default value:__ `0.5`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of this axis for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If the `aria` property is true, for SVG output the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute) will be set to this description. If the description is unspecified it will be automatically generated.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "domain": { "description": "A boolean flag indicating if the domain (the axis baseline) should be included as part of the axis.\n\n__Default value:__ `true`", "type": "boolean" }, "domainCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for the domain line's ending style. One of `\"butt\"`, `\"round\"` or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "domainColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Color of axis domain line.\n\n__Default value:__ `\"gray\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "domainDash": { "anyOf": [ { "description": "An array of alternating [stroke, space] lengths for dashed domain lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "domainDashOffset": { "anyOf": [ { "description": "The pixel offset at which to start drawing with the domain dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "domainOpacity": { "anyOf": [ { "description": "Opacity of the axis domain line.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "domainWidth": { "anyOf": [ { "description": "Stroke width of axis domain line\n\n__Default value:__ `1`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "grid": { "description": "A boolean flag indicating if grid lines should be included as part of the axis\n\n__Default value:__ `true` for [continuous scales](https://vega.github.io/vega-lite/docs/scale.html#continuous) that are not binned; otherwise, `false`.", "type": "boolean" }, "gridCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for grid lines' ending style. One of `\"butt\"`, `\"round\"` or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "gridColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Color of gridlines.\n\n__Default value:__ `\"lightGray\"`." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisColor" } ] }, "gridDash": { "anyOf": [ { "description": "An array of alternating [stroke, space] lengths for dashed grid lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumberArray" } ] }, "gridDashOffset": { "anyOf": [ { "description": "The pixel offset at which to start drawing with the grid dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "gridOpacity": { "anyOf": [ { "description": "The stroke opacity of grid (value between [0,1])\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "gridWidth": { "anyOf": [ { "description": "The grid width, in pixels.\n\n__Default value:__ `1`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "labelAlign": { "anyOf": [ { "$ref": "#/definitions/Align", "description": "Horizontal text alignment of axis tick labels, overriding the default setting for the current axis orientation." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisLabelAlign" } ] }, "labelAngle": { "anyOf": [ { "description": "The rotation angle of the axis labels.\n\n__Default value:__ `-90` for nominal and ordinal fields; `0` otherwise.", "maximum": 360, "minimum": -360, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline", "description": "Vertical text baseline of axis tick labels, overriding the default setting for the current axis orientation. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the *lineHeight* rather than *fontSize* alone." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisLabelBaseline" } ] }, "labelBound": { "anyOf": [ { "description": "Indicates if labels should be hidden if they exceed the axis range. If `false` (the default) no bounds overlap analysis is performed. If `true`, labels will be hidden if they exceed the axis range by more than 1 pixel. If this property is a number, it specifies the pixel tolerance: the maximum amount by which a label bounding box may exceed the axis range.\n\n__Default value:__ `false`.", "type": [ "number", "boolean" ] }, { "$ref": "#/definitions/ExprRef" } ] }, "labelColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the tick label, can be in hex color code or regular color name." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisColor" } ] }, "labelExpr": { "description": "[Vega expression](https://vega.github.io/vega/docs/expressions/) for customizing labels.\n\n__Note:__ The label text and value can be assessed via the `label` and `value` properties of the axis's backing `datum` object.", "type": "string" }, "labelFlush": { "description": "Indicates if the first and last axis labels should be aligned flush with the scale range. Flush alignment for a horizontal axis will left-align the first label and right-align the last label. For vertical axes, bottom and top text baselines are applied instead. If this property is a number, it also indicates the number of pixels by which to offset the first and last labels; for example, a value of 2 will flush-align the first and last labels and also push them 2 pixels outward from the center of the axis. The additional adjustment can sometimes help the labels better visually group with corresponding axis ticks.\n\n__Default value:__ `true` for axis of a continuous x-scale. Otherwise, `false`.", "type": [ "boolean", "number" ] }, "labelFlushOffset": { "anyOf": [ { "description": "Indicates the number of pixels by which to offset flush-adjusted labels. For example, a value of `2` will push flush-adjusted labels 2 pixels outward from the center of the axis. Offsets can help the labels better visually group with corresponding axis ticks.\n\n__Default value:__ `0`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelFont": { "anyOf": [ { "description": "The font of the tick label.", "type": "string" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisString" } ] }, "labelFontSize": { "anyOf": [ { "description": "The font size of the label, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "labelFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "Font style of the title." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisLabelFontStyle" } ] }, "labelFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "Font weight of axis tick labels." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisLabelFontWeight" } ] }, "labelLimit": { "anyOf": [ { "description": "Maximum allowed pixel width of axis tick labels.\n\n__Default value:__ `180`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelLineHeight": { "anyOf": [ { "description": "Line height in pixels for multi-line label text or label text with `\"line-top\"` or `\"line-bottom\"` baseline.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelOffset": { "anyOf": [ { "description": "Position offset in pixels to apply to labels, in addition to tickOffset.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "labelOpacity": { "anyOf": [ { "description": "The opacity of the labels.", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "labelOverlap": { "anyOf": [ { "$ref": "#/definitions/LabelOverlap" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The strategy to use for resolving overlap of axis labels. If `false` (the default), no overlap reduction is attempted. If set to `true` or `\"parity\"`, a strategy of removing every other label is used (this works well for standard linear axes). If set to `\"greedy\"`, a linear scan of the labels is performed, removing any labels that overlaps with the last visible label (this often works better for log-scaled axes).\n\n__Default value:__ `true` for non-nominal fields with non-log scales; `\"greedy\"` for log scales; otherwise `false`." }, "labelPadding": { "anyOf": [ { "description": "The padding in pixels between labels and ticks.\n\n__Default value:__ `2`", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "labelSeparation": { "anyOf": [ { "description": "The minimum separation that must be between label bounding boxes for them to be considered non-overlapping (default `0`). This property is ignored if *labelOverlap* resolution is not enabled.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labels": { "description": "A boolean flag indicating if labels should be included as part of the axis.\n\n__Default value:__ `true`.", "type": "boolean" }, "maxExtent": { "anyOf": [ { "description": "The maximum extent in pixels that axis ticks and labels should use. This determines a maximum offset value for axis titles.\n\n__Default value:__ `undefined`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "minExtent": { "anyOf": [ { "description": "The minimum extent in pixels that axis ticks and labels should use. This determines a minimum offset value for axis titles.\n\n__Default value:__ `30` for y-axis; `undefined` for x-axis.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "offset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The offset, in pixels, by which to displace the axis from the edge of the enclosing group or data rectangle.\n\n__Default value:__ derived from the [axis config](https://vega.github.io/vega-lite/docs/config.html#facet-scale-config)'s `offset` (`0` by default)" }, "orient": { "anyOf": [ { "$ref": "#/definitions/AxisOrient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The orientation of the axis. One of `\"top\"`, `\"bottom\"`, `\"left\"` or `\"right\"`. The orientation can be used to further specialize the axis type (e.g., a y-axis oriented towards the right edge of the chart).\n\n__Default value:__ `\"bottom\"` for x-axes and `\"left\"` for y-axes." }, "position": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The anchor position of the axis in pixels. For x-axes with top or bottom orientation, this sets the axis group x coordinate. For y-axes with left or right orientation, this sets the axis group y coordinate.\n\n__Default value__: `0`" }, "style": { "anyOf": [ { "type": "string" }, { "items": { "type": "string" }, "type": "array" } ], "description": "A string or array of strings indicating the name of custom styles to apply to the axis. A style is a named collection of axis property defined within the [style configuration](https://vega.github.io/vega-lite/docs/mark.html#style-config). If style is an array, later styles will override earlier styles.\n\n__Default value:__ (none) __Note:__ Any specified style will augment the default style. For example, an x-axis mark with `\"style\": \"foo\"` will use `config.axisX` and `config.style.foo` (the specified style `\"foo\"` has higher precedence)." }, "tickBand": { "anyOf": [ { "description": "For band scales, indicates if ticks and grid lines should be placed at the `\"center\"` of a band (default) or at the band `\"extent\"`s to indicate intervals", "enum": [ "center", "extent" ], "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "tickCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for the tick lines' ending style. One of `\"butt\"`, `\"round\"` or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "tickColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the axis's tick.\n\n__Default value:__ `\"gray\"`" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisColor" } ] }, "tickCount": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/TimeInterval" }, { "$ref": "#/definitions/TimeIntervalStep" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A desired number of ticks, for axes visualizing quantitative scales. The resulting number may be different so that values are \"nice\" (multiples of 2, 5, 10) and lie within the underlying scale's range.\n\nFor scales of type `\"time\"` or `\"utc\"`, the tick count can instead be a time interval specifier. Legal string values are `\"millisecond\"`, `\"second\"`, `\"minute\"`, `\"hour\"`, `\"day\"`, `\"week\"`, `\"month\"`, and `\"year\"`. Alternatively, an object-valued interval specifier of the form `{\"interval\": \"month\", \"step\": 3}` includes a desired number of interval steps. Here, ticks are generated for each quarter (Jan, Apr, Jul, Oct) boundary.\n\n__Default value__: Determine using a formula `ceil(width/40)` for x and `ceil(height/40)` for y.", "minimum": 0 }, "tickDash": { "anyOf": [ { "description": "An array of alternating [stroke, space] lengths for dashed tick mark lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumberArray" } ] }, "tickDashOffset": { "anyOf": [ { "description": "The pixel offset at which to start drawing with the tick mark dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "tickExtra": { "description": "Boolean flag indicating if an extra axis tick should be added for the initial position of the axis. This flag is useful for styling axes for `band` scales such that ticks are placed on band boundaries rather in the middle of a band. Use in conjunction with `\"bandPosition\": 1` and an axis `\"padding\"` value of `0`.", "type": "boolean" }, "tickMinStep": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The minimum desired step between axis ticks, in terms of scale domain values. For example, a value of `1` indicates that ticks should not be less than 1 unit apart. If `tickMinStep` is specified, the `tickCount` value will be adjusted, if necessary, to enforce the minimum step value." }, "tickOffset": { "anyOf": [ { "description": "Position offset in pixels to apply to ticks, labels, and gridlines.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "tickOpacity": { "anyOf": [ { "description": "Opacity of the ticks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "tickRound": { "description": "Boolean flag indicating if pixel position values should be rounded to the nearest integer.\n\n__Default value:__ `true`", "type": "boolean" }, "tickSize": { "anyOf": [ { "description": "The size in pixels of axis ticks.\n\n__Default value:__ `5`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "tickWidth": { "anyOf": [ { "description": "The width, in pixels, of ticks.\n\n__Default value:__ `1`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "ticks": { "description": "Boolean value that determines whether the axis should include ticks.\n\n__Default value:__ `true`", "type": "boolean" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "titleAlign": { "anyOf": [ { "$ref": "#/definitions/Align", "description": "Horizontal text alignment of axis titles." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleAnchor": { "anyOf": [ { "$ref": "#/definitions/TitleAnchor", "description": "Text anchor position for placing axis titles." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleAngle": { "anyOf": [ { "description": "Angle in degrees of axis titles.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline", "description": "Vertical text baseline for axis titles. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the *lineHeight* rather than *fontSize* alone." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Color of the title, can be in hex color code or regular color name." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFont": { "anyOf": [ { "description": "Font of the title. (e.g., `\"Helvetica Neue\"`).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontSize": { "anyOf": [ { "description": "Font size of the title.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "Font style of the title." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "Font weight of the title. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleLimit": { "anyOf": [ { "description": "Maximum allowed pixel width of axis titles.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleLineHeight": { "anyOf": [ { "description": "Line height in pixels for multi-line title text or title text with `\"line-top\"` or `\"line-bottom\"` baseline.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleOpacity": { "anyOf": [ { "description": "Opacity of the axis title.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titlePadding": { "anyOf": [ { "description": "The padding, in pixels, between title and axis.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleX": { "anyOf": [ { "description": "X-coordinate of the axis title relative to the axis group.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleY": { "anyOf": [ { "description": "Y-coordinate of the axis title relative to the axis group.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "translate": { "anyOf": [ { "description": "Coordinate space translation offset for axis layout. By default, axes are translated by a 0.5 pixel offset for both the x and y coordinates in order to align stroked lines with the pixel grid. However, for vector graphics output these pixel-specific adjustments may be undesirable, in which case translate can be changed (for example, to zero).\n\n__Default value:__ `0.5`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "values": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "items": { "type": "string" }, "type": "array" }, { "items": { "type": "boolean" }, "type": "array" }, { "items": { "$ref": "#/definitions/DateTime" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Explicitly set the visible axis tick values." }, "zindex": { "description": "A non-negative integer indicating the z-index of the axis. If zindex is 0, axes should be drawn behind all chart elements. To put them in front, set `zindex` to `1` or more.\n\n__Default value:__ `0` (behind the marks).", "minimum": 0, "type": "number" } }, "type": "object" }, "AxisConfig": { "additionalProperties": false, "properties": { "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG group, removing the axis from the ARIA accessibility tree.\n\n__Default value:__ `true`", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "bandPosition": { "anyOf": [ { "description": "An interpolation fraction indicating where, for `band` scales, axis ticks should be positioned. A value of `0` places ticks at the left edge of their bands. A value of `0.5` places ticks in the middle of their bands.\n\n __Default value:__ `0.5`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of this axis for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If the `aria` property is true, for SVG output the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute) will be set to this description. If the description is unspecified it will be automatically generated.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "disable": { "description": "Disable axis by default.", "type": "boolean" }, "domain": { "description": "A boolean flag indicating if the domain (the axis baseline) should be included as part of the axis.\n\n__Default value:__ `true`", "type": "boolean" }, "domainCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for the domain line's ending style. One of `\"butt\"`, `\"round\"` or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "domainColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Color of axis domain line.\n\n__Default value:__ `\"gray\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "domainDash": { "anyOf": [ { "description": "An array of alternating [stroke, space] lengths for dashed domain lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "domainDashOffset": { "anyOf": [ { "description": "The pixel offset at which to start drawing with the domain dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "domainOpacity": { "anyOf": [ { "description": "Opacity of the axis domain line.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "domainWidth": { "anyOf": [ { "description": "Stroke width of axis domain line\n\n__Default value:__ `1`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "grid": { "description": "A boolean flag indicating if grid lines should be included as part of the axis\n\n__Default value:__ `true` for [continuous scales](https://vega.github.io/vega-lite/docs/scale.html#continuous) that are not binned; otherwise, `false`.", "type": "boolean" }, "gridCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for grid lines' ending style. One of `\"butt\"`, `\"round\"` or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "gridColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Color of gridlines.\n\n__Default value:__ `\"lightGray\"`." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisColor" } ] }, "gridDash": { "anyOf": [ { "description": "An array of alternating [stroke, space] lengths for dashed grid lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumberArray" } ] }, "gridDashOffset": { "anyOf": [ { "description": "The pixel offset at which to start drawing with the grid dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "gridOpacity": { "anyOf": [ { "description": "The stroke opacity of grid (value between [0,1])\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "gridWidth": { "anyOf": [ { "description": "The grid width, in pixels.\n\n__Default value:__ `1`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "labelAlign": { "anyOf": [ { "$ref": "#/definitions/Align", "description": "Horizontal text alignment of axis tick labels, overriding the default setting for the current axis orientation." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisLabelAlign" } ] }, "labelAngle": { "anyOf": [ { "description": "The rotation angle of the axis labels.\n\n__Default value:__ `-90` for nominal and ordinal fields; `0` otherwise.", "maximum": 360, "minimum": -360, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline", "description": "Vertical text baseline of axis tick labels, overriding the default setting for the current axis orientation. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the *lineHeight* rather than *fontSize* alone." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisLabelBaseline" } ] }, "labelBound": { "anyOf": [ { "description": "Indicates if labels should be hidden if they exceed the axis range. If `false` (the default) no bounds overlap analysis is performed. If `true`, labels will be hidden if they exceed the axis range by more than 1 pixel. If this property is a number, it specifies the pixel tolerance: the maximum amount by which a label bounding box may exceed the axis range.\n\n__Default value:__ `false`.", "type": [ "number", "boolean" ] }, { "$ref": "#/definitions/ExprRef" } ] }, "labelColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the tick label, can be in hex color code or regular color name." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisColor" } ] }, "labelExpr": { "description": "[Vega expression](https://vega.github.io/vega/docs/expressions/) for customizing labels.\n\n__Note:__ The label text and value can be assessed via the `label` and `value` properties of the axis's backing `datum` object.", "type": "string" }, "labelFlush": { "description": "Indicates if the first and last axis labels should be aligned flush with the scale range. Flush alignment for a horizontal axis will left-align the first label and right-align the last label. For vertical axes, bottom and top text baselines are applied instead. If this property is a number, it also indicates the number of pixels by which to offset the first and last labels; for example, a value of 2 will flush-align the first and last labels and also push them 2 pixels outward from the center of the axis. The additional adjustment can sometimes help the labels better visually group with corresponding axis ticks.\n\n__Default value:__ `true` for axis of a continuous x-scale. Otherwise, `false`.", "type": [ "boolean", "number" ] }, "labelFlushOffset": { "anyOf": [ { "description": "Indicates the number of pixels by which to offset flush-adjusted labels. For example, a value of `2` will push flush-adjusted labels 2 pixels outward from the center of the axis. Offsets can help the labels better visually group with corresponding axis ticks.\n\n__Default value:__ `0`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelFont": { "anyOf": [ { "description": "The font of the tick label.", "type": "string" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisString" } ] }, "labelFontSize": { "anyOf": [ { "description": "The font size of the label, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "labelFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "Font style of the title." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisLabelFontStyle" } ] }, "labelFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "Font weight of axis tick labels." }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisLabelFontWeight" } ] }, "labelLimit": { "anyOf": [ { "description": "Maximum allowed pixel width of axis tick labels.\n\n__Default value:__ `180`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelLineHeight": { "anyOf": [ { "description": "Line height in pixels for multi-line label text or label text with `\"line-top\"` or `\"line-bottom\"` baseline.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelOffset": { "anyOf": [ { "description": "Position offset in pixels to apply to labels, in addition to tickOffset.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "labelOpacity": { "anyOf": [ { "description": "The opacity of the labels.", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "labelOverlap": { "anyOf": [ { "$ref": "#/definitions/LabelOverlap" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The strategy to use for resolving overlap of axis labels. If `false` (the default), no overlap reduction is attempted. If set to `true` or `\"parity\"`, a strategy of removing every other label is used (this works well for standard linear axes). If set to `\"greedy\"`, a linear scan of the labels is performed, removing any labels that overlaps with the last visible label (this often works better for log-scaled axes).\n\n__Default value:__ `true` for non-nominal fields with non-log scales; `\"greedy\"` for log scales; otherwise `false`." }, "labelPadding": { "anyOf": [ { "description": "The padding in pixels between labels and ticks.\n\n__Default value:__ `2`", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "labelSeparation": { "anyOf": [ { "description": "The minimum separation that must be between label bounding boxes for them to be considered non-overlapping (default `0`). This property is ignored if *labelOverlap* resolution is not enabled.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labels": { "description": "A boolean flag indicating if labels should be included as part of the axis.\n\n__Default value:__ `true`.", "type": "boolean" }, "maxExtent": { "anyOf": [ { "description": "The maximum extent in pixels that axis ticks and labels should use. This determines a maximum offset value for axis titles.\n\n__Default value:__ `undefined`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "minExtent": { "anyOf": [ { "description": "The minimum extent in pixels that axis ticks and labels should use. This determines a minimum offset value for axis titles.\n\n__Default value:__ `30` for y-axis; `undefined` for x-axis.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "offset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The offset, in pixels, by which to displace the axis from the edge of the enclosing group or data rectangle.\n\n__Default value:__ derived from the [axis config](https://vega.github.io/vega-lite/docs/config.html#facet-scale-config)'s `offset` (`0` by default)" }, "orient": { "anyOf": [ { "$ref": "#/definitions/AxisOrient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The orientation of the axis. One of `\"top\"`, `\"bottom\"`, `\"left\"` or `\"right\"`. The orientation can be used to further specialize the axis type (e.g., a y-axis oriented towards the right edge of the chart).\n\n__Default value:__ `\"bottom\"` for x-axes and `\"left\"` for y-axes." }, "position": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The anchor position of the axis in pixels. For x-axes with top or bottom orientation, this sets the axis group x coordinate. For y-axes with left or right orientation, this sets the axis group y coordinate.\n\n__Default value__: `0`" }, "style": { "anyOf": [ { "type": "string" }, { "items": { "type": "string" }, "type": "array" } ], "description": "A string or array of strings indicating the name of custom styles to apply to the axis. A style is a named collection of axis property defined within the [style configuration](https://vega.github.io/vega-lite/docs/mark.html#style-config). If style is an array, later styles will override earlier styles.\n\n__Default value:__ (none) __Note:__ Any specified style will augment the default style. For example, an x-axis mark with `\"style\": \"foo\"` will use `config.axisX` and `config.style.foo` (the specified style `\"foo\"` has higher precedence)." }, "tickBand": { "anyOf": [ { "description": "For band scales, indicates if ticks and grid lines should be placed at the `\"center\"` of a band (default) or at the band `\"extent\"`s to indicate intervals", "enum": [ "center", "extent" ], "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "tickCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for the tick lines' ending style. One of `\"butt\"`, `\"round\"` or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "tickColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the axis's tick.\n\n__Default value:__ `\"gray\"`" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisColor" } ] }, "tickCount": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/TimeInterval" }, { "$ref": "#/definitions/TimeIntervalStep" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A desired number of ticks, for axes visualizing quantitative scales. The resulting number may be different so that values are \"nice\" (multiples of 2, 5, 10) and lie within the underlying scale's range.\n\nFor scales of type `\"time\"` or `\"utc\"`, the tick count can instead be a time interval specifier. Legal string values are `\"millisecond\"`, `\"second\"`, `\"minute\"`, `\"hour\"`, `\"day\"`, `\"week\"`, `\"month\"`, and `\"year\"`. Alternatively, an object-valued interval specifier of the form `{\"interval\": \"month\", \"step\": 3}` includes a desired number of interval steps. Here, ticks are generated for each quarter (Jan, Apr, Jul, Oct) boundary.\n\n__Default value__: Determine using a formula `ceil(width/40)` for x and `ceil(height/40)` for y.", "minimum": 0 }, "tickDash": { "anyOf": [ { "description": "An array of alternating [stroke, space] lengths for dashed tick mark lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumberArray" } ] }, "tickDashOffset": { "anyOf": [ { "description": "The pixel offset at which to start drawing with the tick mark dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "tickExtra": { "description": "Boolean flag indicating if an extra axis tick should be added for the initial position of the axis. This flag is useful for styling axes for `band` scales such that ticks are placed on band boundaries rather in the middle of a band. Use in conjunction with `\"bandPosition\": 1` and an axis `\"padding\"` value of `0`.", "type": "boolean" }, "tickMinStep": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The minimum desired step between axis ticks, in terms of scale domain values. For example, a value of `1` indicates that ticks should not be less than 1 unit apart. If `tickMinStep` is specified, the `tickCount` value will be adjusted, if necessary, to enforce the minimum step value." }, "tickOffset": { "anyOf": [ { "description": "Position offset in pixels to apply to ticks, labels, and gridlines.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "tickOpacity": { "anyOf": [ { "description": "Opacity of the ticks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "tickRound": { "description": "Boolean flag indicating if pixel position values should be rounded to the nearest integer.\n\n__Default value:__ `true`", "type": "boolean" }, "tickSize": { "anyOf": [ { "description": "The size in pixels of axis ticks.\n\n__Default value:__ `5`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "tickWidth": { "anyOf": [ { "description": "The width, in pixels, of ticks.\n\n__Default value:__ `1`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ConditionalAxisNumber" } ] }, "ticks": { "description": "Boolean value that determines whether the axis should include ticks.\n\n__Default value:__ `true`", "type": "boolean" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "titleAlign": { "anyOf": [ { "$ref": "#/definitions/Align", "description": "Horizontal text alignment of axis titles." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleAnchor": { "anyOf": [ { "$ref": "#/definitions/TitleAnchor", "description": "Text anchor position for placing axis titles." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleAngle": { "anyOf": [ { "description": "Angle in degrees of axis titles.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline", "description": "Vertical text baseline for axis titles. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the *lineHeight* rather than *fontSize* alone." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Color of the title, can be in hex color code or regular color name." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFont": { "anyOf": [ { "description": "Font of the title. (e.g., `\"Helvetica Neue\"`).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontSize": { "anyOf": [ { "description": "Font size of the title.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "Font style of the title." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "Font weight of the title. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleLimit": { "anyOf": [ { "description": "Maximum allowed pixel width of axis titles.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleLineHeight": { "anyOf": [ { "description": "Line height in pixels for multi-line title text or title text with `\"line-top\"` or `\"line-bottom\"` baseline.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleOpacity": { "anyOf": [ { "description": "Opacity of the axis title.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titlePadding": { "anyOf": [ { "description": "The padding, in pixels, between title and axis.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleX": { "anyOf": [ { "description": "X-coordinate of the axis title relative to the axis group.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleY": { "anyOf": [ { "description": "Y-coordinate of the axis title relative to the axis group.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "translate": { "anyOf": [ { "description": "Coordinate space translation offset for axis layout. By default, axes are translated by a 0.5 pixel offset for both the x and y coordinates in order to align stroked lines with the pixel grid. However, for vector graphics output these pixel-specific adjustments may be undesirable, in which case translate can be changed (for example, to zero).\n\n__Default value:__ `0.5`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "values": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "items": { "type": "string" }, "type": "array" }, { "items": { "type": "boolean" }, "type": "array" }, { "items": { "$ref": "#/definitions/DateTime" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Explicitly set the visible axis tick values." }, "zindex": { "description": "A non-negative integer indicating the z-index of the axis. If zindex is 0, axes should be drawn behind all chart elements. To put them in front, set `zindex` to `1` or more.\n\n__Default value:__ `0` (behind the marks).", "minimum": 0, "type": "number" } }, "type": "object" }, "AxisOrient": { "enum": [ "top", "bottom", "left", "right" ], "type": "string" }, "AxisResolveMap": { "additionalProperties": false, "properties": { "x": { "$ref": "#/definitions/ResolveMode" }, "y": { "$ref": "#/definitions/ResolveMode" } }, "type": "object" }, "BBox": { "anyOf": [ { "items": { "type": "number" }, "maxItems": 4, "minItems": 4, "type": "array" }, { "items": { "type": "number" }, "maxItems": 6, "minItems": 6, "type": "array" } ], "description": "Bounding box https://tools.ietf.org/html/rfc7946#section-5" }, "BarConfig": { "additionalProperties": false, "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The horizontal alignment of the text or ranged marks (area, bar, image, rect, rule). One of `\"left\"`, `\"right\"`, `\"center\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "angle": { "anyOf": [ { "description": "The rotation angle of the text, in degrees.", "maximum": 360, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG element, removing the mark item from the ARIA accessibility tree.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRole": { "anyOf": [ { "description": "Sets the type of user interface element of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"role\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRoleDescription": { "anyOf": [ { "description": "A human-readable, author-localized description for the role of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"aria-roledescription\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "aspect": { "anyOf": [ { "description": "Whether to keep aspect ratio of image marks.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "baseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For text marks, the vertical text baseline. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, `\"line-bottom\"`, or an expression reference that provides one of the valid values. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `lineHeight` rather than `fontSize` alone.\n\nFor range marks, the vertical alignment of the marks. One of `\"top\"`, `\"middle\"`, `\"bottom\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "binSpacing": { "description": "Offset between bars for binned field. The ideal value for this is either 0 (preferred by statisticians) or 1 (Vega-Lite default, D3 example style).\n\n__Default value:__ `1`", "minimum": 0, "type": "number" }, "blend": { "anyOf": [ { "$ref": "#/definitions/Blend", "description": "The color blend mode for drawing an item on its current background. Any valid [CSS mix-blend-mode](https://developer.mozilla.org/en-US/docs/Web/CSS/mix-blend-mode) value can be used.\n\n__Default value: `\"source-over\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "continuousBandSize": { "description": "The default size of the bars on continuous scales.\n\n__Default value:__ `5`", "minimum": 0, "type": "number" }, "cornerRadius": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles or arcs' corners.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusEnd": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "- For vertical bars, top-left and top-right corner radius.\n\n- For horizontal bars, top-right and bottom-right corner radius." }, "cornerRadiusTopLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cursor": { "anyOf": [ { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the mark. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "dir": { "anyOf": [ { "$ref": "#/definitions/TextDirection", "description": "The direction of the text. One of `\"ltr\"` (left-to-right) or `\"rtl\"` (right-to-left). This property determines on which side is truncated in response to the limit parameter.\n\n__Default value:__ `\"ltr\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "discreteBandSize": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RelativeBandSize" } ], "description": "The default size of the bars with discrete dimensions. If unspecified, the default size is `step-2`, which provides 2 pixel offset between bars.", "minimum": 0 }, "dx": { "anyOf": [ { "description": "The horizontal offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "dy": { "anyOf": [ { "description": "The vertical offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "ellipsis": { "anyOf": [ { "description": "The ellipsis string for text truncated in response to the limit parameter.\n\n__Default value:__ `\"…\"`", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "endAngle": { "anyOf": [ { "description": "The end angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fill": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default fill color. This property has higher precedence than `config.color`. Set to `null` to remove fill.\n\n__Default value:__ (None)" }, "fillOpacity": { "anyOf": [ { "description": "The fill opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "filled": { "description": "Whether the mark's color should be used as fill color instead of stroke color.\n\n__Default value:__ `false` for all `point`, `line`, and `rule` marks as well as `geoshape` marks for [`graticule`](https://vega.github.io/vega-lite/docs/data.html#graticule) data sources; otherwise, `true`.\n\n__Note:__ This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).", "type": "boolean" }, "font": { "anyOf": [ { "description": "The typeface to set the text in (e.g., `\"Helvetica Neue\"`).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontSize": { "anyOf": [ { "description": "The font size, in pixels.\n\n__Default value:__ `11`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style (e.g., `\"italic\"`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "fontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "height": { "anyOf": [ { "description": "Height of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "href": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "A URL to load upon mouse click. If defined, the mark acts as a hyperlink." }, { "$ref": "#/definitions/ExprRef" } ] }, "innerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The inner radius in pixels of arc marks. `innerRadius` is an alias for `radius2`.\n\n__Default value:__ `0`", "minimum": 0 }, "interpolate": { "anyOf": [ { "$ref": "#/definitions/Interpolate", "description": "The line interpolation method to use for line and area marks. One of the following:\n- `\"linear\"`: piecewise linear segments, as in a polyline.\n- `\"linear-closed\"`: close the linear segments to form a polygon.\n- `\"step\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"step-before\"`: alternate between vertical and horizontal segments, as in a step function.\n- `\"step-after\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"basis\"`: a B-spline, with control point duplication on the ends.\n- `\"basis-open\"`: an open B-spline; may not intersect the start or end.\n- `\"basis-closed\"`: a closed B-spline, as in a loop.\n- `\"cardinal\"`: a Cardinal spline, with control point duplication on the ends.\n- `\"cardinal-open\"`: an open Cardinal spline; may not intersect the start or end, but will intersect other control points.\n- `\"cardinal-closed\"`: a closed Cardinal spline, as in a loop.\n- `\"bundle\"`: equivalent to basis, except the tension parameter is used to straighten the spline.\n- `\"monotone\"`: cubic interpolation that preserves monotonicity in y." }, { "$ref": "#/definitions/ExprRef" } ] }, "invalid": { "description": "Defines how Vega-Lite should handle marks for invalid values (`null` and `NaN`).\n- If set to `\"filter\"` (default), all data items with null values will be skipped (for line, trail, and area marks) or filtered (for other marks).\n- If `null`, all data items are included. In this case, invalid values will be interpreted as zeroes.", "enum": [ "filter", null ], "type": [ "string", "null" ] }, "limit": { "anyOf": [ { "description": "The maximum length of the text mark in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0` -- indicating no limit", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineBreak": { "anyOf": [ { "description": "A delimiter, such as a newline character, upon which to break text strings into multiple lines. This property is ignored if the text is array-valued.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineHeight": { "anyOf": [ { "description": "The line height in pixels (the spacing between subsequent lines of text) for multi-line text marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "minBandSize": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The minimum band size for bar and rectangle marks. __Default value:__ `0.25`" }, "opacity": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The overall opacity (value between [0,1]).\n\n__Default value:__ `0.7` for non-aggregate plots with `point`, `tick`, `circle`, or `square` marks or layered `bar` charts and `1` otherwise.", "maximum": 1, "minimum": 0 }, "order": { "description": "For line and trail marks, this `order` property can be set to `null` or `false` to make the lines use the original order in the data sources.", "type": [ "null", "boolean" ] }, "orient": { "$ref": "#/definitions/Orientation", "description": "The orientation of a non-stacked bar, tick, area, and line charts. The value is either horizontal (default) or vertical.\n- For bar, rule and tick, this determines whether the size of the bar and tick should be applied to x or y dimension.\n- For area, this property determines the orient property of the Vega output.\n- For line and trail marks, this property determines the sort order of the points in the line if `config.sortLineBy` is not specified. For stacked charts, this is always determined by the orientation of the stack; therefore explicitly specified value will be ignored." }, "outerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The outer radius in pixels of arc marks. `outerRadius` is an alias for `radius`.\n\n__Default value:__ `0`", "minimum": 0 }, "padAngle": { "anyOf": [ { "description": "The angular padding applied to sides of the arc, in radians.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "radius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For arc mark, the primary (outer) radius in pixels.\n\nFor text marks, polar coordinate radial offset, in pixels, of the text from the origin determined by the `x` and `y` properties.\n\n__Default value:__ `min(plot_width, plot_height)/2`", "minimum": 0 }, "radius2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The secondary (inner) radius in pixels of arc marks.\n\n__Default value:__ `0`", "minimum": 0 }, "shape": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/SymbolShape" }, { "type": "string" } ], "description": "Shape of the point marks. Supported values include:\n- plotting shapes: `\"circle\"`, `\"square\"`, `\"cross\"`, `\"diamond\"`, `\"triangle-up\"`, `\"triangle-down\"`, `\"triangle-right\"`, or `\"triangle-left\"`.\n- the line symbol `\"stroke\"`\n- centered directional shapes `\"arrow\"`, `\"wedge\"`, or `\"triangle\"`\n- a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) (For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.)\n\n__Default value:__ `\"circle\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "size": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default size for marks.\n- For `point`/`circle`/`square`, this represents the pixel area of the marks. Note that this value sets the area of the symbol; the side lengths will increase with the square root of this value.\n- For `bar`, this represents the band size of the bar, in pixels.\n- For `text`, this represents the font size, in pixels.\n\n__Default value:__\n- `30` for point, circle, square marks; width/height's `step`\n- `2` for bar marks with discrete dimensions;\n- `5` for bar marks with continuous dimensions;\n- `11` for text marks.", "minimum": 0 }, "smooth": { "anyOf": [ { "description": "A boolean flag (default true) indicating if the image should be smoothed when resized. If false, individual pixels should be scaled directly rather than interpolated with smoothing. For SVG rendering, this option may not work in some browsers due to lack of standardization.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "startAngle": { "anyOf": [ { "description": "The start angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "stroke": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default stroke color. This property has higher precedence than `config.color`. Set to `null` to remove stroke.\n\n__Default value:__ (None)" }, "strokeCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for line ending style. One of `\"butt\"`, `\"round\"`, or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "An array of alternating stroke, space lengths for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDashOffset": { "anyOf": [ { "description": "The offset (in pixels) into which to begin drawing with the stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeJoin": { "anyOf": [ { "$ref": "#/definitions/StrokeJoin", "description": "The stroke line join method. One of `\"miter\"`, `\"round\"` or `\"bevel\"`.\n\n__Default value:__ `\"miter\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeMiterLimit": { "anyOf": [ { "description": "The miter limit at which to bevel a line join.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOffset": { "anyOf": [ { "description": "The offset in pixels at which to draw the group stroke and fill. If unspecified, the default behavior is to dynamically offset stroked groups such that 1 pixel stroke widths align with the pixel grid.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOpacity": { "anyOf": [ { "description": "The stroke opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "The stroke width, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "tension": { "anyOf": [ { "description": "Depending on the interpolation type, sets the tension parameter (for line and area marks).", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "text": { "anyOf": [ { "$ref": "#/definitions/Text", "description": "Placeholder text if the `text` channel is not specified" }, { "$ref": "#/definitions/ExprRef" } ] }, "theta": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "- For arc marks, the arc length in radians if theta2 is not specified, otherwise the start arc angle. (A value of 0 indicates up or “north”, increasing values proceed clockwise.)\n\n- For text marks, polar coordinate angle in radians.", "maximum": 360, "minimum": 0 }, "theta2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The end angle of arc marks in radians. A value of 0 indicates up or “north”, increasing values proceed clockwise." }, "timeUnitBandPosition": { "description": "Default relative band position for a time unit. If set to `0`, the marks will be positioned at the beginning of the time unit band step. If set to `0.5`, the marks will be positioned in the middle of the time unit band step.", "type": "number" }, "timeUnitBandSize": { "description": "Default relative band size for a time unit. If set to `1`, the bandwidth of the marks will be equal to the time unit band step. If set to `0.5`, bandwidth of the marks will be half of the time unit band step.", "type": "number" }, "tooltip": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "type": "boolean" }, { "$ref": "#/definitions/TooltipContent" }, { "$ref": "#/definitions/ExprRef" }, { "type": "null" } ], "description": "The tooltip text string to show upon mouse hover or an object defining which fields should the tooltip be derived from.\n\n- If `tooltip` is `true` or `{\"content\": \"encoding\"}`, then all fields from `encoding` will be used.\n- If `tooltip` is `{\"content\": \"data\"}`, then all fields that appear in the highlighted data point will be used.\n- If set to `null` or `false`, then no tooltip will be used.\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite.\n\n__Default value:__ `null`" }, "url": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "The URL of the image file for image marks." }, { "$ref": "#/definitions/ExprRef" } ] }, "width": { "anyOf": [ { "description": "Width of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "x": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X coordinates of the marks, or width of horizontal `\"bar\"` and `\"area\"` without specified `x2` or `width`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "y": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y coordinates of the marks, or height of vertical `\"bar\"` and `\"area\"` without specified `y2` or `height`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." } }, "type": "object" }, "BaseTitleNoValueRefs": { "additionalProperties": false, "properties": { "align": { "$ref": "#/definitions/Align", "description": "Horizontal text alignment for title text. One of `\"left\"`, `\"center\"`, or `\"right\"`." }, "anchor": { "anyOf": [ { "$ref": "#/definitions/TitleAnchor", "description": "The anchor position for placing the title and subtitle text. One of `\"start\"`, `\"middle\"`, or `\"end\"`. For example, with an orientation of top these anchor positions map to a left-, center-, or right-aligned title." }, { "$ref": "#/definitions/ExprRef" } ] }, "angle": { "anyOf": [ { "description": "Angle in degrees of title and subtitle text.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG group, removing the title from the ARIA accessibility tree.\n\n__Default value:__ `true`", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "baseline": { "$ref": "#/definitions/TextBaseline", "description": "Vertical text baseline for title and subtitle text. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the *lineHeight* rather than *fontSize* alone." }, "color": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Text color for title text." }, { "$ref": "#/definitions/ExprRef" } ] }, "dx": { "anyOf": [ { "description": "Delta offset for title and subtitle text x-coordinate.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "dy": { "anyOf": [ { "description": "Delta offset for title and subtitle text y-coordinate.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "font": { "anyOf": [ { "description": "Font name for title text.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontSize": { "anyOf": [ { "description": "Font size in pixels for title text.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "Font style for title text." }, { "$ref": "#/definitions/ExprRef" } ] }, "fontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "Font weight for title text. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "frame": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/TitleFrame" }, { "type": "string" } ], "description": "The reference frame for the anchor position, one of `\"bounds\"` (to anchor relative to the full bounding box) or `\"group\"` (to anchor relative to the group width or height)." }, { "$ref": "#/definitions/ExprRef" } ] }, "limit": { "anyOf": [ { "description": "The maximum allowed length in pixels of title and subtitle text.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineHeight": { "anyOf": [ { "description": "Line height in pixels for multi-line title text or title text with `\"line-top\"` or `\"line-bottom\"` baseline.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "offset": { "anyOf": [ { "description": "The orthogonal offset in pixels by which to displace the title group from its position along the edge of the chart.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "orient": { "anyOf": [ { "$ref": "#/definitions/TitleOrient", "description": "Default title orientation (`\"top\"`, `\"bottom\"`, `\"left\"`, or `\"right\"`)" }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Text color for subtitle text." }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleFont": { "anyOf": [ { "description": "Font name for subtitle text.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleFontSize": { "anyOf": [ { "description": "Font size in pixels for subtitle text.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "Font style for subtitle text." }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "Font weight for subtitle text. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleLineHeight": { "anyOf": [ { "description": "Line height in pixels for multi-line subtitle text.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitlePadding": { "anyOf": [ { "description": "The padding in pixels between title and subtitle text.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "zindex": { "anyOf": [ { "description": "The integer z-index indicating the layering of the title group relative to other axis, mark, and legend groups.\n\n__Default value:__ `0`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] } }, "type": "object" }, "Baseline": { "enum": [ "top", "middle", "bottom" ], "type": "string" }, "BinExtent": { "anyOf": [ { "items": { "type": "number" }, "maxItems": 2, "minItems": 2, "type": "array" }, { "$ref": "#/definitions/ParameterExtent" } ] }, "BinParams": { "additionalProperties": false, "description": "Binning properties or boolean flag for determining whether to bin data or not.", "properties": { "anchor": { "description": "A value in the binned domain at which to anchor the bins, shifting the bin boundaries if necessary to ensure that a boundary aligns with the anchor value.\n\n__Default value:__ the minimum bin extent value", "type": "number" }, "base": { "description": "The number base to use for automatic bin determination (default is base 10).\n\n__Default value:__ `10`", "type": "number" }, "binned": { "description": "When set to `true`, Vega-Lite treats the input data as already binned.", "type": "boolean" }, "divide": { "description": "Scale factors indicating allowable subdivisions. The default value is [5, 2], which indicates that for base 10 numbers (the default base), the method may consider dividing bin sizes by 5 and/or 2. For example, for an initial step size of 10, the method can check if bin sizes of 2 (= 10/5), 5 (= 10/2), or 1 (= 10/(5*2)) might also satisfy the given constraints.\n\n__Default value:__ `[5, 2]`", "items": { "type": "number" }, "maxItems": 2, "minItems": 1, "type": "array" }, "extent": { "$ref": "#/definitions/BinExtent", "description": "A two-element (`[min, max]`) array indicating the range of desired bin values." }, "maxbins": { "description": "Maximum number of bins.\n\n__Default value:__ `6` for `row`, `column` and `shape` channels; `10` for other channels", "minimum": 2, "type": "number" }, "minstep": { "description": "A minimum allowable step size (particularly useful for integer values).", "type": "number" }, "nice": { "description": "If true, attempts to make the bin boundaries use human-friendly boundaries, such as multiples of ten.\n\n__Default value:__ `true`", "type": "boolean" }, "step": { "description": "An exact step size to use between bins.\n\n__Note:__ If provided, options such as maxbins will be ignored.", "type": "number" }, "steps": { "description": "An array of allowable step sizes to choose from.", "items": { "type": "number" }, "minItems": 1, "type": "array" } }, "type": "object" }, "BinTransform": { "additionalProperties": false, "properties": { "as": { "anyOf": [ { "$ref": "#/definitions/FieldName" }, { "items": { "$ref": "#/definitions/FieldName" }, "type": "array" } ], "description": "The output fields at which to write the start and end bin values. This can be either a string or an array of strings with two elements denoting the name for the fields for bin start and bin end respectively. If a single string (e.g., `\"val\"`) is provided, the end field will be `\"val_end\"`." }, "bin": { "anyOf": [ { "const": true, "type": "boolean" }, { "$ref": "#/definitions/BinParams" } ], "description": "An object indicating bin properties, or simply `true` for using default bin parameters." }, "field": { "$ref": "#/definitions/FieldName", "description": "The data field to bin." } }, "required": [ "bin", "field", "as" ], "type": "object" }, "BindCheckbox": { "additionalProperties": false, "properties": { "debounce": { "description": "If defined, delays event handling until the specified milliseconds have elapsed since the last event was fired.", "type": "number" }, "element": { "$ref": "#/definitions/Element", "description": "An optional CSS selector string indicating the parent element to which the input element should be added. By default, all input elements are added within the parent container of the Vega view." }, "input": { "const": "checkbox", "type": "string" }, "name": { "description": "By default, the signal name is used to label input elements. This `name` property can be used instead to specify a custom label for the bound signal.", "type": "string" } }, "required": [ "input" ], "type": "object" }, "BindDirect": { "additionalProperties": false, "properties": { "debounce": { "description": "If defined, delays event handling until the specified milliseconds have elapsed since the last event was fired.", "type": "number" }, "element": { "anyOf": [ { "$ref": "#/definitions/Element" }, { "additionalProperties": false, "type": "object" } ], "description": "An input element that exposes a _value_ property and supports the [EventTarget](https://developer.mozilla.org/en-US/docs/Web/API/EventTarget) interface, or a CSS selector string to such an element. When the element updates and dispatches an event, the _value_ property will be used as the new, bound signal value. When the signal updates independent of the element, the _value_ property will be set to the signal value and a new event will be dispatched on the element." }, "event": { "description": "The event (default `\"input\"`) to listen for to track changes on the external element.", "type": "string" } }, "required": [ "element" ], "type": "object" }, "BindInput": { "additionalProperties": false, "properties": { "autocomplete": { "description": "A hint for form autofill. See the [HTML autocomplete attribute](https://developer.mozilla.org/en-US/docs/Web/HTML/Attributes/autocomplete) for additional information.", "type": "string" }, "debounce": { "description": "If defined, delays event handling until the specified milliseconds have elapsed since the last event was fired.", "type": "number" }, "element": { "$ref": "#/definitions/Element", "description": "An optional CSS selector string indicating the parent element to which the input element should be added. By default, all input elements are added within the parent container of the Vega view." }, "input": { "description": "The type of input element to use. The valid values are `\"checkbox\"`, `\"radio\"`, `\"range\"`, `\"select\"`, and any other legal [HTML form input type](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input).", "type": "string" }, "name": { "description": "By default, the signal name is used to label input elements. This `name` property can be used instead to specify a custom label for the bound signal.", "type": "string" }, "placeholder": { "description": "Text that appears in the form control when it has no value set.", "type": "string" } }, "type": "object" }, "BindRadioSelect": { "additionalProperties": false, "properties": { "debounce": { "description": "If defined, delays event handling until the specified milliseconds have elapsed since the last event was fired.", "type": "number" }, "element": { "$ref": "#/definitions/Element", "description": "An optional CSS selector string indicating the parent element to which the input element should be added. By default, all input elements are added within the parent container of the Vega view." }, "input": { "enum": [ "radio", "select" ], "type": "string" }, "labels": { "description": "An array of label strings to represent the `options` values. If unspecified, the `options` value will be coerced to a string and used as the label.", "items": { "type": "string" }, "type": "array" }, "name": { "description": "By default, the signal name is used to label input elements. This `name` property can be used instead to specify a custom label for the bound signal.", "type": "string" }, "options": { "description": "An array of options to select from.", "items": {}, "type": "array" } }, "required": [ "input", "options" ], "type": "object" }, "BindRange": { "additionalProperties": false, "properties": { "debounce": { "description": "If defined, delays event handling until the specified milliseconds have elapsed since the last event was fired.", "type": "number" }, "element": { "$ref": "#/definitions/Element", "description": "An optional CSS selector string indicating the parent element to which the input element should be added. By default, all input elements are added within the parent container of the Vega view." }, "input": { "const": "range", "type": "string" }, "max": { "description": "Sets the maximum slider value. Defaults to the larger of the signal value and `100`.", "type": "number" }, "min": { "description": "Sets the minimum slider value. Defaults to the smaller of the signal value and `0`.", "type": "number" }, "name": { "description": "By default, the signal name is used to label input elements. This `name` property can be used instead to specify a custom label for the bound signal.", "type": "string" }, "step": { "description": "Sets the minimum slider increment. If undefined, the step size will be automatically determined based on the `min` and `max` values.", "type": "number" } }, "required": [ "input" ], "type": "object" }, "Binding": { "anyOf": [ { "$ref": "#/definitions/BindCheckbox" }, { "$ref": "#/definitions/BindRadioSelect" }, { "$ref": "#/definitions/BindRange" }, { "$ref": "#/definitions/BindInput" }, { "$ref": "#/definitions/BindDirect" } ] }, "BinnedTimeUnit": { "anyOf": [ { "enum": [ "binnedyear", "binnedyearquarter", "binnedyearquartermonth", "binnedyearmonth", "binnedyearmonthdate", "binnedyearmonthdatehours", "binnedyearmonthdatehoursminutes", "binnedyearmonthdatehoursminutesseconds", "binnedyearweek", "binnedyearweekday", "binnedyearweekdayhours", "binnedyearweekdayhoursminutes", "binnedyearweekdayhoursminutesseconds", "binnedyeardayofyear" ], "type": "string" }, { "enum": [ "binnedutcyear", "binnedutcyearquarter", "binnedutcyearquartermonth", "binnedutcyearmonth", "binnedutcyearmonthdate", "binnedutcyearmonthdatehours", "binnedutcyearmonthdatehoursminutes", "binnedutcyearmonthdatehoursminutesseconds", "binnedutcyearweek", "binnedutcyearweekday", "binnedutcyearweekdayhours", "binnedutcyearweekdayhoursminutes", "binnedutcyearweekdayhoursminutesseconds", "binnedutcyeardayofyear" ], "type": "string" } ] }, "Blend": { "enum": [ null, "multiply", "screen", "overlay", "darken", "lighten", "color-dodge", "color-burn", "hard-light", "soft-light", "difference", "exclusion", "hue", "saturation", "color", "luminosity" ], "type": [ "null", "string" ] }, "BoxPlot": { "const": "boxplot", "type": "string" }, "BoxPlotConfig": { "additionalProperties": false, "properties": { "box": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "extent": { "anyOf": [ { "const": "min-max", "type": "string" }, { "type": "number" } ], "description": "The extent of the whiskers. Available options include:\n- `\"min-max\"`: min and max are the lower and upper whiskers respectively.\n- A number representing multiple of the interquartile range. This number will be multiplied by the IQR to determine whisker boundary, which spans from the smallest data to the largest data within the range _[Q1 - k * IQR, Q3 + k * IQR]_ where _Q1_ and _Q3_ are the first and third quartiles while _IQR_ is the interquartile range (_Q3-Q1_).\n\n__Default value:__ `1.5`." }, "median": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "outliers": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "rule": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "size": { "description": "Size of the box and median tick of a box plot", "type": "number" }, "ticks": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] } }, "type": "object" }, "BoxPlotDef": { "additionalProperties": false, "properties": { "box": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "clip": { "description": "Whether a composite mark be clipped to the enclosing group’s width and height.", "type": "boolean" }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "extent": { "anyOf": [ { "const": "min-max", "type": "string" }, { "type": "number" } ], "description": "The extent of the whiskers. Available options include:\n- `\"min-max\"`: min and max are the lower and upper whiskers respectively.\n- A number representing multiple of the interquartile range. This number will be multiplied by the IQR to determine whisker boundary, which spans from the smallest data to the largest data within the range _[Q1 - k * IQR, Q3 + k * IQR]_ where _Q1_ and _Q3_ are the first and third quartiles while _IQR_ is the interquartile range (_Q3-Q1_).\n\n__Default value:__ `1.5`." }, "invalid": { "description": "Defines how Vega-Lite should handle marks for invalid values (`null` and `NaN`).\n- If set to `\"filter\"` (default), all data items with null values will be skipped (for line, trail, and area marks) or filtered (for other marks).\n- If `null`, all data items are included. In this case, invalid values will be interpreted as zeroes.", "enum": [ "filter", null ], "type": [ "string", "null" ] }, "median": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "opacity": { "description": "The opacity (value between [0,1]) of the mark.", "maximum": 1, "minimum": 0, "type": "number" }, "orient": { "$ref": "#/definitions/Orientation", "description": "Orientation of the box plot. This is normally automatically determined based on types of fields on x and y channels. However, an explicit `orient` be specified when the orientation is ambiguous.\n\n__Default value:__ `\"vertical\"`." }, "outliers": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "rule": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "size": { "description": "Size of the box and median tick of a box plot", "type": "number" }, "ticks": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "type": { "$ref": "#/definitions/BoxPlot", "description": "The mark type. This could a primitive mark type (one of `\"bar\"`, `\"circle\"`, `\"square\"`, `\"tick\"`, `\"line\"`, `\"area\"`, `\"point\"`, `\"geoshape\"`, `\"rule\"`, and `\"text\"`) or a composite mark type (`\"boxplot\"`, `\"errorband\"`, `\"errorbar\"`)." } }, "required": [ "type" ], "type": "object" }, "BrushConfig": { "additionalProperties": false, "properties": { "cursor": { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the interval mark. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, "fill": { "$ref": "#/definitions/Color", "description": "The fill color of the interval mark.\n\n__Default value:__ `\"#333333\"`" }, "fillOpacity": { "description": "The fill opacity of the interval mark (a value between `0` and `1`).\n\n__Default value:__ `0.125`", "type": "number" }, "stroke": { "$ref": "#/definitions/Color", "description": "The stroke color of the interval mark.\n\n__Default value:__ `\"#ffffff\"`" }, "strokeDash": { "description": "An array of alternating stroke and space lengths, for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, "strokeDashOffset": { "description": "The offset (in pixels) with which to begin drawing the stroke dash array.", "type": "number" }, "strokeOpacity": { "description": "The stroke opacity of the interval mark (a value between `0` and `1`).", "type": "number" }, "strokeWidth": { "description": "The stroke width of the interval mark.", "type": "number" } }, "type": "object" }, "CalculateTransform": { "additionalProperties": false, "properties": { "as": { "$ref": "#/definitions/FieldName", "description": "The field for storing the computed formula value." }, "calculate": { "description": "A [expression](https://vega.github.io/vega-lite/docs/types.html#expression) string. Use the variable `datum` to refer to the current data object.", "type": "string" } }, "required": [ "calculate", "as" ], "type": "object" }, "Categorical": { "enum": [ "accent", "category10", "category20", "category20b", "category20c", "dark2", "paired", "pastel1", "pastel2", "set1", "set2", "set3", "tableau10", "tableau20" ], "type": "string" }, "Color": { "anyOf": [ { "$ref": "#/definitions/ColorName" }, { "$ref": "#/definitions/HexColor" }, { "type": "string" } ] }, "ColorDef": { "$ref": "#/definitions/MarkPropDef<(Gradient|string|null)>" }, "ColorName": { "enum": [ "black", "silver", "gray", "white", "maroon", "red", "purple", "fuchsia", "green", "lime", "olive", "yellow", "navy", "blue", "teal", "aqua", "orange", "aliceblue", "antiquewhite", "aquamarine", "azure", "beige", "bisque", "blanchedalmond", "blueviolet", "brown", "burlywood", "cadetblue", "chartreuse", "chocolate", "coral", "cornflowerblue", "cornsilk", "crimson", "cyan", "darkblue", "darkcyan", "darkgoldenrod", "darkgray", "darkgreen", "darkgrey", "darkkhaki", "darkmagenta", "darkolivegreen", "darkorange", "darkorchid", "darkred", "darksalmon", "darkseagreen", "darkslateblue", "darkslategray", "darkslategrey", "darkturquoise", "darkviolet", "deeppink", "deepskyblue", "dimgray", "dimgrey", "dodgerblue", "firebrick", "floralwhite", "forestgreen", "gainsboro", "ghostwhite", "gold", "goldenrod", "greenyellow", "grey", "honeydew", "hotpink", "indianred", "indigo", "ivory", "khaki", "lavender", "lavenderblush", "lawngreen", "lemonchiffon", "lightblue", "lightcoral", "lightcyan", "lightgoldenrodyellow", "lightgray", "lightgreen", "lightgrey", "lightpink", "lightsalmon", "lightseagreen", "lightskyblue", "lightslategray", "lightslategrey", "lightsteelblue", "lightyellow", "limegreen", "linen", "magenta", "mediumaquamarine", "mediumblue", "mediumorchid", "mediumpurple", "mediumseagreen", "mediumslateblue", "mediumspringgreen", "mediumturquoise", "mediumvioletred", "midnightblue", "mintcream", "mistyrose", "moccasin", "navajowhite", "oldlace", "olivedrab", "orangered", "orchid", "palegoldenrod", "palegreen", "paleturquoise", "palevioletred", "papayawhip", "peachpuff", "peru", "pink", "plum", "powderblue", "rosybrown", "royalblue", "saddlebrown", "salmon", "sandybrown", "seagreen", "seashell", "sienna", "skyblue", "slateblue", "slategray", "slategrey", "snow", "springgreen", "steelblue", "tan", "thistle", "tomato", "turquoise", "violet", "wheat", "whitesmoke", "yellowgreen", "rebeccapurple" ], "type": "string" }, "ColorScheme": { "anyOf": [ { "$ref": "#/definitions/Categorical" }, { "$ref": "#/definitions/SequentialSingleHue" }, { "$ref": "#/definitions/SequentialMultiHue" }, { "$ref": "#/definitions/Diverging" }, { "$ref": "#/definitions/Cyclical" } ] }, "Encoding": { "additionalProperties": false, "properties": { "angle": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Rotation angle of point and text marks." }, "color": { "$ref": "#/definitions/ColorDef", "description": "Color of the marks – either fill or stroke color based on the `filled` property of mark definition. By default, `color` represents fill color for `\"area\"`, `\"bar\"`, `\"tick\"`, `\"text\"`, `\"trail\"`, `\"circle\"`, and `\"square\"` / stroke color for `\"line\"` and `\"point\"`.\n\n__Default value:__ If undefined, the default color depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `color` property.\n\n_Note:_ 1) For fine-grained control over both fill and stroke colors of the marks, please use the `fill` and `stroke` channels. The `fill` or `stroke` encodings have higher precedence than `color`, thus may override the `color` encoding if conflicting encodings are specified. 2) See the scale documentation for more information about customizing [color scheme](https://vega.github.io/vega-lite/docs/scale.html#scheme)." }, "description": { "anyOf": [ { "$ref": "#/definitions/StringFieldDefWithCondition" }, { "$ref": "#/definitions/StringValueDefWithCondition" } ], "description": "A text description of this mark for ARIA accessibility (SVG output only). For SVG output the `\"aria-label\"` attribute will be set to this description." }, "detail": { "anyOf": [ { "$ref": "#/definitions/FieldDefWithoutScale" }, { "items": { "$ref": "#/definitions/FieldDefWithoutScale" }, "type": "array" } ], "description": "Additional levels of detail for grouping data in aggregate views and in line, trail, and area marks without mapping data to a specific visual channel." }, "fill": { "$ref": "#/definitions/ColorDef", "description": "Fill color of the marks. __Default value:__ If undefined, the default color depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `color` property.\n\n_Note:_ The `fill` encoding has higher precedence than `color`, thus may override the `color` encoding if conflicting encodings are specified." }, "fillOpacity": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Fill opacity of the marks.\n\n__Default value:__ If undefined, the default opacity depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `fillOpacity` property." }, "href": { "anyOf": [ { "$ref": "#/definitions/StringFieldDefWithCondition" }, { "$ref": "#/definitions/StringValueDefWithCondition" } ], "description": "A URL to load upon mouse click." }, "key": { "$ref": "#/definitions/FieldDefWithoutScale", "description": "A data field to use as a unique key for data binding. When a visualization’s data is updated, the key value will be used to match data elements to existing mark instances. Use a key channel to enable object constancy for transitions over dynamic data." }, "latitude": { "$ref": "#/definitions/LatLongDef", "description": "Latitude position of geographically projected marks." }, "latitude2": { "$ref": "#/definitions/Position2Def", "description": "Latitude-2 position for geographically projected ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`." }, "longitude": { "$ref": "#/definitions/LatLongDef", "description": "Longitude position of geographically projected marks." }, "longitude2": { "$ref": "#/definitions/Position2Def", "description": "Longitude-2 position for geographically projected ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`." }, "opacity": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Opacity of the marks.\n\n__Default value:__ If undefined, the default opacity depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `opacity` property." }, "order": { "anyOf": [ { "$ref": "#/definitions/OrderFieldDef" }, { "items": { "$ref": "#/definitions/OrderFieldDef" }, "type": "array" }, { "$ref": "#/definitions/OrderValueDef" }, { "$ref": "#/definitions/OrderOnlyDef" } ], "description": "Order of the marks.\n- For stacked marks, this `order` channel encodes [stack order](https://vega.github.io/vega-lite/docs/stack.html#order).\n- For line and trail marks, this `order` channel encodes order of data points in the lines. This can be useful for creating [a connected scatterplot](https://vega.github.io/vega-lite/examples/connected_scatterplot.html). Setting `order` to `{\"value\": null}` makes the line marks use the original order in the data sources.\n- Otherwise, this `order` channel encodes layer order of the marks.\n\n__Note__: In aggregate plots, `order` field should be `aggregate`d to avoid creating additional aggregation grouping." }, "radius": { "$ref": "#/definitions/PolarDef", "description": "The outer radius in pixels of arc marks." }, "radius2": { "$ref": "#/definitions/Position2Def", "description": "The inner radius in pixels of arc marks." }, "shape": { "$ref": "#/definitions/ShapeDef", "description": "Shape of the mark.\n\n1. For `point` marks the supported values include: - plotting shapes: `\"circle\"`, `\"square\"`, `\"cross\"`, `\"diamond\"`, `\"triangle-up\"`, `\"triangle-down\"`, `\"triangle-right\"`, or `\"triangle-left\"`. - the line symbol `\"stroke\"` - centered directional shapes `\"arrow\"`, `\"wedge\"`, or `\"triangle\"` - a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) (For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.)\n\n2. For `geoshape` marks it should be a field definition of the geojson data\n\n__Default value:__ If undefined, the default shape depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#point-config)'s `shape` property. (`\"circle\"` if unset.)" }, "size": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Size of the mark.\n- For `\"point\"`, `\"square\"` and `\"circle\"`, – the symbol size, or pixel area of the mark.\n- For `\"bar\"` and `\"tick\"` – the bar and tick's size.\n- For `\"text\"` – the text's font size.\n- Size is unsupported for `\"line\"`, `\"area\"`, and `\"rect\"`. (Use `\"trail\"` instead of line with varying size)" }, "stroke": { "$ref": "#/definitions/ColorDef", "description": "Stroke color of the marks. __Default value:__ If undefined, the default color depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `color` property.\n\n_Note:_ The `stroke` encoding has higher precedence than `color`, thus may override the `color` encoding if conflicting encodings are specified." }, "strokeDash": { "$ref": "#/definitions/NumericArrayMarkPropDef", "description": "Stroke dash of the marks.\n\n__Default value:__ `[1,0]` (No dash)." }, "strokeOpacity": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Stroke opacity of the marks.\n\n__Default value:__ If undefined, the default opacity depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `strokeOpacity` property." }, "strokeWidth": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Stroke width of the marks.\n\n__Default value:__ If undefined, the default stroke width depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `strokeWidth` property." }, "text": { "$ref": "#/definitions/TextDef", "description": "Text of the `text` mark." }, "theta": { "$ref": "#/definitions/PolarDef", "description": "- For arc marks, the arc length in radians if theta2 is not specified, otherwise the start arc angle. (A value of 0 indicates up or “north”, increasing values proceed clockwise.)\n\n- For text marks, polar coordinate angle in radians." }, "theta2": { "$ref": "#/definitions/Position2Def", "description": "The end angle of arc marks in radians. A value of 0 indicates up or “north”, increasing values proceed clockwise." }, "tooltip": { "anyOf": [ { "$ref": "#/definitions/StringFieldDefWithCondition" }, { "$ref": "#/definitions/StringValueDefWithCondition" }, { "items": { "$ref": "#/definitions/StringFieldDef" }, "type": "array" }, { "type": "null" } ], "description": "The tooltip text to show upon mouse hover. Specifying `tooltip` encoding overrides [the `tooltip` property in the mark definition](https://vega.github.io/vega-lite/docs/mark.html#mark-def).\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite." }, "url": { "anyOf": [ { "$ref": "#/definitions/StringFieldDefWithCondition" }, { "$ref": "#/definitions/StringValueDefWithCondition" } ], "description": "The URL of an image mark." }, "x": { "$ref": "#/definitions/PositionDef", "description": "X coordinates of the marks, or width of horizontal `\"bar\"` and `\"area\"` without specified `x2` or `width`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2": { "$ref": "#/definitions/Position2Def", "description": "X2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "xError": { "anyOf": [ { "$ref": "#/definitions/SecondaryFieldDef" }, { "$ref": "#/definitions/ValueDef" } ], "description": "Error value of x coordinates for error specified `\"errorbar\"` and `\"errorband\"`." }, "xError2": { "anyOf": [ { "$ref": "#/definitions/SecondaryFieldDef" }, { "$ref": "#/definitions/ValueDef" } ], "description": "Secondary error value of x coordinates for error specified `\"errorbar\"` and `\"errorband\"`." }, "xOffset": { "$ref": "#/definitions/OffsetDef", "description": "Offset of x-position of the marks" }, "y": { "$ref": "#/definitions/PositionDef", "description": "Y coordinates of the marks, or height of vertical `\"bar\"` and `\"area\"` without specified `y2` or `height`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2": { "$ref": "#/definitions/Position2Def", "description": "Y2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "yError": { "anyOf": [ { "$ref": "#/definitions/SecondaryFieldDef" }, { "$ref": "#/definitions/ValueDef" } ], "description": "Error value of y coordinates for error specified `\"errorbar\"` and `\"errorband\"`." }, "yError2": { "anyOf": [ { "$ref": "#/definitions/SecondaryFieldDef" }, { "$ref": "#/definitions/ValueDef" } ], "description": "Secondary error value of y coordinates for error specified `\"errorbar\"` and `\"errorband\"`." }, "yOffset": { "$ref": "#/definitions/OffsetDef", "description": "Offset of y-position of the marks" } }, "type": "object" }, "CompositeMark": { "anyOf": [ { "$ref": "#/definitions/BoxPlot" }, { "$ref": "#/definitions/ErrorBar" }, { "$ref": "#/definitions/ErrorBand" } ] }, "CompositeMarkDef": { "anyOf": [ { "$ref": "#/definitions/BoxPlotDef" }, { "$ref": "#/definitions/ErrorBarDef" }, { "$ref": "#/definitions/ErrorBandDef" } ] }, "CompositionConfig": { "additionalProperties": false, "properties": { "columns": { "description": "The number of columns to include in the view composition layout.\n\n__Default value__: `undefined` -- An infinite number of columns (a single row) will be assumed. This is equivalent to `hconcat` (for `concat`) and to using the `column` channel (for `facet` and `repeat`).\n\n__Note__:\n\n1) This property is only for:\n- the general (wrappable) `concat` operator (not `hconcat`/`vconcat`)\n- the `facet` and `repeat` operator with one field/repetition definition (without row/column nesting)\n\n2) Setting the `columns` to `1` is equivalent to `vconcat` (for `concat`) and to using the `row` channel (for `facet` and `repeat`).", "type": "number" }, "spacing": { "description": "The default spacing in pixels between composed sub-views.\n\n__Default value__: `20`", "type": "number" } }, "type": "object" }, "ConditionalMarkPropFieldOrDatumDef": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate" }, { "$ref": "#/definitions/ConditionalParameter" } ] }, "ConditionalMarkPropFieldOrDatumDef": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate>" }, { "$ref": "#/definitions/ConditionalParameter>" } ] }, "ConditionalStringFieldDef": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate" }, { "$ref": "#/definitions/ConditionalParameter" } ] }, "ConditionalValueDef<(Gradient|string|null|ExprRef)>": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate>" }, { "$ref": "#/definitions/ConditionalParameter>" } ] }, "ConditionalValueDef<(Text|ExprRef)>": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate>" }, { "$ref": "#/definitions/ConditionalParameter>" } ] }, "ConditionalValueDef<(number[]|ExprRef)>": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate>" }, { "$ref": "#/definitions/ConditionalParameter>" } ] }, "ConditionalValueDef<(number|ExprRef)>": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate>" }, { "$ref": "#/definitions/ConditionalParameter>" } ] }, "ConditionalValueDef<(string|ExprRef)>": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate>" }, { "$ref": "#/definitions/ConditionalParameter>" } ] }, "ConditionalValueDef<(string|null|ExprRef)>": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate>" }, { "$ref": "#/definitions/ConditionalParameter>" } ] }, "ConditionalValueDef": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate>" }, { "$ref": "#/definitions/ConditionalParameter>" } ] }, "ConditionalAxisColor": { "$ref": "#/definitions/ConditionalAxisProperty<(Color|null)>" }, "ConditionalAxisLabelAlign": { "$ref": "#/definitions/ConditionalAxisProperty<(Align|null)>" }, "ConditionalAxisLabelBaseline": { "$ref": "#/definitions/ConditionalAxisProperty<(TextBaseline|null)>" }, "ConditionalAxisLabelFontStyle": { "$ref": "#/definitions/ConditionalAxisProperty<(FontStyle|null)>" }, "ConditionalAxisLabelFontWeight": { "$ref": "#/definitions/ConditionalAxisProperty<(FontWeight|null)>" }, "ConditionalAxisNumber": { "$ref": "#/definitions/ConditionalAxisProperty<(number|null)>" }, "ConditionalAxisNumberArray": { "$ref": "#/definitions/ConditionalAxisProperty<(number[]|null)>" }, "ConditionalAxisProperty<(Align|null)>": { "anyOf": [ { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(Align|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(Align|null)>|ExprRef)>" }, "type": "array" } ] }, "value": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "condition", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(Align|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(Align|null)>|ExprRef)>" }, "type": "array" } ] }, "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" } }, "required": [ "condition", "expr" ], "type": "object" } ] }, "ConditionalAxisProperty<(Color|null)>": { "anyOf": [ { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(Color|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(Color|null)>|ExprRef)>" }, "type": "array" } ] }, "value": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "condition", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(Color|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(Color|null)>|ExprRef)>" }, "type": "array" } ] }, "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" } }, "required": [ "condition", "expr" ], "type": "object" } ] }, "ConditionalAxisProperty<(FontStyle|null)>": { "anyOf": [ { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(FontStyle|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(FontStyle|null)>|ExprRef)>" }, "type": "array" } ] }, "value": { "anyOf": [ { "$ref": "#/definitions/FontStyle" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "condition", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(FontStyle|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(FontStyle|null)>|ExprRef)>" }, "type": "array" } ] }, "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" } }, "required": [ "condition", "expr" ], "type": "object" } ] }, "ConditionalAxisProperty<(FontWeight|null)>": { "anyOf": [ { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(FontWeight|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(FontWeight|null)>|ExprRef)>" }, "type": "array" } ] }, "value": { "anyOf": [ { "$ref": "#/definitions/FontWeight" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "condition", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(FontWeight|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(FontWeight|null)>|ExprRef)>" }, "type": "array" } ] }, "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" } }, "required": [ "condition", "expr" ], "type": "object" } ] }, "ConditionalAxisProperty<(TextBaseline|null)>": { "anyOf": [ { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(TextBaseline|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(TextBaseline|null)>|ExprRef)>" }, "type": "array" } ] }, "value": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "condition", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(TextBaseline|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(TextBaseline|null)>|ExprRef)>" }, "type": "array" } ] }, "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" } }, "required": [ "condition", "expr" ], "type": "object" } ] }, "ConditionalAxisProperty<(number[]|null)>": { "anyOf": [ { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(number[]|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(number[]|null)>|ExprRef)>" }, "type": "array" } ] }, "value": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "condition", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(number[]|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(number[]|null)>|ExprRef)>" }, "type": "array" } ] }, "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" } }, "required": [ "condition", "expr" ], "type": "object" } ] }, "ConditionalAxisProperty<(number|null)>": { "anyOf": [ { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(number|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(number|null)>|ExprRef)>" }, "type": "array" } ] }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": [ "number", "null" ] } }, "required": [ "condition", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(number|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(number|null)>|ExprRef)>" }, "type": "array" } ] }, "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" } }, "required": [ "condition", "expr" ], "type": "object" } ] }, "ConditionalAxisProperty<(string|null)>": { "anyOf": [ { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(string|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(string|null)>|ExprRef)>" }, "type": "array" } ] }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": [ "string", "null" ] } }, "required": [ "condition", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(string|null)>|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalPredicate<(ValueDef<(string|null)>|ExprRef)>" }, "type": "array" } ] }, "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" } }, "required": [ "condition", "expr" ], "type": "object" } ] }, "ConditionalAxisString": { "$ref": "#/definitions/ConditionalAxisProperty<(string|null)>" }, "ConditionalParameter": { "anyOf": [ { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "required": [ "param" ], "type": "object" }, { "additionalProperties": false, "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "required": [ "param" ], "type": "object" } ] }, "ConditionalParameter>": { "anyOf": [ { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/TypeForShape", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "required": [ "param" ], "type": "object" }, { "additionalProperties": false, "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "required": [ "param" ], "type": "object" } ] }, "ConditionalParameter": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "required": [ "param" ], "type": "object" }, "ConditionalParameter>": { "additionalProperties": false, "properties": { "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "value": { "anyOf": [ { "$ref": "#/definitions/Gradient" }, { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "param", "value" ], "type": "object" }, "ConditionalParameter>": { "additionalProperties": false, "properties": { "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "value": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "param", "value" ], "type": "object" }, "ConditionalParameter>": { "additionalProperties": false, "properties": { "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "value": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "param", "value" ], "type": "object" }, "ConditionalParameter>": { "additionalProperties": false, "properties": { "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "value": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "param", "value" ], "type": "object" }, "ConditionalParameter>": { "additionalProperties": false, "properties": { "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "value": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "param", "value" ], "type": "object" }, "ConditionalParameter>": { "additionalProperties": false, "properties": { "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "value": { "anyOf": [ { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "param", "value" ], "type": "object" }, "ConditionalParameter>": { "additionalProperties": false, "properties": { "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": "number" } }, "required": [ "param", "value" ], "type": "object" }, "ConditionalPredicate<(ValueDef<(Align|null)>|ExprRef)>": { "anyOf": [ { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" } }, "required": [ "expr", "test" ], "type": "object" } ] }, "ConditionalPredicate<(ValueDef<(Color|null)>|ExprRef)>": { "anyOf": [ { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" } }, "required": [ "expr", "test" ], "type": "object" } ] }, "ConditionalPredicate<(ValueDef<(FontStyle|null)>|ExprRef)>": { "anyOf": [ { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "$ref": "#/definitions/FontStyle" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" } }, "required": [ "expr", "test" ], "type": "object" } ] }, "ConditionalPredicate<(ValueDef<(FontWeight|null)>|ExprRef)>": { "anyOf": [ { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "$ref": "#/definitions/FontWeight" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" } }, "required": [ "expr", "test" ], "type": "object" } ] }, "ConditionalPredicate<(ValueDef<(TextBaseline|null)>|ExprRef)>": { "anyOf": [ { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" } }, "required": [ "expr", "test" ], "type": "object" } ] }, "ConditionalPredicate<(ValueDef<(number[]|null)>|ExprRef)>": { "anyOf": [ { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "type": "null" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" } }, "required": [ "expr", "test" ], "type": "object" } ] }, "ConditionalPredicate<(ValueDef<(number|null)>|ExprRef)>": { "anyOf": [ { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": [ "number", "null" ] } }, "required": [ "test", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" } }, "required": [ "expr", "test" ], "type": "object" } ] }, "ConditionalPredicate<(ValueDef<(string|null)>|ExprRef)>": { "anyOf": [ { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": [ "string", "null" ] } }, "required": [ "test", "value" ], "type": "object" }, { "additionalProperties": false, "properties": { "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" } }, "required": [ "expr", "test" ], "type": "object" } ] }, "ConditionalPredicate": { "anyOf": [ { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "required": [ "test" ], "type": "object" }, { "additionalProperties": false, "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "required": [ "test" ], "type": "object" } ] }, "ConditionalPredicate>": { "anyOf": [ { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/TypeForShape", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "required": [ "test" ], "type": "object" }, { "additionalProperties": false, "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "required": [ "test" ], "type": "object" } ] }, "ConditionalPredicate": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "required": [ "test" ], "type": "object" }, "ConditionalPredicate>": { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "$ref": "#/definitions/Gradient" }, { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, "ConditionalPredicate>": { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, "ConditionalPredicate>": { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, "ConditionalPredicate>": { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, "ConditionalPredicate>": { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, "ConditionalPredicate>": { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "anyOf": [ { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "test", "value" ], "type": "object" }, "ConditionalPredicate>": { "additionalProperties": false, "properties": { "test": { "$ref": "#/definitions/PredicateComposition", "description": "Predicate for triggering the condition" }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": "number" } }, "required": [ "test", "value" ], "type": "object" }, "Config": { "additionalProperties": false, "properties": { "arc": { "$ref": "#/definitions/RectConfig", "description": "Arc-specific Config" }, "area": { "$ref": "#/definitions/AreaConfig", "description": "Area-Specific Config" }, "aria": { "description": "A boolean flag indicating if ARIA default attributes should be included for marks and guides (SVG output only). If false, the `\"aria-hidden\"` attribute will be set for all guides, removing them from the ARIA accessibility tree and Vega-Lite will not generate default descriptions for marks.\n\n__Default value:__ `true`.", "type": "boolean" }, "autosize": { "anyOf": [ { "$ref": "#/definitions/AutosizeType" }, { "$ref": "#/definitions/AutoSizeParams" } ], "description": "How the visualization size should be determined. If a string, should be one of `\"pad\"`, `\"fit\"` or `\"none\"`. Object values can additionally specify parameters for content sizing and automatic resizing.\n\n__Default value__: `pad`" }, "axis": { "$ref": "#/definitions/AxisConfig", "description": "Axis configuration, which determines default properties for all `x` and `y` [axes](https://vega.github.io/vega-lite/docs/axis.html). For a full list of axis configuration options, please see the [corresponding section of the axis documentation](https://vega.github.io/vega-lite/docs/axis.html#config)." }, "axisBand": { "$ref": "#/definitions/AxisConfig", "description": "Config for axes with \"band\" scales." }, "axisBottom": { "$ref": "#/definitions/AxisConfig", "description": "Config for x-axis along the bottom edge of the chart." }, "axisDiscrete": { "$ref": "#/definitions/AxisConfig", "description": "Config for axes with \"point\" or \"band\" scales." }, "axisLeft": { "$ref": "#/definitions/AxisConfig", "description": "Config for y-axis along the left edge of the chart." }, "axisPoint": { "$ref": "#/definitions/AxisConfig", "description": "Config for axes with \"point\" scales." }, "axisQuantitative": { "$ref": "#/definitions/AxisConfig", "description": "Config for quantitative axes." }, "axisRight": { "$ref": "#/definitions/AxisConfig", "description": "Config for y-axis along the right edge of the chart." }, "axisTemporal": { "$ref": "#/definitions/AxisConfig", "description": "Config for temporal axes." }, "axisTop": { "$ref": "#/definitions/AxisConfig", "description": "Config for x-axis along the top edge of the chart." }, "axisX": { "$ref": "#/definitions/AxisConfig", "description": "X-axis specific config." }, "axisXBand": { "$ref": "#/definitions/AxisConfig", "description": "Config for x-axes with \"band\" scales." }, "axisXDiscrete": { "$ref": "#/definitions/AxisConfig", "description": "Config for x-axes with \"point\" or \"band\" scales." }, "axisXPoint": { "$ref": "#/definitions/AxisConfig", "description": "Config for x-axes with \"point\" scales." }, "axisXQuantitative": { "$ref": "#/definitions/AxisConfig", "description": "Config for x-quantitative axes." }, "axisXTemporal": { "$ref": "#/definitions/AxisConfig", "description": "Config for x-temporal axes." }, "axisY": { "$ref": "#/definitions/AxisConfig", "description": "Y-axis specific config." }, "axisYBand": { "$ref": "#/definitions/AxisConfig", "description": "Config for y-axes with \"band\" scales." }, "axisYDiscrete": { "$ref": "#/definitions/AxisConfig", "description": "Config for y-axes with \"point\" or \"band\" scales." }, "axisYPoint": { "$ref": "#/definitions/AxisConfig", "description": "Config for y-axes with \"point\" scales." }, "axisYQuantitative": { "$ref": "#/definitions/AxisConfig", "description": "Config for y-quantitative axes." }, "axisYTemporal": { "$ref": "#/definitions/AxisConfig", "description": "Config for y-temporal axes." }, "background": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "CSS color property to use as the background of the entire view.\n\n__Default value:__ `\"white\"`" }, "bar": { "$ref": "#/definitions/BarConfig", "description": "Bar-Specific Config" }, "boxplot": { "$ref": "#/definitions/BoxPlotConfig", "description": "Box Config" }, "circle": { "$ref": "#/definitions/MarkConfig", "description": "Circle-Specific Config" }, "concat": { "$ref": "#/definitions/CompositionConfig", "description": "Default configuration for all concatenation and repeat view composition operators (`concat`, `hconcat`, `vconcat`, and `repeat`)" }, "countTitle": { "description": "Default axis and legend title for count fields.\n\n__Default value:__ `'Count of Records`.", "type": "string" }, "customFormatTypes": { "description": "Allow the `formatType` property for text marks and guides to accept a custom formatter function [registered as a Vega expression](https://vega.github.io/vega-lite/usage/compile.html#format-type).", "type": "boolean" }, "errorband": { "$ref": "#/definitions/ErrorBandConfig", "description": "ErrorBand Config" }, "errorbar": { "$ref": "#/definitions/ErrorBarConfig", "description": "ErrorBar Config" }, "facet": { "$ref": "#/definitions/CompositionConfig", "description": "Default configuration for the `facet` view composition operator" }, "fieldTitle": { "description": "Defines how Vega-Lite generates title for fields. There are three possible styles:\n- `\"verbal\"` (Default) - displays function in a verbal style (e.g., \"Sum of field\", \"Year-month of date\", \"field (binned)\").\n- `\"function\"` - displays function using parentheses and capitalized texts (e.g., \"SUM(field)\", \"YEARMONTH(date)\", \"BIN(field)\").\n- `\"plain\"` - displays only the field name without functions (e.g., \"field\", \"date\", \"field\").", "enum": [ "verbal", "functional", "plain" ], "type": "string" }, "font": { "description": "Default font for all text marks, titles, and labels.", "type": "string" }, "geoshape": { "$ref": "#/definitions/MarkConfig", "description": "Geoshape-Specific Config" }, "header": { "$ref": "#/definitions/HeaderConfig", "description": "Header configuration, which determines default properties for all [headers](https://vega.github.io/vega-lite/docs/header.html).\n\nFor a full list of header configuration options, please see the [corresponding section of in the header documentation](https://vega.github.io/vega-lite/docs/header.html#config)." }, "headerColumn": { "$ref": "#/definitions/HeaderConfig", "description": "Header configuration, which determines default properties for column [headers](https://vega.github.io/vega-lite/docs/header.html).\n\nFor a full list of header configuration options, please see the [corresponding section of in the header documentation](https://vega.github.io/vega-lite/docs/header.html#config)." }, "headerFacet": { "$ref": "#/definitions/HeaderConfig", "description": "Header configuration, which determines default properties for non-row/column facet [headers](https://vega.github.io/vega-lite/docs/header.html).\n\nFor a full list of header configuration options, please see the [corresponding section of in the header documentation](https://vega.github.io/vega-lite/docs/header.html#config)." }, "headerRow": { "$ref": "#/definitions/HeaderConfig", "description": "Header configuration, which determines default properties for row [headers](https://vega.github.io/vega-lite/docs/header.html).\n\nFor a full list of header configuration options, please see the [corresponding section of in the header documentation](https://vega.github.io/vega-lite/docs/header.html#config)." }, "image": { "$ref": "#/definitions/RectConfig", "description": "Image-specific Config" }, "legend": { "$ref": "#/definitions/LegendConfig", "description": "Legend configuration, which determines default properties for all [legends](https://vega.github.io/vega-lite/docs/legend.html). For a full list of legend configuration options, please see the [corresponding section of in the legend documentation](https://vega.github.io/vega-lite/docs/legend.html#config)." }, "line": { "$ref": "#/definitions/LineConfig", "description": "Line-Specific Config" }, "lineBreak": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A delimiter, such as a newline character, upon which to break text strings into multiple lines. This property provides a global default for text marks, which is overridden by mark or style config settings, and by the lineBreak mark encoding channel. If signal-valued, either string or regular expression (regexp) values are valid." }, "locale": { "$ref": "#/definitions/Locale", "description": "Locale definitions for string parsing and formatting of number and date values. The locale object should contain `number` and/or `time` properties with [locale definitions](https://vega.github.io/vega/docs/api/locale/). Locale definitions provided in the config block may be overridden by the View constructor locale option." }, "mark": { "$ref": "#/definitions/MarkConfig", "description": "Mark Config" }, "normalizedNumberFormat": { "description": "If normalizedNumberFormatType is not specified, D3 number format for axis labels, text marks, and tooltips of normalized stacked fields (fields with `stack: \"normalize\"`). For example `\"s\"` for SI units. Use [D3's number format pattern](https://github.com/d3/d3-format#locale_format).\n\nIf `config.normalizedNumberFormatType` is specified and `config.customFormatTypes` is `true`, this value will be passed as `format` alongside `datum.value` to the `config.numberFormatType` function. __Default value:__ `%`", "type": "string" }, "normalizedNumberFormatType": { "description": "[Custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type) for `config.normalizedNumberFormat`.\n\n__Default value:__ `undefined` -- This is equilvalent to call D3-format, which is exposed as [`format` in Vega-Expression](https://vega.github.io/vega/docs/expressions/#format). __Note:__ You must also set `customFormatTypes` to `true` to use this feature.", "type": "string" }, "numberFormat": { "description": "If numberFormatType is not specified, D3 number format for guide labels, text marks, and tooltips of non-normalized fields (fields *without* `stack: \"normalize\"`). For example `\"s\"` for SI units. Use [D3's number format pattern](https://github.com/d3/d3-format#locale_format).\n\nIf `config.numberFormatType` is specified and `config.customFormatTypes` is `true`, this value will be passed as `format` alongside `datum.value` to the `config.numberFormatType` function.", "type": "string" }, "numberFormatType": { "description": "[Custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type) for `config.numberFormat`.\n\n__Default value:__ `undefined` -- This is equilvalent to call D3-format, which is exposed as [`format` in Vega-Expression](https://vega.github.io/vega/docs/expressions/#format). __Note:__ You must also set `customFormatTypes` to `true` to use this feature.", "type": "string" }, "padding": { "anyOf": [ { "$ref": "#/definitions/Padding" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The default visualization padding, in pixels, from the edge of the visualization canvas to the data rectangle. If a number, specifies padding for all sides. If an object, the value should have the format `{\"left\": 5, \"top\": 5, \"right\": 5, \"bottom\": 5}` to specify padding for each side of the visualization.\n\n__Default value__: `5`" }, "params": { "description": "Dynamic variables or selections that parameterize a visualization.", "items": { "$ref": "#/definitions/TopLevelParameter" }, "type": "array" }, "point": { "$ref": "#/definitions/MarkConfig", "description": "Point-Specific Config" }, "projection": { "$ref": "#/definitions/ProjectionConfig", "description": "Projection configuration, which determines default properties for all [projections](https://vega.github.io/vega-lite/docs/projection.html). For a full list of projection configuration options, please see the [corresponding section of the projection documentation](https://vega.github.io/vega-lite/docs/projection.html#config)." }, "range": { "$ref": "#/definitions/RangeConfig", "description": "An object hash that defines default range arrays or schemes for using with scales. For a full list of scale range configuration options, please see the [corresponding section of the scale documentation](https://vega.github.io/vega-lite/docs/scale.html#config)." }, "rect": { "$ref": "#/definitions/RectConfig", "description": "Rect-Specific Config" }, "rule": { "$ref": "#/definitions/MarkConfig", "description": "Rule-Specific Config" }, "scale": { "$ref": "#/definitions/ScaleConfig", "description": "Scale configuration determines default properties for all [scales](https://vega.github.io/vega-lite/docs/scale.html). For a full list of scale configuration options, please see the [corresponding section of the scale documentation](https://vega.github.io/vega-lite/docs/scale.html#config)." }, "selection": { "$ref": "#/definitions/SelectionConfig", "description": "An object hash for defining default properties for each type of selections." }, "square": { "$ref": "#/definitions/MarkConfig", "description": "Square-Specific Config" }, "style": { "$ref": "#/definitions/StyleConfigIndex", "description": "An object hash that defines key-value mappings to determine default properties for marks with a given [style](https://vega.github.io/vega-lite/docs/mark.html#mark-def). The keys represent styles names; the values have to be valid [mark configuration objects](https://vega.github.io/vega-lite/docs/mark.html#config)." }, "text": { "$ref": "#/definitions/MarkConfig", "description": "Text-Specific Config" }, "tick": { "$ref": "#/definitions/TickConfig", "description": "Tick-Specific Config" }, "timeFormat": { "description": "Default time format for raw time values (without time units) in text marks, legend labels and header labels.\n\n__Default value:__ `\"%b %d, %Y\"` __Note:__ Axes automatically determine the format for each label automatically so this config does not affect axes.", "type": "string" }, "timeFormatType": { "description": "[Custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type) for `config.timeFormat`.\n\n__Default value:__ `undefined` -- This is equilvalent to call D3-time-format, which is exposed as [`timeFormat` in Vega-Expression](https://vega.github.io/vega/docs/expressions/#timeFormat). __Note:__ You must also set `customFormatTypes` to `true` and there must *not* be a `timeUnit` defined to use this feature.", "type": "string" }, "title": { "$ref": "#/definitions/TitleConfig", "description": "Title configuration, which determines default properties for all [titles](https://vega.github.io/vega-lite/docs/title.html). For a full list of title configuration options, please see the [corresponding section of the title documentation](https://vega.github.io/vega-lite/docs/title.html#config)." }, "tooltipFormat": { "$ref": "#/definitions/FormatConfig", "description": "Define [custom format configuration](https://vega.github.io/vega-lite/docs/config.html#format) for tooltips. If unspecified, default format config will be applied." }, "trail": { "$ref": "#/definitions/LineConfig", "description": "Trail-Specific Config" }, "view": { "$ref": "#/definitions/ViewConfig", "description": "Default properties for [single view plots](https://vega.github.io/vega-lite/docs/spec.html#single)." } }, "type": "object" }, "CsvDataFormat": { "additionalProperties": false, "properties": { "parse": { "anyOf": [ { "$ref": "#/definitions/Parse" }, { "type": "null" } ], "description": "If set to `null`, disable type inference based on the spec and only use type inference based on the data. Alternatively, a parsing directive object can be provided for explicit data types. Each property of the object corresponds to a field name, and the value to the desired data type (one of `\"number\"`, `\"boolean\"`, `\"date\"`, or null (do not parse the field)). For example, `\"parse\": {\"modified_on\": \"date\"}` parses the `modified_on` field in each input record a Date value.\n\nFor `\"date\"`, we parse data based using JavaScript's [`Date.parse()`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Date/parse). For Specific date formats can be provided (e.g., `{foo: \"date:'%m%d%Y'\"}`), using the [d3-time-format syntax](https://github.com/d3/d3-time-format#locale_format). UTC date format parsing is supported similarly (e.g., `{foo: \"utc:'%m%d%Y'\"}`). See more about [UTC time](https://vega.github.io/vega-lite/docs/timeunit.html#utc)" }, "type": { "description": "Type of input data: `\"json\"`, `\"csv\"`, `\"tsv\"`, `\"dsv\"`.\n\n__Default value:__ The default format type is determined by the extension of the file URL. If no extension is detected, `\"json\"` will be used by default.", "enum": [ "csv", "tsv" ], "type": "string" } }, "type": "object" }, "Cursor": { "enum": [ "auto", "default", "none", "context-menu", "help", "pointer", "progress", "wait", "cell", "crosshair", "text", "vertical-text", "alias", "copy", "move", "no-drop", "not-allowed", "e-resize", "n-resize", "ne-resize", "nw-resize", "s-resize", "se-resize", "sw-resize", "w-resize", "ew-resize", "ns-resize", "nesw-resize", "nwse-resize", "col-resize", "row-resize", "all-scroll", "zoom-in", "zoom-out", "grab", "grabbing" ], "type": "string" }, "Cyclical": { "enum": [ "rainbow", "sinebow" ], "type": "string" }, "Data": { "anyOf": [ { "$ref": "#/definitions/DataSource" }, { "$ref": "#/definitions/Generator" } ] }, "DataFormat": { "anyOf": [ { "$ref": "#/definitions/CsvDataFormat" }, { "$ref": "#/definitions/DsvDataFormat" }, { "$ref": "#/definitions/JsonDataFormat" }, { "$ref": "#/definitions/TopoDataFormat" } ] }, "DataSource": { "anyOf": [ { "$ref": "#/definitions/UrlData" }, { "$ref": "#/definitions/InlineData" }, { "$ref": "#/definitions/NamedData" } ] }, "Datasets": { "$ref": "#/definitions/Dict" }, "DateTime": { "additionalProperties": false, "description": "Object for defining datetime in Vega-Lite Filter. If both month and quarter are provided, month has higher precedence. `day` cannot be combined with other date. We accept string for month and day names.", "properties": { "date": { "description": "Integer value representing the date (day of the month) from 1-31.", "maximum": 31, "minimum": 1, "type": "number" }, "day": { "anyOf": [ { "$ref": "#/definitions/Day" }, { "type": "string" } ], "description": "Value representing the day of a week. This can be one of: (1) integer value -- `1` represents Monday; (2) case-insensitive day name (e.g., `\"Monday\"`); (3) case-insensitive, 3-character short day name (e.g., `\"Mon\"`).\n\n**Warning:** A DateTime definition object with `day`** should not be combined with `year`, `quarter`, `month`, or `date`." }, "hours": { "description": "Integer value representing the hour of a day from 0-23.", "maximum": 24, "minimum": 0, "type": "number" }, "milliseconds": { "description": "Integer value representing the millisecond segment of time.", "maximum": 1000, "minimum": 0, "type": "number" }, "minutes": { "description": "Integer value representing the minute segment of time from 0-59.", "maximum": 60, "minimum": 0, "type": "number" }, "month": { "anyOf": [ { "$ref": "#/definitions/Month" }, { "type": "string" } ], "description": "One of: (1) integer value representing the month from `1`-`12`. `1` represents January; (2) case-insensitive month name (e.g., `\"January\"`); (3) case-insensitive, 3-character short month name (e.g., `\"Jan\"`)." }, "quarter": { "description": "Integer value representing the quarter of the year (from 1-4).", "maximum": 4, "minimum": 1, "type": "number" }, "seconds": { "description": "Integer value representing the second segment (0-59) of a time value", "maximum": 60, "minimum": 0, "type": "number" }, "utc": { "description": "A boolean flag indicating if date time is in utc time. If false, the date time is in local time", "type": "boolean" }, "year": { "description": "Integer value representing the year.", "type": "number" } }, "type": "object" }, "DatumDef": { "additionalProperties": false, "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "Day": { "maximum": 7, "minimum": 1, "type": "number" }, "DensityTransform": { "additionalProperties": false, "properties": { "as": { "description": "The output fields for the sample value and corresponding density estimate.\n\n__Default value:__ `[\"value\", \"density\"]`", "items": { "$ref": "#/definitions/FieldName" }, "maxItems": 2, "minItems": 2, "type": "array" }, "bandwidth": { "description": "The bandwidth (standard deviation) of the Gaussian kernel. If unspecified or set to zero, the bandwidth value is automatically estimated from the input data using Scott’s rule.", "type": "number" }, "counts": { "description": "A boolean flag indicating if the output values should be probability estimates (false) or smoothed counts (true).\n\n__Default value:__ `false`", "type": "boolean" }, "cumulative": { "description": "A boolean flag indicating whether to produce density estimates (false) or cumulative density estimates (true).\n\n__Default value:__ `false`", "type": "boolean" }, "density": { "$ref": "#/definitions/FieldName", "description": "The data field for which to perform density estimation." }, "extent": { "description": "A [min, max] domain from which to sample the distribution. If unspecified, the extent will be determined by the observed minimum and maximum values of the density value field.", "items": { "type": "number" }, "maxItems": 2, "minItems": 2, "type": "array" }, "groupby": { "description": "The data fields to group by. If not specified, a single group containing all data objects will be used.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "maxsteps": { "description": "The maximum number of samples to take along the extent domain for plotting the density.\n\n__Default value:__ `200`", "type": "number" }, "minsteps": { "description": "The minimum number of samples to take along the extent domain for plotting the density.\n\n__Default value:__ `25`", "type": "number" }, "steps": { "description": "The exact number of samples to take along the extent domain for plotting the density. If specified, overrides both minsteps and maxsteps to set an exact number of uniform samples. Potentially useful in conjunction with a fixed extent to ensure consistent sample points for stacked densities.", "type": "number" } }, "required": [ "density" ], "type": "object" }, "DerivedStream": { "additionalProperties": false, "properties": { "between": { "items": { "$ref": "#/definitions/Stream" }, "type": "array" }, "consume": { "type": "boolean" }, "debounce": { "type": "number" }, "filter": { "anyOf": [ { "$ref": "#/definitions/Expr" }, { "items": { "$ref": "#/definitions/Expr" }, "type": "array" } ] }, "markname": { "type": "string" }, "marktype": { "$ref": "#/definitions/MarkType" }, "stream": { "$ref": "#/definitions/Stream" }, "throttle": { "type": "number" } }, "required": [ "stream" ], "type": "object" }, "Dict": { "additionalProperties": { "$ref": "#/definitions/InlineDataset" }, "type": "object" }, "Dict": { "additionalProperties": { "$ref": "#/definitions/SelectionInit" }, "type": "object" }, "Dict": { "additionalProperties": { "$ref": "#/definitions/SelectionInitInterval" }, "type": "object" }, "Dict": { "additionalProperties": {}, "type": "object" }, "Diverging": { "enum": [ "blueorange", "blueorange-3", "blueorange-4", "blueorange-5", "blueorange-6", "blueorange-7", "blueorange-8", "blueorange-9", "blueorange-10", "blueorange-11", "brownbluegreen", "brownbluegreen-3", "brownbluegreen-4", "brownbluegreen-5", "brownbluegreen-6", "brownbluegreen-7", "brownbluegreen-8", "brownbluegreen-9", "brownbluegreen-10", "brownbluegreen-11", "purplegreen", "purplegreen-3", "purplegreen-4", "purplegreen-5", "purplegreen-6", "purplegreen-7", "purplegreen-8", "purplegreen-9", "purplegreen-10", "purplegreen-11", "pinkyellowgreen", "pinkyellowgreen-3", "pinkyellowgreen-4", "pinkyellowgreen-5", "pinkyellowgreen-6", "pinkyellowgreen-7", "pinkyellowgreen-8", "pinkyellowgreen-9", "pinkyellowgreen-10", "pinkyellowgreen-11", "purpleorange", "purpleorange-3", "purpleorange-4", "purpleorange-5", "purpleorange-6", "purpleorange-7", "purpleorange-8", "purpleorange-9", "purpleorange-10", "purpleorange-11", "redblue", "redblue-3", "redblue-4", "redblue-5", "redblue-6", "redblue-7", "redblue-8", "redblue-9", "redblue-10", "redblue-11", "redgrey", "redgrey-3", "redgrey-4", "redgrey-5", "redgrey-6", "redgrey-7", "redgrey-8", "redgrey-9", "redgrey-10", "redgrey-11", "redyellowblue", "redyellowblue-3", "redyellowblue-4", "redyellowblue-5", "redyellowblue-6", "redyellowblue-7", "redyellowblue-8", "redyellowblue-9", "redyellowblue-10", "redyellowblue-11", "redyellowgreen", "redyellowgreen-3", "redyellowgreen-4", "redyellowgreen-5", "redyellowgreen-6", "redyellowgreen-7", "redyellowgreen-8", "redyellowgreen-9", "redyellowgreen-10", "redyellowgreen-11", "spectral", "spectral-3", "spectral-4", "spectral-5", "spectral-6", "spectral-7", "spectral-8", "spectral-9", "spectral-10", "spectral-11" ], "type": "string" }, "DomainUnionWith": { "additionalProperties": false, "properties": { "unionWith": { "description": "Customized domain values to be union with the field's values or explicitly defined domain. Should be an array of valid scale domain values.", "items": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "type": "boolean" }, { "$ref": "#/definitions/DateTime" } ] }, "type": "array" } }, "required": [ "unionWith" ], "type": "object" }, "DsvDataFormat": { "additionalProperties": false, "properties": { "delimiter": { "description": "The delimiter between records. The delimiter must be a single character (i.e., a single 16-bit code unit); so, ASCII delimiters are fine, but emoji delimiters are not.", "maxLength": 1, "minLength": 1, "type": "string" }, "parse": { "anyOf": [ { "$ref": "#/definitions/Parse" }, { "type": "null" } ], "description": "If set to `null`, disable type inference based on the spec and only use type inference based on the data. Alternatively, a parsing directive object can be provided for explicit data types. Each property of the object corresponds to a field name, and the value to the desired data type (one of `\"number\"`, `\"boolean\"`, `\"date\"`, or null (do not parse the field)). For example, `\"parse\": {\"modified_on\": \"date\"}` parses the `modified_on` field in each input record a Date value.\n\nFor `\"date\"`, we parse data based using JavaScript's [`Date.parse()`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Date/parse). For Specific date formats can be provided (e.g., `{foo: \"date:'%m%d%Y'\"}`), using the [d3-time-format syntax](https://github.com/d3/d3-time-format#locale_format). UTC date format parsing is supported similarly (e.g., `{foo: \"utc:'%m%d%Y'\"}`). See more about [UTC time](https://vega.github.io/vega-lite/docs/timeunit.html#utc)" }, "type": { "const": "dsv", "description": "Type of input data: `\"json\"`, `\"csv\"`, `\"tsv\"`, `\"dsv\"`.\n\n__Default value:__ The default format type is determined by the extension of the file URL. If no extension is detected, `\"json\"` will be used by default.", "type": "string" } }, "required": [ "delimiter" ], "type": "object" }, "Element": { "type": "string" }, "EncodingSortField": { "additionalProperties": false, "description": "A sort definition for sorting a discrete scale in an encoding field definition.", "properties": { "field": { "$ref": "#/definitions/Field", "description": "The data [field](https://vega.github.io/vega-lite/docs/field.html) to sort by.\n\n__Default value:__ If unspecified, defaults to the field specified in the outer data reference." }, "op": { "$ref": "#/definitions/NonArgAggregateOp", "description": "An [aggregate operation](https://vega.github.io/vega-lite/docs/aggregate.html#ops) to perform on the field prior to sorting (e.g., `\"count\"`, `\"mean\"` and `\"median\"`). An aggregation is required when there are multiple values of the sort field for each encoded data field. The input data objects will be aggregated, grouped by the encoded data field.\n\nFor a full list of operations, please see the documentation for [aggregate](https://vega.github.io/vega-lite/docs/aggregate.html#ops).\n\n__Default value:__ `\"sum\"` for stacked plots. Otherwise, `\"min\"`." }, "order": { "anyOf": [ { "$ref": "#/definitions/SortOrder" }, { "type": "null" } ], "description": "The sort order. One of `\"ascending\"` (default), `\"descending\"`, or `null` (no not sort)." } }, "type": "object" }, "ErrorBand": { "const": "errorband", "type": "string" }, "ErrorBandConfig": { "additionalProperties": false, "properties": { "band": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "borders": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "extent": { "$ref": "#/definitions/ErrorBarExtent", "description": "The extent of the band. Available options include:\n- `\"ci\"`: Extend the band to the confidence interval of the mean.\n- `\"stderr\"`: The size of band are set to the value of standard error, extending from the mean.\n- `\"stdev\"`: The size of band are set to the value of standard deviation, extending from the mean.\n- `\"iqr\"`: Extend the band to the q1 and q3.\n\n__Default value:__ `\"stderr\"`." }, "interpolate": { "$ref": "#/definitions/Interpolate", "description": "The line interpolation method for the error band. One of the following:\n- `\"linear\"`: piecewise linear segments, as in a polyline.\n- `\"linear-closed\"`: close the linear segments to form a polygon.\n- `\"step\"`: a piecewise constant function (a step function) consisting of alternating horizontal and vertical lines. The y-value changes at the midpoint of each pair of adjacent x-values.\n- `\"step-before\"`: a piecewise constant function (a step function) consisting of alternating horizontal and vertical lines. The y-value changes before the x-value.\n- `\"step-after\"`: a piecewise constant function (a step function) consisting of alternating horizontal and vertical lines. The y-value changes after the x-value.\n- `\"basis\"`: a B-spline, with control point duplication on the ends.\n- `\"basis-open\"`: an open B-spline; may not intersect the start or end.\n- `\"basis-closed\"`: a closed B-spline, as in a loop.\n- `\"cardinal\"`: a Cardinal spline, with control point duplication on the ends.\n- `\"cardinal-open\"`: an open Cardinal spline; may not intersect the start or end, but will intersect other control points.\n- `\"cardinal-closed\"`: a closed Cardinal spline, as in a loop.\n- `\"bundle\"`: equivalent to basis, except the tension parameter is used to straighten the spline.\n- `\"monotone\"`: cubic interpolation that preserves monotonicity in y." }, "tension": { "description": "The tension parameter for the interpolation type of the error band.", "maximum": 1, "minimum": 0, "type": "number" } }, "type": "object" }, "ErrorBandDef": { "additionalProperties": false, "properties": { "band": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "borders": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "clip": { "description": "Whether a composite mark be clipped to the enclosing group’s width and height.", "type": "boolean" }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "extent": { "$ref": "#/definitions/ErrorBarExtent", "description": "The extent of the band. Available options include:\n- `\"ci\"`: Extend the band to the confidence interval of the mean.\n- `\"stderr\"`: The size of band are set to the value of standard error, extending from the mean.\n- `\"stdev\"`: The size of band are set to the value of standard deviation, extending from the mean.\n- `\"iqr\"`: Extend the band to the q1 and q3.\n\n__Default value:__ `\"stderr\"`." }, "interpolate": { "$ref": "#/definitions/Interpolate", "description": "The line interpolation method for the error band. One of the following:\n- `\"linear\"`: piecewise linear segments, as in a polyline.\n- `\"linear-closed\"`: close the linear segments to form a polygon.\n- `\"step\"`: a piecewise constant function (a step function) consisting of alternating horizontal and vertical lines. The y-value changes at the midpoint of each pair of adjacent x-values.\n- `\"step-before\"`: a piecewise constant function (a step function) consisting of alternating horizontal and vertical lines. The y-value changes before the x-value.\n- `\"step-after\"`: a piecewise constant function (a step function) consisting of alternating horizontal and vertical lines. The y-value changes after the x-value.\n- `\"basis\"`: a B-spline, with control point duplication on the ends.\n- `\"basis-open\"`: an open B-spline; may not intersect the start or end.\n- `\"basis-closed\"`: a closed B-spline, as in a loop.\n- `\"cardinal\"`: a Cardinal spline, with control point duplication on the ends.\n- `\"cardinal-open\"`: an open Cardinal spline; may not intersect the start or end, but will intersect other control points.\n- `\"cardinal-closed\"`: a closed Cardinal spline, as in a loop.\n- `\"bundle\"`: equivalent to basis, except the tension parameter is used to straighten the spline.\n- `\"monotone\"`: cubic interpolation that preserves monotonicity in y." }, "opacity": { "description": "The opacity (value between [0,1]) of the mark.", "maximum": 1, "minimum": 0, "type": "number" }, "orient": { "$ref": "#/definitions/Orientation", "description": "Orientation of the error band. This is normally automatically determined, but can be specified when the orientation is ambiguous and cannot be automatically determined." }, "tension": { "description": "The tension parameter for the interpolation type of the error band.", "maximum": 1, "minimum": 0, "type": "number" }, "type": { "$ref": "#/definitions/ErrorBand", "description": "The mark type. This could a primitive mark type (one of `\"bar\"`, `\"circle\"`, `\"square\"`, `\"tick\"`, `\"line\"`, `\"area\"`, `\"point\"`, `\"geoshape\"`, `\"rule\"`, and `\"text\"`) or a composite mark type (`\"boxplot\"`, `\"errorband\"`, `\"errorbar\"`)." } }, "required": [ "type" ], "type": "object" }, "ErrorBar": { "const": "errorbar", "type": "string" }, "ErrorBarConfig": { "additionalProperties": false, "properties": { "extent": { "$ref": "#/definitions/ErrorBarExtent", "description": "The extent of the rule. Available options include:\n- `\"ci\"`: Extend the rule to the confidence interval of the mean.\n- `\"stderr\"`: The size of rule are set to the value of standard error, extending from the mean.\n- `\"stdev\"`: The size of rule are set to the value of standard deviation, extending from the mean.\n- `\"iqr\"`: Extend the rule to the q1 and q3.\n\n__Default value:__ `\"stderr\"`." }, "rule": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "size": { "description": "Size of the ticks of an error bar", "type": "number" }, "thickness": { "description": "Thickness of the ticks and the bar of an error bar", "type": "number" }, "ticks": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] } }, "type": "object" }, "ErrorBarDef": { "additionalProperties": false, "properties": { "clip": { "description": "Whether a composite mark be clipped to the enclosing group’s width and height.", "type": "boolean" }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "extent": { "$ref": "#/definitions/ErrorBarExtent", "description": "The extent of the rule. Available options include:\n- `\"ci\"`: Extend the rule to the confidence interval of the mean.\n- `\"stderr\"`: The size of rule are set to the value of standard error, extending from the mean.\n- `\"stdev\"`: The size of rule are set to the value of standard deviation, extending from the mean.\n- `\"iqr\"`: Extend the rule to the q1 and q3.\n\n__Default value:__ `\"stderr\"`." }, "opacity": { "description": "The opacity (value between [0,1]) of the mark.", "maximum": 1, "minimum": 0, "type": "number" }, "orient": { "$ref": "#/definitions/Orientation", "description": "Orientation of the error bar. This is normally automatically determined, but can be specified when the orientation is ambiguous and cannot be automatically determined." }, "rule": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "size": { "description": "Size of the ticks of an error bar", "type": "number" }, "thickness": { "description": "Thickness of the ticks and the bar of an error bar", "type": "number" }, "ticks": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/AnyMarkConfig" } ] }, "type": { "$ref": "#/definitions/ErrorBar", "description": "The mark type. This could a primitive mark type (one of `\"bar\"`, `\"circle\"`, `\"square\"`, `\"tick\"`, `\"line\"`, `\"area\"`, `\"point\"`, `\"geoshape\"`, `\"rule\"`, and `\"text\"`) or a composite mark type (`\"boxplot\"`, `\"errorband\"`, `\"errorbar\"`)." } }, "required": [ "type" ], "type": "object" }, "ErrorBarExtent": { "enum": [ "ci", "iqr", "stderr", "stdev" ], "type": "string" }, "EventStream": { "anyOf": [ { "additionalProperties": false, "properties": { "between": { "items": { "$ref": "#/definitions/Stream" }, "type": "array" }, "consume": { "type": "boolean" }, "debounce": { "type": "number" }, "filter": { "anyOf": [ { "$ref": "#/definitions/Expr" }, { "items": { "$ref": "#/definitions/Expr" }, "type": "array" } ] }, "markname": { "type": "string" }, "marktype": { "$ref": "#/definitions/MarkType" }, "source": { "enum": [ "view", "scope" ], "type": "string" }, "throttle": { "type": "number" }, "type": { "$ref": "#/definitions/EventType" } }, "required": [ "type" ], "type": "object" }, { "additionalProperties": false, "properties": { "between": { "items": { "$ref": "#/definitions/Stream" }, "type": "array" }, "consume": { "type": "boolean" }, "debounce": { "type": "number" }, "filter": { "anyOf": [ { "$ref": "#/definitions/Expr" }, { "items": { "$ref": "#/definitions/Expr" }, "type": "array" } ] }, "markname": { "type": "string" }, "marktype": { "$ref": "#/definitions/MarkType" }, "source": { "const": "window", "type": "string" }, "throttle": { "type": "number" }, "type": { "$ref": "#/definitions/WindowEventType" } }, "required": [ "source", "type" ], "type": "object" } ] }, "EventType": { "enum": [ "click", "dblclick", "dragenter", "dragleave", "dragover", "keydown", "keypress", "keyup", "mousedown", "mousemove", "mouseout", "mouseover", "mouseup", "mousewheel", "pointerdown", "pointermove", "pointerout", "pointerover", "pointerup", "timer", "touchend", "touchmove", "touchstart", "wheel" ], "type": "string" }, "Expr": { "type": "string" }, "ExprRef": { "additionalProperties": false, "properties": { "expr": { "description": "Vega expression (which can refer to Vega-Lite parameters).", "type": "string" } }, "required": [ "expr" ], "type": "object" }, "ExtentTransform": { "additionalProperties": false, "properties": { "extent": { "$ref": "#/definitions/FieldName", "description": "The field of which to get the extent." }, "param": { "$ref": "#/definitions/ParameterName", "description": "The output parameter produced by the extent transform." } }, "required": [ "extent", "param" ], "type": "object" }, "FacetEncodingFieldDef": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "columns": { "description": "The number of columns to include in the view composition layout.\n\n__Default value__: `undefined` -- An infinite number of columns (a single row) will be assumed. This is equivalent to `hconcat` (for `concat`) and to using the `column` channel (for `facet` and `repeat`).\n\n__Note__:\n\n1) This property is only for:\n- the general (wrappable) `concat` operator (not `hconcat`/`vconcat`)\n- the `facet` and `repeat` operator with one field/repetition definition (without row/column nesting)\n\n2) Setting the `columns` to `1` is equivalent to `vconcat` (for `concat`) and to using the `row` channel (for `facet` and `repeat`).", "type": "number" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "header": { "anyOf": [ { "$ref": "#/definitions/Header" }, { "type": "null" } ], "description": "An object defining properties of a facet's header." }, "sort": { "anyOf": [ { "$ref": "#/definitions/SortArray" }, { "$ref": "#/definitions/SortOrder" }, { "$ref": "#/definitions/EncodingSortField" }, { "type": "null" } ], "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` is not supported for `row` and `column`." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FacetFieldDef": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "header": { "anyOf": [ { "$ref": "#/definitions/Header" }, { "type": "null" } ], "description": "An object defining properties of a facet's header." }, "sort": { "anyOf": [ { "$ref": "#/definitions/SortArray" }, { "$ref": "#/definitions/SortOrder" }, { "$ref": "#/definitions/EncodingSortField" }, { "type": "null" } ], "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` is not supported for `row` and `column`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FacetMapping": { "additionalProperties": false, "properties": { "column": { "$ref": "#/definitions/FacetFieldDef", "description": "A field definition for the horizontal facet of trellis plots." }, "row": { "$ref": "#/definitions/FacetFieldDef", "description": "A field definition for the vertical facet of trellis plots." } }, "type": "object" }, "FacetedEncoding": { "additionalProperties": false, "properties": { "angle": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Rotation angle of point and text marks." }, "color": { "$ref": "#/definitions/ColorDef", "description": "Color of the marks – either fill or stroke color based on the `filled` property of mark definition. By default, `color` represents fill color for `\"area\"`, `\"bar\"`, `\"tick\"`, `\"text\"`, `\"trail\"`, `\"circle\"`, and `\"square\"` / stroke color for `\"line\"` and `\"point\"`.\n\n__Default value:__ If undefined, the default color depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `color` property.\n\n_Note:_ 1) For fine-grained control over both fill and stroke colors of the marks, please use the `fill` and `stroke` channels. The `fill` or `stroke` encodings have higher precedence than `color`, thus may override the `color` encoding if conflicting encodings are specified. 2) See the scale documentation for more information about customizing [color scheme](https://vega.github.io/vega-lite/docs/scale.html#scheme)." }, "column": { "$ref": "#/definitions/RowColumnEncodingFieldDef", "description": "A field definition for the horizontal facet of trellis plots." }, "description": { "anyOf": [ { "$ref": "#/definitions/StringFieldDefWithCondition" }, { "$ref": "#/definitions/StringValueDefWithCondition" } ], "description": "A text description of this mark for ARIA accessibility (SVG output only). For SVG output the `\"aria-label\"` attribute will be set to this description." }, "detail": { "anyOf": [ { "$ref": "#/definitions/FieldDefWithoutScale" }, { "items": { "$ref": "#/definitions/FieldDefWithoutScale" }, "type": "array" } ], "description": "Additional levels of detail for grouping data in aggregate views and in line, trail, and area marks without mapping data to a specific visual channel." }, "facet": { "$ref": "#/definitions/FacetEncodingFieldDef", "description": "A field definition for the (flexible) facet of trellis plots.\n\nIf either `row` or `column` is specified, this channel will be ignored." }, "fill": { "$ref": "#/definitions/ColorDef", "description": "Fill color of the marks. __Default value:__ If undefined, the default color depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `color` property.\n\n_Note:_ The `fill` encoding has higher precedence than `color`, thus may override the `color` encoding if conflicting encodings are specified." }, "fillOpacity": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Fill opacity of the marks.\n\n__Default value:__ If undefined, the default opacity depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `fillOpacity` property." }, "href": { "anyOf": [ { "$ref": "#/definitions/StringFieldDefWithCondition" }, { "$ref": "#/definitions/StringValueDefWithCondition" } ], "description": "A URL to load upon mouse click." }, "key": { "$ref": "#/definitions/FieldDefWithoutScale", "description": "A data field to use as a unique key for data binding. When a visualization’s data is updated, the key value will be used to match data elements to existing mark instances. Use a key channel to enable object constancy for transitions over dynamic data." }, "latitude": { "$ref": "#/definitions/LatLongDef", "description": "Latitude position of geographically projected marks." }, "latitude2": { "$ref": "#/definitions/Position2Def", "description": "Latitude-2 position for geographically projected ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`." }, "longitude": { "$ref": "#/definitions/LatLongDef", "description": "Longitude position of geographically projected marks." }, "longitude2": { "$ref": "#/definitions/Position2Def", "description": "Longitude-2 position for geographically projected ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`." }, "opacity": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Opacity of the marks.\n\n__Default value:__ If undefined, the default opacity depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `opacity` property." }, "order": { "anyOf": [ { "$ref": "#/definitions/OrderFieldDef" }, { "items": { "$ref": "#/definitions/OrderFieldDef" }, "type": "array" }, { "$ref": "#/definitions/OrderValueDef" }, { "$ref": "#/definitions/OrderOnlyDef" } ], "description": "Order of the marks.\n- For stacked marks, this `order` channel encodes [stack order](https://vega.github.io/vega-lite/docs/stack.html#order).\n- For line and trail marks, this `order` channel encodes order of data points in the lines. This can be useful for creating [a connected scatterplot](https://vega.github.io/vega-lite/examples/connected_scatterplot.html). Setting `order` to `{\"value\": null}` makes the line marks use the original order in the data sources.\n- Otherwise, this `order` channel encodes layer order of the marks.\n\n__Note__: In aggregate plots, `order` field should be `aggregate`d to avoid creating additional aggregation grouping." }, "radius": { "$ref": "#/definitions/PolarDef", "description": "The outer radius in pixels of arc marks." }, "radius2": { "$ref": "#/definitions/Position2Def", "description": "The inner radius in pixels of arc marks." }, "row": { "$ref": "#/definitions/RowColumnEncodingFieldDef", "description": "A field definition for the vertical facet of trellis plots." }, "shape": { "$ref": "#/definitions/ShapeDef", "description": "Shape of the mark.\n\n1. For `point` marks the supported values include: - plotting shapes: `\"circle\"`, `\"square\"`, `\"cross\"`, `\"diamond\"`, `\"triangle-up\"`, `\"triangle-down\"`, `\"triangle-right\"`, or `\"triangle-left\"`. - the line symbol `\"stroke\"` - centered directional shapes `\"arrow\"`, `\"wedge\"`, or `\"triangle\"` - a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) (For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.)\n\n2. For `geoshape` marks it should be a field definition of the geojson data\n\n__Default value:__ If undefined, the default shape depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#point-config)'s `shape` property. (`\"circle\"` if unset.)" }, "size": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Size of the mark.\n- For `\"point\"`, `\"square\"` and `\"circle\"`, – the symbol size, or pixel area of the mark.\n- For `\"bar\"` and `\"tick\"` – the bar and tick's size.\n- For `\"text\"` – the text's font size.\n- Size is unsupported for `\"line\"`, `\"area\"`, and `\"rect\"`. (Use `\"trail\"` instead of line with varying size)" }, "stroke": { "$ref": "#/definitions/ColorDef", "description": "Stroke color of the marks. __Default value:__ If undefined, the default color depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `color` property.\n\n_Note:_ The `stroke` encoding has higher precedence than `color`, thus may override the `color` encoding if conflicting encodings are specified." }, "strokeDash": { "$ref": "#/definitions/NumericArrayMarkPropDef", "description": "Stroke dash of the marks.\n\n__Default value:__ `[1,0]` (No dash)." }, "strokeOpacity": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Stroke opacity of the marks.\n\n__Default value:__ If undefined, the default opacity depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `strokeOpacity` property." }, "strokeWidth": { "$ref": "#/definitions/NumericMarkPropDef", "description": "Stroke width of the marks.\n\n__Default value:__ If undefined, the default stroke width depends on [mark config](https://vega.github.io/vega-lite/docs/config.html#mark-config)'s `strokeWidth` property." }, "text": { "$ref": "#/definitions/TextDef", "description": "Text of the `text` mark." }, "theta": { "$ref": "#/definitions/PolarDef", "description": "- For arc marks, the arc length in radians if theta2 is not specified, otherwise the start arc angle. (A value of 0 indicates up or “north”, increasing values proceed clockwise.)\n\n- For text marks, polar coordinate angle in radians." }, "theta2": { "$ref": "#/definitions/Position2Def", "description": "The end angle of arc marks in radians. A value of 0 indicates up or “north”, increasing values proceed clockwise." }, "tooltip": { "anyOf": [ { "$ref": "#/definitions/StringFieldDefWithCondition" }, { "$ref": "#/definitions/StringValueDefWithCondition" }, { "items": { "$ref": "#/definitions/StringFieldDef" }, "type": "array" }, { "type": "null" } ], "description": "The tooltip text to show upon mouse hover. Specifying `tooltip` encoding overrides [the `tooltip` property in the mark definition](https://vega.github.io/vega-lite/docs/mark.html#mark-def).\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite." }, "url": { "anyOf": [ { "$ref": "#/definitions/StringFieldDefWithCondition" }, { "$ref": "#/definitions/StringValueDefWithCondition" } ], "description": "The URL of an image mark." }, "x": { "$ref": "#/definitions/PositionDef", "description": "X coordinates of the marks, or width of horizontal `\"bar\"` and `\"area\"` without specified `x2` or `width`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2": { "$ref": "#/definitions/Position2Def", "description": "X2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "xError": { "anyOf": [ { "$ref": "#/definitions/SecondaryFieldDef" }, { "$ref": "#/definitions/ValueDef" } ], "description": "Error value of x coordinates for error specified `\"errorbar\"` and `\"errorband\"`." }, "xError2": { "anyOf": [ { "$ref": "#/definitions/SecondaryFieldDef" }, { "$ref": "#/definitions/ValueDef" } ], "description": "Secondary error value of x coordinates for error specified `\"errorbar\"` and `\"errorband\"`." }, "xOffset": { "$ref": "#/definitions/OffsetDef", "description": "Offset of x-position of the marks" }, "y": { "$ref": "#/definitions/PositionDef", "description": "Y coordinates of the marks, or height of vertical `\"bar\"` and `\"area\"` without specified `y2` or `height`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2": { "$ref": "#/definitions/Position2Def", "description": "Y2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "yError": { "anyOf": [ { "$ref": "#/definitions/SecondaryFieldDef" }, { "$ref": "#/definitions/ValueDef" } ], "description": "Error value of y coordinates for error specified `\"errorbar\"` and `\"errorband\"`." }, "yError2": { "anyOf": [ { "$ref": "#/definitions/SecondaryFieldDef" }, { "$ref": "#/definitions/ValueDef" } ], "description": "Secondary error value of y coordinates for error specified `\"errorbar\"` and `\"errorband\"`." }, "yOffset": { "$ref": "#/definitions/OffsetDef", "description": "Offset of y-position of the marks" } }, "type": "object" }, "FacetedUnitSpec": { "additionalProperties": false, "description": "Unit spec that can have a composite mark and row or column channels (shorthand for a facet spec).", "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "encoding": { "$ref": "#/definitions/FacetedEncoding", "description": "A key-value mapping between encoding channels and definition of fields." }, "height": { "anyOf": [ { "type": "number" }, { "const": "container", "type": "string" }, { "$ref": "#/definitions/Step" } ], "description": "The height of a visualization.\n\n- For a plot with a continuous y-field, height should be a number.\n- For a plot with either a discrete y-field or no y-field, height can be either a number indicating a fixed height or an object in the form of `{step: number}` defining the height per discrete step. (No y-field is equivalent to having one discrete step.)\n- To enable responsive sizing on height, it should be set to `\"container\"`.\n\n__Default value:__ Based on `config.view.continuousHeight` for a plot with a continuous y-field and `config.view.discreteHeight` otherwise.\n\n__Note:__ For plots with [`row` and `column` channels](https://vega.github.io/vega-lite/docs/encoding.html#facet), this represents the height of a single view and the `\"container\"` option cannot be used.\n\n__See also:__ [`height`](https://vega.github.io/vega-lite/docs/size.html) documentation." }, "mark": { "$ref": "#/definitions/AnyMark", "description": "A string describing the mark type (one of `\"bar\"`, `\"circle\"`, `\"square\"`, `\"tick\"`, `\"line\"`, `\"area\"`, `\"point\"`, `\"rule\"`, `\"geoshape\"`, and `\"text\"`) or a [mark definition object](https://vega.github.io/vega-lite/docs/mark.html#mark-def)." }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "params": { "description": "An array of parameters that may either be simple variables, or more complex selections that map user input to data queries.", "items": { "$ref": "#/definitions/SelectionParameter" }, "type": "array" }, "projection": { "$ref": "#/definitions/Projection", "description": "An object defining properties of geographic projection, which will be applied to `shape` path for `\"geoshape\"` marks and to `latitude` and `\"longitude\"` channels for other marks." }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "view": { "$ref": "#/definitions/ViewBackground", "description": "An object defining the view background's fill and stroke.\n\n__Default value:__ none (transparent)" }, "width": { "anyOf": [ { "type": "number" }, { "const": "container", "type": "string" }, { "$ref": "#/definitions/Step" } ], "description": "The width of a visualization.\n\n- For a plot with a continuous x-field, width should be a number.\n- For a plot with either a discrete x-field or no x-field, width can be either a number indicating a fixed width or an object in the form of `{step: number}` defining the width per discrete step. (No x-field is equivalent to having one discrete step.)\n- To enable responsive sizing on width, it should be set to `\"container\"`.\n\n__Default value:__ Based on `config.view.continuousWidth` for a plot with a continuous x-field and `config.view.discreteWidth` otherwise.\n\n__Note:__ For plots with [`row` and `column` channels](https://vega.github.io/vega-lite/docs/encoding.html#facet), this represents the width of a single view and the `\"container\"` option cannot be used.\n\n__See also:__ [`width`](https://vega.github.io/vega-lite/docs/size.html) documentation." } }, "required": [ "mark" ], "type": "object" }, "Feature": { "additionalProperties": false, "description": "A feature object which contains a geometry and associated properties. https://tools.ietf.org/html/rfc7946#section-3.2", "properties": { "bbox": { "$ref": "#/definitions/BBox", "description": "Bounding box of the coordinate range of the object's Geometries, Features, or Feature Collections. https://tools.ietf.org/html/rfc7946#section-5" }, "geometry": { "$ref": "#/definitions/Geometry", "description": "The feature's geometry" }, "id": { "description": "A value that uniquely identifies this feature in a https://tools.ietf.org/html/rfc7946#section-3.2.", "type": [ "string", "number" ] }, "properties": { "$ref": "#/definitions/GeoJsonProperties", "description": "Properties associated with this feature." }, "type": { "const": "Feature", "description": "Specifies the type of GeoJSON object.", "type": "string" } }, "required": [ "geometry", "properties", "type" ], "type": "object" }, "Feature": { "additionalProperties": false, "description": "A feature object which contains a geometry and associated properties. https://tools.ietf.org/html/rfc7946#section-3.2", "properties": { "bbox": { "$ref": "#/definitions/BBox", "description": "Bounding box of the coordinate range of the object's Geometries, Features, or Feature Collections. https://tools.ietf.org/html/rfc7946#section-5" }, "geometry": { "$ref": "#/definitions/Geometry", "description": "The feature's geometry" }, "id": { "description": "A value that uniquely identifies this feature in a https://tools.ietf.org/html/rfc7946#section-3.2.", "type": [ "string", "number" ] }, "properties": { "$ref": "#/definitions/GeoJsonProperties", "description": "Properties associated with this feature." }, "type": { "const": "Feature", "description": "Specifies the type of GeoJSON object.", "type": "string" } }, "required": [ "geometry", "properties", "type" ], "type": "object" }, "FeatureCollection": { "additionalProperties": false, "description": "A collection of feature objects. https://tools.ietf.org/html/rfc7946#section-3.3", "properties": { "bbox": { "$ref": "#/definitions/BBox", "description": "Bounding box of the coordinate range of the object's Geometries, Features, or Feature Collections. https://tools.ietf.org/html/rfc7946#section-5" }, "features": { "items": { "$ref": "#/definitions/Feature" }, "type": "array" }, "type": { "const": "FeatureCollection", "description": "Specifies the type of GeoJSON object.", "type": "string" } }, "required": [ "features", "type" ], "type": "object" }, "Field": { "anyOf": [ { "$ref": "#/definitions/FieldName" }, { "$ref": "#/definitions/RepeatRef" } ] }, "FieldDefWithoutScale": { "$ref": "#/definitions/TypedFieldDef", "description": "Field Def without scale (and without bin: \"binned\" support)." }, "FieldEqualPredicate": { "additionalProperties": false, "properties": { "equal": { "anyOf": [ { "type": "string" }, { "type": "number" }, { "type": "boolean" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The value that the field should be equal to." }, "field": { "$ref": "#/definitions/FieldName", "description": "Field to be tested." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit for the field to be tested." } }, "required": [ "equal", "field" ], "type": "object" }, "FieldGTEPredicate": { "additionalProperties": false, "properties": { "field": { "$ref": "#/definitions/FieldName", "description": "Field to be tested." }, "gte": { "anyOf": [ { "type": "string" }, { "type": "number" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The value that the field should be greater than or equals to." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit for the field to be tested." } }, "required": [ "field", "gte" ], "type": "object" }, "FieldGTPredicate": { "additionalProperties": false, "properties": { "field": { "$ref": "#/definitions/FieldName", "description": "Field to be tested." }, "gt": { "anyOf": [ { "type": "string" }, { "type": "number" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The value that the field should be greater than." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit for the field to be tested." } }, "required": [ "field", "gt" ], "type": "object" }, "FieldLTEPredicate": { "additionalProperties": false, "properties": { "field": { "$ref": "#/definitions/FieldName", "description": "Field to be tested." }, "lte": { "anyOf": [ { "type": "string" }, { "type": "number" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The value that the field should be less than or equals to." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit for the field to be tested." } }, "required": [ "field", "lte" ], "type": "object" }, "FieldLTPredicate": { "additionalProperties": false, "properties": { "field": { "$ref": "#/definitions/FieldName", "description": "Field to be tested." }, "lt": { "anyOf": [ { "type": "string" }, { "type": "number" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The value that the field should be less than." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit for the field to be tested." } }, "required": [ "field", "lt" ], "type": "object" }, "FieldName": { "type": "string" }, "FieldOneOfPredicate": { "additionalProperties": false, "properties": { "field": { "$ref": "#/definitions/FieldName", "description": "Field to be tested." }, "oneOf": { "anyOf": [ { "items": { "type": "string" }, "type": "array" }, { "items": { "type": "number" }, "type": "array" }, { "items": { "type": "boolean" }, "type": "array" }, { "items": { "$ref": "#/definitions/DateTime" }, "type": "array" } ], "description": "A set of values that the `field`'s value should be a member of, for a data item included in the filtered data." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit for the field to be tested." } }, "required": [ "field", "oneOf" ], "type": "object" }, "FieldOrDatumDefWithCondition": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldOrDatumDefWithCondition": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldOrDatumDefWithCondition": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldOrDatumDefWithCondition": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number[]|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number[]|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldOrDatumDefWithCondition": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldOrDatumDefWithCondition": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldOrDatumDefWithCondition": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number[]|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number[]|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldOrDatumDefWithCondition,(string|null)>": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/TypeForShape", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldOrDatumDefWithCondition": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(Text|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Text|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldOrDatumDefWithCondition": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(Text|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Text|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldOrDatumDefWithCondition": { "additionalProperties": false, "description": "A FieldDef with Condition { condition: {value: ...}, field: ..., ... }", "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(string|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "FieldRange": { "additionalProperties": false, "properties": { "field": { "type": "string" } }, "required": [ "field" ], "type": "object" }, "FieldRangePredicate": { "additionalProperties": false, "properties": { "field": { "$ref": "#/definitions/FieldName", "description": "Field to be tested." }, "range": { "anyOf": [ { "items": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/DateTime" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ] }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ], "description": "An array of inclusive minimum and maximum values for a field value of a data item to be included in the filtered data.", "maxItems": 2, "minItems": 2 }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit for the field to be tested." } }, "required": [ "field", "range" ], "type": "object" }, "FieldValidPredicate": { "additionalProperties": false, "properties": { "field": { "$ref": "#/definitions/FieldName", "description": "Field to be tested." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit for the field to be tested." }, "valid": { "description": "If set to true the field's value has to be valid, meaning both not `null` and not [`NaN`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/NaN).", "type": "boolean" } }, "required": [ "field", "valid" ], "type": "object" }, "FilterTransform": { "additionalProperties": false, "properties": { "filter": { "$ref": "#/definitions/PredicateComposition", "description": "The `filter` property must be a predication definition, which can take one of the following forms:\n\n1) an [expression](https://vega.github.io/vega-lite/docs/types.html#expression) string, where `datum` can be used to refer to the current data object. For example, `{filter: \"datum.b2 > 60\"}` would make the output data includes only items that have values in the field `b2` over 60.\n\n2) one of the [field predicates](https://vega.github.io/vega-lite/docs/predicate.html#field-predicate): [`equal`](https://vega.github.io/vega-lite/docs/predicate.html#field-equal-predicate), [`lt`](https://vega.github.io/vega-lite/docs/predicate.html#lt-predicate), [`lte`](https://vega.github.io/vega-lite/docs/predicate.html#lte-predicate), [`gt`](https://vega.github.io/vega-lite/docs/predicate.html#gt-predicate), [`gte`](https://vega.github.io/vega-lite/docs/predicate.html#gte-predicate), [`range`](https://vega.github.io/vega-lite/docs/predicate.html#range-predicate), [`oneOf`](https://vega.github.io/vega-lite/docs/predicate.html#one-of-predicate), or [`valid`](https://vega.github.io/vega-lite/docs/predicate.html#valid-predicate),\n\n3) a [selection predicate](https://vega.github.io/vega-lite/docs/predicate.html#selection-predicate), which define the names of a selection that the data point should belong to (or a logical composition of selections).\n\n4) a [logical composition](https://vega.github.io/vega-lite/docs/predicate.html#composition) of (1), (2), or (3)." } }, "required": [ "filter" ], "type": "object" }, "Fit": { "anyOf": [ { "$ref": "#/definitions/GeoJsonFeature" }, { "$ref": "#/definitions/GeoJsonFeatureCollection" }, { "items": { "$ref": "#/definitions/GeoJsonFeature" }, "type": "array" } ] }, "FlattenTransform": { "additionalProperties": false, "properties": { "as": { "description": "The output field names for extracted array values.\n\n__Default value:__ The field name of the corresponding array field", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "flatten": { "description": "An array of one or more data fields containing arrays to flatten. If multiple fields are specified, their array values should have a parallel structure, ideally with the same length. If the lengths of parallel arrays do not match, the longest array will be used with `null` values added for missing entries.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" } }, "required": [ "flatten" ], "type": "object" }, "FoldTransform": { "additionalProperties": false, "properties": { "as": { "description": "The output field names for the key and value properties produced by the fold transform. __Default value:__ `[\"key\", \"value\"]`", "items": { "$ref": "#/definitions/FieldName" }, "maxItems": 2, "minItems": 2, "type": "array" }, "fold": { "description": "An array of data fields indicating the properties to fold.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" } }, "required": [ "fold" ], "type": "object" }, "FontStyle": { "type": "string" }, "FontWeight": { "enum": [ "normal", "bold", "lighter", "bolder", 100, 200, 300, 400, 500, 600, 700, 800, 900 ], "type": [ "string", "number" ] }, "FormatConfig": { "additionalProperties": false, "properties": { "normalizedNumberFormat": { "description": "If normalizedNumberFormatType is not specified, D3 number format for axis labels, text marks, and tooltips of normalized stacked fields (fields with `stack: \"normalize\"`). For example `\"s\"` for SI units. Use [D3's number format pattern](https://github.com/d3/d3-format#locale_format).\n\nIf `config.normalizedNumberFormatType` is specified and `config.customFormatTypes` is `true`, this value will be passed as `format` alongside `datum.value` to the `config.numberFormatType` function. __Default value:__ `%`", "type": "string" }, "normalizedNumberFormatType": { "description": "[Custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type) for `config.normalizedNumberFormat`.\n\n__Default value:__ `undefined` -- This is equilvalent to call D3-format, which is exposed as [`format` in Vega-Expression](https://vega.github.io/vega/docs/expressions/#format). __Note:__ You must also set `customFormatTypes` to `true` to use this feature.", "type": "string" }, "numberFormat": { "description": "If numberFormatType is not specified, D3 number format for guide labels, text marks, and tooltips of non-normalized fields (fields *without* `stack: \"normalize\"`). For example `\"s\"` for SI units. Use [D3's number format pattern](https://github.com/d3/d3-format#locale_format).\n\nIf `config.numberFormatType` is specified and `config.customFormatTypes` is `true`, this value will be passed as `format` alongside `datum.value` to the `config.numberFormatType` function.", "type": "string" }, "numberFormatType": { "description": "[Custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type) for `config.numberFormat`.\n\n__Default value:__ `undefined` -- This is equilvalent to call D3-format, which is exposed as [`format` in Vega-Expression](https://vega.github.io/vega/docs/expressions/#format). __Note:__ You must also set `customFormatTypes` to `true` to use this feature.", "type": "string" }, "timeFormat": { "description": "Default time format for raw time values (without time units) in text marks, legend labels and header labels.\n\n__Default value:__ `\"%b %d, %Y\"` __Note:__ Axes automatically determine the format for each label automatically so this config does not affect axes.", "type": "string" }, "timeFormatType": { "description": "[Custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type) for `config.timeFormat`.\n\n__Default value:__ `undefined` -- This is equilvalent to call D3-time-format, which is exposed as [`timeFormat` in Vega-Expression](https://vega.github.io/vega/docs/expressions/#timeFormat). __Note:__ You must also set `customFormatTypes` to `true` and there must *not* be a `timeUnit` defined to use this feature.", "type": "string" } }, "type": "object" }, "Generator": { "anyOf": [ { "$ref": "#/definitions/SequenceGenerator" }, { "$ref": "#/definitions/SphereGenerator" }, { "$ref": "#/definitions/GraticuleGenerator" } ] }, "ConcatSpec": { "additionalProperties": false, "description": "Base interface for a generalized concatenation specification.", "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "columns": { "description": "The number of columns to include in the view composition layout.\n\n__Default value__: `undefined` -- An infinite number of columns (a single row) will be assumed. This is equivalent to `hconcat` (for `concat`) and to using the `column` channel (for `facet` and `repeat`).\n\n__Note__:\n\n1) This property is only for:\n- the general (wrappable) `concat` operator (not `hconcat`/`vconcat`)\n- the `facet` and `repeat` operator with one field/repetition definition (without row/column nesting)\n\n2) Setting the `columns` to `1` is equivalent to `vconcat` (for `concat`) and to using the `row` channel (for `facet` and `repeat`).", "type": "number" }, "concat": { "description": "A list of views to be concatenated.", "items": { "$ref": "#/definitions/Spec" }, "type": "array" }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" } }, "required": [ "concat" ], "type": "object" }, "FacetSpec": { "additionalProperties": false, "description": "Base interface for a facet specification.", "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "columns": { "description": "The number of columns to include in the view composition layout.\n\n__Default value__: `undefined` -- An infinite number of columns (a single row) will be assumed. This is equivalent to `hconcat` (for `concat`) and to using the `column` channel (for `facet` and `repeat`).\n\n__Note__:\n\n1) This property is only for:\n- the general (wrappable) `concat` operator (not `hconcat`/`vconcat`)\n- the `facet` and `repeat` operator with one field/repetition definition (without row/column nesting)\n\n2) Setting the `columns` to `1` is equivalent to `vconcat` (for `concat`) and to using the `row` channel (for `facet` and `repeat`).", "type": "number" }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "facet": { "anyOf": [ { "$ref": "#/definitions/FacetFieldDef" }, { "$ref": "#/definitions/FacetMapping" } ], "description": "Definition for how to facet the data. One of: 1) [a field definition for faceting the plot by one field](https://vega.github.io/vega-lite/docs/facet.html#field-def) 2) [An object that maps `row` and `column` channels to their field definitions](https://vega.github.io/vega-lite/docs/facet.html#mapping)" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "spec": { "anyOf": [ { "$ref": "#/definitions/LayerSpec" }, { "$ref": "#/definitions/FacetedUnitSpec" } ], "description": "A specification of the view that gets faceted." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" } }, "required": [ "facet", "spec" ], "type": "object" }, "HConcatSpec": { "additionalProperties": false, "description": "Base interface for a horizontal concatenation specification.", "properties": { "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\n__Default value:__ `false`", "type": "boolean" }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "hconcat": { "description": "A list of views to be concatenated and put into a row.", "items": { "$ref": "#/definitions/Spec" }, "type": "array" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "description": "The spacing in pixels between sub-views of the concat operator.\n\n__Default value__: `10`", "type": "number" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" } }, "required": [ "hconcat" ], "type": "object" }, "Spec": { "anyOf": [ { "$ref": "#/definitions/FacetedUnitSpec" }, { "$ref": "#/definitions/LayerSpec" }, { "$ref": "#/definitions/RepeatSpec" }, { "$ref": "#/definitions/FacetSpec" }, { "$ref": "#/definitions/ConcatSpec" }, { "$ref": "#/definitions/VConcatSpec" }, { "$ref": "#/definitions/HConcatSpec" } ], "description": "Any specification in Vega-Lite." }, "GenericUnitSpec": { "additionalProperties": false, "description": "Base interface for a unit (single-view) specification.", "properties": { "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "encoding": { "$ref": "#/definitions/Encoding", "description": "A key-value mapping between encoding channels and definition of fields." }, "mark": { "$ref": "#/definitions/AnyMark", "description": "A string describing the mark type (one of `\"bar\"`, `\"circle\"`, `\"square\"`, `\"tick\"`, `\"line\"`, `\"area\"`, `\"point\"`, `\"rule\"`, `\"geoshape\"`, and `\"text\"`) or a [mark definition object](https://vega.github.io/vega-lite/docs/mark.html#mark-def)." }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "params": { "description": "An array of parameters that may either be simple variables, or more complex selections that map user input to data queries.", "items": { "$ref": "#/definitions/SelectionParameter" }, "type": "array" }, "projection": { "$ref": "#/definitions/Projection", "description": "An object defining properties of geographic projection, which will be applied to `shape` path for `\"geoshape\"` marks and to `latitude` and `\"longitude\"` channels for other marks." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" } }, "required": [ "mark" ], "type": "object" }, "VConcatSpec": { "additionalProperties": false, "description": "Base interface for a vertical concatenation specification.", "properties": { "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\n__Default value:__ `false`", "type": "boolean" }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "description": "The spacing in pixels between sub-views of the concat operator.\n\n__Default value__: `10`", "type": "number" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "vconcat": { "description": "A list of views to be concatenated and put into a column.", "items": { "$ref": "#/definitions/Spec" }, "type": "array" } }, "required": [ "vconcat" ], "type": "object" }, "GeoJsonFeature": { "$ref": "#/definitions/Feature" }, "GeoJsonFeatureCollection": { "$ref": "#/definitions/FeatureCollection" }, "GeoJsonProperties": { "anyOf": [ { "type": "object" }, { "type": "null" } ] }, "Geometry": { "anyOf": [ { "$ref": "#/definitions/Point" }, { "$ref": "#/definitions/MultiPoint" }, { "$ref": "#/definitions/LineString" }, { "$ref": "#/definitions/MultiLineString" }, { "$ref": "#/definitions/Polygon" }, { "$ref": "#/definitions/MultiPolygon" }, { "$ref": "#/definitions/GeometryCollection" } ], "description": "Union of geometry objects. https://tools.ietf.org/html/rfc7946#section-3" }, "GeometryCollection": { "additionalProperties": false, "description": "Geometry Collection https://tools.ietf.org/html/rfc7946#section-3.1.8", "properties": { "bbox": { "$ref": "#/definitions/BBox", "description": "Bounding box of the coordinate range of the object's Geometries, Features, or Feature Collections. https://tools.ietf.org/html/rfc7946#section-5" }, "geometries": { "items": { "$ref": "#/definitions/Geometry" }, "type": "array" }, "type": { "const": "GeometryCollection", "description": "Specifies the type of GeoJSON object.", "type": "string" } }, "required": [ "geometries", "type" ], "type": "object" }, "Gradient": { "anyOf": [ { "$ref": "#/definitions/LinearGradient" }, { "$ref": "#/definitions/RadialGradient" } ] }, "GradientStop": { "additionalProperties": false, "properties": { "color": { "$ref": "#/definitions/Color", "description": "The color value at this point in the gradient." }, "offset": { "description": "The offset fraction for the color stop, indicating its position within the gradient.", "type": "number" } }, "required": [ "offset", "color" ], "type": "object" }, "GraticuleGenerator": { "additionalProperties": false, "properties": { "graticule": { "anyOf": [ { "const": true, "type": "boolean" }, { "$ref": "#/definitions/GraticuleParams" } ], "description": "Generate graticule GeoJSON data for geographic reference lines." }, "name": { "description": "Provide a placeholder name and bind data at runtime.", "type": "string" } }, "required": [ "graticule" ], "type": "object" }, "GraticuleParams": { "additionalProperties": false, "properties": { "extent": { "$ref": "#/definitions/Vector2>", "description": "Sets both the major and minor extents to the same values." }, "extentMajor": { "$ref": "#/definitions/Vector2>", "description": "The major extent of the graticule as a two-element array of coordinates." }, "extentMinor": { "$ref": "#/definitions/Vector2>", "description": "The minor extent of the graticule as a two-element array of coordinates." }, "precision": { "description": "The precision of the graticule in degrees.\n\n__Default value:__ `2.5`", "type": "number" }, "step": { "$ref": "#/definitions/Vector2", "description": "Sets both the major and minor step angles to the same values." }, "stepMajor": { "$ref": "#/definitions/Vector2", "description": "The major step angles of the graticule.\n\n\n__Default value:__ `[90, 360]`" }, "stepMinor": { "$ref": "#/definitions/Vector2", "description": "The minor step angles of the graticule.\n\n__Default value:__ `[10, 10]`" } }, "type": "object" }, "Header": { "additionalProperties": false, "description": "Headers of row / column channels for faceted plots.", "properties": { "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "labelAlign": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Horizontal text alignment of header labels. One of `\"left\"`, `\"center\"`, or `\"right\"`." }, "labelAnchor": { "$ref": "#/definitions/TitleAnchor", "description": "The anchor position for placing the labels. One of `\"start\"`, `\"middle\"`, or `\"end\"`. For example, with a label orientation of top these anchor positions map to a left-, center-, or right-aligned label." }, "labelAngle": { "description": "The rotation angle of the header labels.\n\n__Default value:__ `0` for column header, `-90` for row header.", "maximum": 360, "minimum": -360, "type": "number" }, "labelBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The vertical text baseline for the header labels. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `titleLineHeight` rather than `titleFontSize` alone." }, "labelColor": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The color of the header label, can be in hex color code or regular color name." }, "labelExpr": { "description": "[Vega expression](https://vega.github.io/vega/docs/expressions/) for customizing labels.\n\n__Note:__ The label text and value can be assessed via the `label` and `value` properties of the header's backing `datum` object.", "type": "string" }, "labelFont": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The font of the header label." }, "labelFontSize": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The font size of the header label, in pixels.", "minimum": 0 }, "labelFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The font style of the header label." }, "labelFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The font weight of the header label." }, "labelLimit": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The maximum length of the header label in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0`, indicating no limit" }, "labelLineHeight": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Line height in pixels for multi-line header labels or title text with `\"line-top\"` or `\"line-bottom\"` baseline." }, "labelOrient": { "$ref": "#/definitions/Orient", "description": "The orientation of the header label. One of `\"top\"`, `\"bottom\"`, `\"left\"` or `\"right\"`." }, "labelPadding": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The padding, in pixel, between facet header's label and the plot.\n\n__Default value:__ `10`" }, "labels": { "description": "A boolean flag indicating if labels should be included as part of the header.\n\n__Default value:__ `true`.", "type": "boolean" }, "orient": { "$ref": "#/definitions/Orient", "description": "Shortcut for setting both labelOrient and titleOrient." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "titleAlign": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Horizontal text alignment (to the anchor) of header titles." }, "titleAnchor": { "$ref": "#/definitions/TitleAnchor", "description": "The anchor position for placing the title. One of `\"start\"`, `\"middle\"`, or `\"end\"`. For example, with an orientation of top these anchor positions map to a left-, center-, or right-aligned title." }, "titleAngle": { "description": "The rotation angle of the header title.\n\n__Default value:__ `0`.", "maximum": 360, "minimum": -360, "type": "number" }, "titleBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The vertical text baseline for the header title. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `titleLineHeight` rather than `titleFontSize` alone.\n\n__Default value:__ `\"middle\"`" }, "titleColor": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Color of the header title, can be in hex color code or regular color name." }, "titleFont": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Font of the header title. (e.g., `\"Helvetica Neue\"`)." }, "titleFontSize": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Font size of the header title.", "minimum": 0 }, "titleFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The font style of the header title." }, "titleFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Font weight of the header title. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, "titleLimit": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The maximum length of the header title in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0`, indicating no limit" }, "titleLineHeight": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Line height in pixels for multi-line header title text or title text with `\"line-top\"` or `\"line-bottom\"` baseline." }, "titleOrient": { "$ref": "#/definitions/Orient", "description": "The orientation of the header title. One of `\"top\"`, `\"bottom\"`, `\"left\"` or `\"right\"`." }, "titlePadding": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The padding, in pixel, between facet header's title and the label.\n\n__Default value:__ `10`" } }, "type": "object" }, "HeaderConfig": { "additionalProperties": false, "properties": { "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "labelAlign": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Horizontal text alignment of header labels. One of `\"left\"`, `\"center\"`, or `\"right\"`." }, "labelAnchor": { "$ref": "#/definitions/TitleAnchor", "description": "The anchor position for placing the labels. One of `\"start\"`, `\"middle\"`, or `\"end\"`. For example, with a label orientation of top these anchor positions map to a left-, center-, or right-aligned label." }, "labelAngle": { "description": "The rotation angle of the header labels.\n\n__Default value:__ `0` for column header, `-90` for row header.", "maximum": 360, "minimum": -360, "type": "number" }, "labelBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The vertical text baseline for the header labels. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `titleLineHeight` rather than `titleFontSize` alone." }, "labelColor": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The color of the header label, can be in hex color code or regular color name." }, "labelExpr": { "description": "[Vega expression](https://vega.github.io/vega/docs/expressions/) for customizing labels.\n\n__Note:__ The label text and value can be assessed via the `label` and `value` properties of the header's backing `datum` object.", "type": "string" }, "labelFont": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The font of the header label." }, "labelFontSize": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The font size of the header label, in pixels.", "minimum": 0 }, "labelFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The font style of the header label." }, "labelFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The font weight of the header label." }, "labelLimit": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The maximum length of the header label in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0`, indicating no limit" }, "labelLineHeight": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Line height in pixels for multi-line header labels or title text with `\"line-top\"` or `\"line-bottom\"` baseline." }, "labelOrient": { "$ref": "#/definitions/Orient", "description": "The orientation of the header label. One of `\"top\"`, `\"bottom\"`, `\"left\"` or `\"right\"`." }, "labelPadding": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The padding, in pixel, between facet header's label and the plot.\n\n__Default value:__ `10`" }, "labels": { "description": "A boolean flag indicating if labels should be included as part of the header.\n\n__Default value:__ `true`.", "type": "boolean" }, "orient": { "$ref": "#/definitions/Orient", "description": "Shortcut for setting both labelOrient and titleOrient." }, "title": { "description": "Set to null to disable title for the axis, legend, or header.", "type": "null" }, "titleAlign": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Horizontal text alignment (to the anchor) of header titles." }, "titleAnchor": { "$ref": "#/definitions/TitleAnchor", "description": "The anchor position for placing the title. One of `\"start\"`, `\"middle\"`, or `\"end\"`. For example, with an orientation of top these anchor positions map to a left-, center-, or right-aligned title." }, "titleAngle": { "description": "The rotation angle of the header title.\n\n__Default value:__ `0`.", "maximum": 360, "minimum": -360, "type": "number" }, "titleBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The vertical text baseline for the header title. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `titleLineHeight` rather than `titleFontSize` alone.\n\n__Default value:__ `\"middle\"`" }, "titleColor": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Color of the header title, can be in hex color code or regular color name." }, "titleFont": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Font of the header title. (e.g., `\"Helvetica Neue\"`)." }, "titleFontSize": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Font size of the header title.", "minimum": 0 }, "titleFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The font style of the header title." }, "titleFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Font weight of the header title. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, "titleLimit": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The maximum length of the header title in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0`, indicating no limit" }, "titleLineHeight": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Line height in pixels for multi-line header title text or title text with `\"line-top\"` or `\"line-bottom\"` baseline." }, "titleOrient": { "$ref": "#/definitions/Orient", "description": "The orientation of the header title. One of `\"top\"`, `\"bottom\"`, `\"left\"` or `\"right\"`." }, "titlePadding": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The padding, in pixel, between facet header's title and the label.\n\n__Default value:__ `10`" } }, "type": "object" }, "HexColor": { "format": "color-hex", "type": "string" }, "ImputeMethod": { "enum": [ "value", "median", "max", "min", "mean" ], "type": "string" }, "ImputeParams": { "additionalProperties": false, "properties": { "frame": { "description": "A frame specification as a two-element array used to control the window over which the specified method is applied. The array entries should either be a number indicating the offset from the current data object, or null to indicate unbounded rows preceding or following the current data object. For example, the value `[-5, 5]` indicates that the window should include five objects preceding and five objects following the current object.\n\n__Default value:__: `[null, null]` indicating that the window includes all objects.", "items": { "type": [ "null", "number" ] }, "maxItems": 2, "minItems": 2, "type": "array" }, "keyvals": { "anyOf": [ { "items": {}, "type": "array" }, { "$ref": "#/definitions/ImputeSequence" } ], "description": "Defines the key values that should be considered for imputation. An array of key values or an object defining a [number sequence](https://vega.github.io/vega-lite/docs/impute.html#sequence-def).\n\nIf provided, this will be used in addition to the key values observed within the input data. If not provided, the values will be derived from all unique values of the `key` field. For `impute` in `encoding`, the key field is the x-field if the y-field is imputed, or vice versa.\n\nIf there is no impute grouping, this property _must_ be specified." }, "method": { "$ref": "#/definitions/ImputeMethod", "description": "The imputation method to use for the field value of imputed data objects. One of `\"value\"`, `\"mean\"`, `\"median\"`, `\"max\"` or `\"min\"`.\n\n__Default value:__ `\"value\"`" }, "value": { "description": "The field value to use when the imputation `method` is `\"value\"`." } }, "type": "object" }, "ImputeSequence": { "additionalProperties": false, "properties": { "start": { "description": "The starting value of the sequence. __Default value:__ `0`", "type": "number" }, "step": { "description": "The step value between sequence entries. __Default value:__ `1` or `-1` if `stop < start`", "type": "number" }, "stop": { "description": "The ending value(exclusive) of the sequence.", "type": "number" } }, "required": [ "stop" ], "type": "object" }, "ImputeTransform": { "additionalProperties": false, "properties": { "frame": { "description": "A frame specification as a two-element array used to control the window over which the specified method is applied. The array entries should either be a number indicating the offset from the current data object, or null to indicate unbounded rows preceding or following the current data object. For example, the value `[-5, 5]` indicates that the window should include five objects preceding and five objects following the current object.\n\n__Default value:__: `[null, null]` indicating that the window includes all objects.", "items": { "type": [ "null", "number" ] }, "maxItems": 2, "minItems": 2, "type": "array" }, "groupby": { "description": "An optional array of fields by which to group the values. Imputation will then be performed on a per-group basis.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "impute": { "$ref": "#/definitions/FieldName", "description": "The data field for which the missing values should be imputed." }, "key": { "$ref": "#/definitions/FieldName", "description": "A key field that uniquely identifies data objects within a group. Missing key values (those occurring in the data but not in the current group) will be imputed." }, "keyvals": { "anyOf": [ { "items": {}, "type": "array" }, { "$ref": "#/definitions/ImputeSequence" } ], "description": "Defines the key values that should be considered for imputation. An array of key values or an object defining a [number sequence](https://vega.github.io/vega-lite/docs/impute.html#sequence-def).\n\nIf provided, this will be used in addition to the key values observed within the input data. If not provided, the values will be derived from all unique values of the `key` field. For `impute` in `encoding`, the key field is the x-field if the y-field is imputed, or vice versa.\n\nIf there is no impute grouping, this property _must_ be specified." }, "method": { "$ref": "#/definitions/ImputeMethod", "description": "The imputation method to use for the field value of imputed data objects. One of `\"value\"`, `\"mean\"`, `\"median\"`, `\"max\"` or `\"min\"`.\n\n__Default value:__ `\"value\"`" }, "value": { "description": "The field value to use when the imputation `method` is `\"value\"`." } }, "required": [ "impute", "key" ], "type": "object" }, "InlineData": { "additionalProperties": false, "properties": { "format": { "$ref": "#/definitions/DataFormat", "description": "An object that specifies the format for parsing the data." }, "name": { "description": "Provide a placeholder name and bind data at runtime.", "type": "string" }, "values": { "$ref": "#/definitions/InlineDataset", "description": "The full data set, included inline. This can be an array of objects or primitive values, an object, or a string. Arrays of primitive values are ingested as objects with a `data` property. Strings are parsed according to the specified format type." } }, "required": [ "values" ], "type": "object" }, "InlineDataset": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "items": { "type": "string" }, "type": "array" }, { "items": { "type": "boolean" }, "type": "array" }, { "items": { "type": "object" }, "type": "array" }, { "type": "string" }, { "type": "object" } ] }, "Interpolate": { "enum": [ "basis", "basis-open", "basis-closed", "bundle", "cardinal", "cardinal-open", "cardinal-closed", "catmull-rom", "linear", "linear-closed", "monotone", "natural", "step", "step-before", "step-after" ], "type": "string" }, "IntervalSelectionConfig": { "additionalProperties": false, "properties": { "clear": { "anyOf": [ { "$ref": "#/definitions/Stream" }, { "type": "string" }, { "type": "boolean" } ], "description": "Clears the selection, emptying it of all values. This property can be a [Event Stream](https://vega.github.io/vega/docs/event-streams/) or `false` to disable clear.\n\n__Default value:__ `dblclick`.\n\n__See also:__ [`clear` examples ](https://vega.github.io/vega-lite/docs/selection.html#clear) in the documentation." }, "encodings": { "description": "An array of encoding channels. The corresponding data field values must match for a data tuple to fall within the selection.\n\n__See also:__ The [projection with `encodings` and `fields` section](https://vega.github.io/vega-lite/docs/selection.html#project) in the documentation.", "items": { "$ref": "#/definitions/SingleDefUnitChannel" }, "type": "array" }, "fields": { "description": "An array of field names whose values must match for a data tuple to fall within the selection.\n\n__See also:__ The [projection with `encodings` and `fields` section](https://vega.github.io/vega-lite/docs/selection.html#project) in the documentation.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "mark": { "$ref": "#/definitions/BrushConfig", "description": "An interval selection also adds a rectangle mark to depict the extents of the interval. The `mark` property can be used to customize the appearance of the mark.\n\n__See also:__ [`mark` examples](https://vega.github.io/vega-lite/docs/selection.html#mark) in the documentation." }, "on": { "anyOf": [ { "$ref": "#/definitions/Stream" }, { "type": "string" } ], "description": "A [Vega event stream](https://vega.github.io/vega/docs/event-streams/) (object or selector) that triggers the selection. For interval selections, the event stream must specify a [start and end](https://vega.github.io/vega/docs/event-streams/#between-filters).\n\n__See also:__ [`on` examples](https://vega.github.io/vega-lite/docs/selection.html#on) in the documentation." }, "resolve": { "$ref": "#/definitions/SelectionResolution", "description": "With layered and multi-view displays, a strategy that determines how selections' data queries are resolved when applied in a filter transform, conditional encoding rule, or scale domain.\n\nOne of:\n- `\"global\"` -- only one brush exists for the entire SPLOM. When the user begins to drag, any previous brushes are cleared, and a new one is constructed.\n- `\"union\"` -- each cell contains its own brush, and points are highlighted if they lie within _any_ of these individual brushes.\n- `\"intersect\"` -- each cell contains its own brush, and points are highlighted only if they fall within _all_ of these individual brushes.\n\n__Default value:__ `global`.\n\n__See also:__ [`resolve` examples](https://vega.github.io/vega-lite/docs/selection.html#resolve) in the documentation." }, "translate": { "description": "When truthy, allows a user to interactively move an interval selection back-and-forth. Can be `true`, `false` (to disable panning), or a [Vega event stream definition](https://vega.github.io/vega/docs/event-streams/) which must include a start and end event to trigger continuous panning. Discrete panning (e.g., pressing the left/right arrow keys) will be supported in future versions.\n\n__Default value:__ `true`, which corresponds to `[pointerdown, window:pointerup] > window:pointermove!`. This default allows users to clicks and drags within an interval selection to reposition it.\n\n__See also:__ [`translate` examples](https://vega.github.io/vega-lite/docs/selection.html#translate) in the documentation.", "type": [ "string", "boolean" ] }, "type": { "const": "interval", "description": "Determines the default event processing and data query for the selection. Vega-Lite currently supports two selection types:\n\n- `\"point\"` -- to select multiple discrete data values; the first value is selected on `click` and additional values toggled on shift-click.\n- `\"interval\"` -- to select a continuous range of data values on `drag`.", "type": "string" }, "zoom": { "description": "When truthy, allows a user to interactively resize an interval selection. Can be `true`, `false` (to disable zooming), or a [Vega event stream definition](https://vega.github.io/vega/docs/event-streams/). Currently, only `wheel` events are supported, but custom event streams can still be used to specify filters, debouncing, and throttling. Future versions will expand the set of events that can trigger this transformation.\n\n__Default value:__ `true`, which corresponds to `wheel!`. This default allows users to use the mouse wheel to resize an interval selection.\n\n__See also:__ [`zoom` examples](https://vega.github.io/vega-lite/docs/selection.html#zoom) in the documentation.", "type": [ "string", "boolean" ] } }, "required": [ "type" ], "type": "object" }, "IntervalSelectionConfigWithoutType": { "additionalProperties": false, "properties": { "clear": { "anyOf": [ { "$ref": "#/definitions/Stream" }, { "type": "string" }, { "type": "boolean" } ], "description": "Clears the selection, emptying it of all values. This property can be a [Event Stream](https://vega.github.io/vega/docs/event-streams/) or `false` to disable clear.\n\n__Default value:__ `dblclick`.\n\n__See also:__ [`clear` examples ](https://vega.github.io/vega-lite/docs/selection.html#clear) in the documentation." }, "encodings": { "description": "An array of encoding channels. The corresponding data field values must match for a data tuple to fall within the selection.\n\n__See also:__ The [projection with `encodings` and `fields` section](https://vega.github.io/vega-lite/docs/selection.html#project) in the documentation.", "items": { "$ref": "#/definitions/SingleDefUnitChannel" }, "type": "array" }, "fields": { "description": "An array of field names whose values must match for a data tuple to fall within the selection.\n\n__See also:__ The [projection with `encodings` and `fields` section](https://vega.github.io/vega-lite/docs/selection.html#project) in the documentation.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "mark": { "$ref": "#/definitions/BrushConfig", "description": "An interval selection also adds a rectangle mark to depict the extents of the interval. The `mark` property can be used to customize the appearance of the mark.\n\n__See also:__ [`mark` examples](https://vega.github.io/vega-lite/docs/selection.html#mark) in the documentation." }, "on": { "anyOf": [ { "$ref": "#/definitions/Stream" }, { "type": "string" } ], "description": "A [Vega event stream](https://vega.github.io/vega/docs/event-streams/) (object or selector) that triggers the selection. For interval selections, the event stream must specify a [start and end](https://vega.github.io/vega/docs/event-streams/#between-filters).\n\n__See also:__ [`on` examples](https://vega.github.io/vega-lite/docs/selection.html#on) in the documentation." }, "resolve": { "$ref": "#/definitions/SelectionResolution", "description": "With layered and multi-view displays, a strategy that determines how selections' data queries are resolved when applied in a filter transform, conditional encoding rule, or scale domain.\n\nOne of:\n- `\"global\"` -- only one brush exists for the entire SPLOM. When the user begins to drag, any previous brushes are cleared, and a new one is constructed.\n- `\"union\"` -- each cell contains its own brush, and points are highlighted if they lie within _any_ of these individual brushes.\n- `\"intersect\"` -- each cell contains its own brush, and points are highlighted only if they fall within _all_ of these individual brushes.\n\n__Default value:__ `global`.\n\n__See also:__ [`resolve` examples](https://vega.github.io/vega-lite/docs/selection.html#resolve) in the documentation." }, "translate": { "description": "When truthy, allows a user to interactively move an interval selection back-and-forth. Can be `true`, `false` (to disable panning), or a [Vega event stream definition](https://vega.github.io/vega/docs/event-streams/) which must include a start and end event to trigger continuous panning. Discrete panning (e.g., pressing the left/right arrow keys) will be supported in future versions.\n\n__Default value:__ `true`, which corresponds to `[pointerdown, window:pointerup] > window:pointermove!`. This default allows users to clicks and drags within an interval selection to reposition it.\n\n__See also:__ [`translate` examples](https://vega.github.io/vega-lite/docs/selection.html#translate) in the documentation.", "type": [ "string", "boolean" ] }, "zoom": { "description": "When truthy, allows a user to interactively resize an interval selection. Can be `true`, `false` (to disable zooming), or a [Vega event stream definition](https://vega.github.io/vega/docs/event-streams/). Currently, only `wheel` events are supported, but custom event streams can still be used to specify filters, debouncing, and throttling. Future versions will expand the set of events that can trigger this transformation.\n\n__Default value:__ `true`, which corresponds to `wheel!`. This default allows users to use the mouse wheel to resize an interval selection.\n\n__See also:__ [`zoom` examples](https://vega.github.io/vega-lite/docs/selection.html#zoom) in the documentation.", "type": [ "string", "boolean" ] } }, "type": "object" }, "JoinAggregateFieldDef": { "additionalProperties": false, "properties": { "as": { "$ref": "#/definitions/FieldName", "description": "The output name for the join aggregate operation." }, "field": { "$ref": "#/definitions/FieldName", "description": "The data field for which to compute the aggregate function. This can be omitted for functions that do not operate over a field such as `\"count\"`." }, "op": { "$ref": "#/definitions/AggregateOp", "description": "The aggregation operation to apply (e.g., `\"sum\"`, `\"average\"` or `\"count\"`). See the list of all supported operations [here](https://vega.github.io/vega-lite/docs/aggregate.html#ops)." } }, "required": [ "op", "as" ], "type": "object" }, "JoinAggregateTransform": { "additionalProperties": false, "properties": { "groupby": { "description": "The data fields for partitioning the data objects into separate groups. If unspecified, all data points will be in a single group.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "joinaggregate": { "description": "The definition of the fields in the join aggregate, and what calculations to use.", "items": { "$ref": "#/definitions/JoinAggregateFieldDef" }, "type": "array" } }, "required": [ "joinaggregate" ], "type": "object" }, "JsonDataFormat": { "additionalProperties": false, "properties": { "parse": { "anyOf": [ { "$ref": "#/definitions/Parse" }, { "type": "null" } ], "description": "If set to `null`, disable type inference based on the spec and only use type inference based on the data. Alternatively, a parsing directive object can be provided for explicit data types. Each property of the object corresponds to a field name, and the value to the desired data type (one of `\"number\"`, `\"boolean\"`, `\"date\"`, or null (do not parse the field)). For example, `\"parse\": {\"modified_on\": \"date\"}` parses the `modified_on` field in each input record a Date value.\n\nFor `\"date\"`, we parse data based using JavaScript's [`Date.parse()`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Date/parse). For Specific date formats can be provided (e.g., `{foo: \"date:'%m%d%Y'\"}`), using the [d3-time-format syntax](https://github.com/d3/d3-time-format#locale_format). UTC date format parsing is supported similarly (e.g., `{foo: \"utc:'%m%d%Y'\"}`). See more about [UTC time](https://vega.github.io/vega-lite/docs/timeunit.html#utc)" }, "property": { "description": "The JSON property containing the desired data. This parameter can be used when the loaded JSON file may have surrounding structure or meta-data. For example `\"property\": \"values.features\"` is equivalent to retrieving `json.values.features` from the loaded JSON object.", "type": "string" }, "type": { "const": "json", "description": "Type of input data: `\"json\"`, `\"csv\"`, `\"tsv\"`, `\"dsv\"`.\n\n__Default value:__ The default format type is determined by the extension of the file URL. If no extension is detected, `\"json\"` will be used by default.", "type": "string" } }, "type": "object" }, "LabelOverlap": { "anyOf": [ { "type": "boolean" }, { "const": "parity", "type": "string" }, { "const": "greedy", "type": "string" } ] }, "LatLongDef": { "anyOf": [ { "$ref": "#/definitions/LatLongFieldDef" }, { "$ref": "#/definitions/DatumDef" } ] }, "LatLongFieldDef": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "const": "quantitative", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation.", "type": "string" } }, "type": "object" }, "LayerRepeatMapping": { "additionalProperties": false, "properties": { "column": { "description": "An array of fields to be repeated horizontally.", "items": { "type": "string" }, "type": "array" }, "layer": { "description": "An array of fields to be repeated as layers.", "items": { "type": "string" }, "type": "array" }, "row": { "description": "An array of fields to be repeated vertically.", "items": { "type": "string" }, "type": "array" } }, "required": [ "layer" ], "type": "object" }, "LayerRepeatSpec": { "additionalProperties": false, "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "columns": { "description": "The number of columns to include in the view composition layout.\n\n__Default value__: `undefined` -- An infinite number of columns (a single row) will be assumed. This is equivalent to `hconcat` (for `concat`) and to using the `column` channel (for `facet` and `repeat`).\n\n__Note__:\n\n1) This property is only for:\n- the general (wrappable) `concat` operator (not `hconcat`/`vconcat`)\n- the `facet` and `repeat` operator with one field/repetition definition (without row/column nesting)\n\n2) Setting the `columns` to `1` is equivalent to `vconcat` (for `concat`) and to using the `row` channel (for `facet` and `repeat`).", "type": "number" }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "repeat": { "$ref": "#/definitions/LayerRepeatMapping", "description": "Definition for fields to be repeated. One of: 1) An array of fields to be repeated. If `\"repeat\"` is an array, the field can be referred to as `{\"repeat\": \"repeat\"}`. The repeated views are laid out in a wrapped row. You can set the number of columns to control the wrapping. 2) An object that maps `\"row\"` and/or `\"column\"` to the listed fields to be repeated along the particular orientations. The objects `{\"repeat\": \"row\"}` and `{\"repeat\": \"column\"}` can be used to refer to the repeated field respectively." }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "spec": { "anyOf": [ { "$ref": "#/definitions/LayerSpec" }, { "$ref": "#/definitions/UnitSpecWithFrame" } ], "description": "A specification of the view that gets repeated." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" } }, "required": [ "repeat", "spec" ], "type": "object" }, "LayerSpec": { "additionalProperties": false, "description": "A full layered plot specification, which may contains `encoding` and `projection` properties that will be applied to underlying unit (single-view) specifications.", "properties": { "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "encoding": { "$ref": "#/definitions/SharedEncoding", "description": "A shared key-value mapping between encoding channels and definition of fields in the underlying layers." }, "height": { "anyOf": [ { "type": "number" }, { "const": "container", "type": "string" }, { "$ref": "#/definitions/Step" } ], "description": "The height of a visualization.\n\n- For a plot with a continuous y-field, height should be a number.\n- For a plot with either a discrete y-field or no y-field, height can be either a number indicating a fixed height or an object in the form of `{step: number}` defining the height per discrete step. (No y-field is equivalent to having one discrete step.)\n- To enable responsive sizing on height, it should be set to `\"container\"`.\n\n__Default value:__ Based on `config.view.continuousHeight` for a plot with a continuous y-field and `config.view.discreteHeight` otherwise.\n\n__Note:__ For plots with [`row` and `column` channels](https://vega.github.io/vega-lite/docs/encoding.html#facet), this represents the height of a single view and the `\"container\"` option cannot be used.\n\n__See also:__ [`height`](https://vega.github.io/vega-lite/docs/size.html) documentation." }, "layer": { "description": "Layer or single view specifications to be layered.\n\n__Note__: Specifications inside `layer` cannot use `row` and `column` channels as layering facet specifications is not allowed. Instead, use the [facet operator](https://vega.github.io/vega-lite/docs/facet.html) and place a layer inside a facet.", "items": { "anyOf": [ { "$ref": "#/definitions/LayerSpec" }, { "$ref": "#/definitions/UnitSpec" } ] }, "type": "array" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "projection": { "$ref": "#/definitions/Projection", "description": "An object defining properties of the geographic projection shared by underlying layers." }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "view": { "$ref": "#/definitions/ViewBackground", "description": "An object defining the view background's fill and stroke.\n\n__Default value:__ none (transparent)" }, "width": { "anyOf": [ { "type": "number" }, { "const": "container", "type": "string" }, { "$ref": "#/definitions/Step" } ], "description": "The width of a visualization.\n\n- For a plot with a continuous x-field, width should be a number.\n- For a plot with either a discrete x-field or no x-field, width can be either a number indicating a fixed width or an object in the form of `{step: number}` defining the width per discrete step. (No x-field is equivalent to having one discrete step.)\n- To enable responsive sizing on width, it should be set to `\"container\"`.\n\n__Default value:__ Based on `config.view.continuousWidth` for a plot with a continuous x-field and `config.view.discreteWidth` otherwise.\n\n__Note:__ For plots with [`row` and `column` channels](https://vega.github.io/vega-lite/docs/encoding.html#facet), this represents the width of a single view and the `\"container\"` option cannot be used.\n\n__See also:__ [`width`](https://vega.github.io/vega-lite/docs/size.html) documentation." } }, "required": [ "layer" ], "type": "object" }, "LayoutAlign": { "enum": [ "all", "each", "none" ], "type": "string" }, "Legend": { "additionalProperties": false, "description": "Properties of a legend or boolean flag for determining whether to show it.", "properties": { "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG group, removing the legend from the ARIA accessibility tree.\n\n__Default value:__ `true`", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "clipHeight": { "anyOf": [ { "description": "The height in pixels to clip symbol legend entries and limit their size.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "columnPadding": { "anyOf": [ { "description": "The horizontal padding in pixels between symbol legend entries.\n\n__Default value:__ `10`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "columns": { "anyOf": [ { "description": "The number of columns in which to arrange symbol legend entries. A value of `0` or lower indicates a single row with one column per entry.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadius": { "anyOf": [ { "description": "Corner radius for the full legend.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of this legend for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If the `aria` property is true, for SVG output the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute) will be set to this description. If the description is unspecified it will be automatically generated.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "direction": { "$ref": "#/definitions/Orientation", "description": "The direction of the legend, one of `\"vertical\"` or `\"horizontal\"`.\n\n__Default value:__\n- For top-/bottom-`orient`ed legends, `\"horizontal\"`\n- For left-/right-`orient`ed legends, `\"vertical\"`\n- For top/bottom-left/right-`orient`ed legends, `\"horizontal\"` for gradient legends and `\"vertical\"` for symbol legends." }, "fillColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Background fill color for the full legend." }, { "$ref": "#/definitions/ExprRef" } ] }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "gradientLength": { "anyOf": [ { "description": "The length in pixels of the primary axis of a color gradient. This value corresponds to the height of a vertical gradient or the width of a horizontal gradient.\n\n__Default value:__ `200`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientOpacity": { "anyOf": [ { "description": "Opacity of the color gradient.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientStrokeColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the gradient stroke, can be in hex color code or regular color name.\n\n__Default value:__ `\"lightGray\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientStrokeWidth": { "anyOf": [ { "description": "The width of the gradient stroke, in pixels.\n\n__Default value:__ `0`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientThickness": { "anyOf": [ { "description": "The thickness in pixels of the color gradient. This value corresponds to the width of a vertical gradient or the height of a horizontal gradient.\n\n__Default value:__ `16`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "gridAlign": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign", "description": "The alignment to apply to symbol legends rows and columns. The supported string values are `\"all\"`, `\"each\"` (the default), and `none`. For more information, see the [grid layout documentation](https://vega.github.io/vega/docs/layout).\n\n__Default value:__ `\"each\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelAlign": { "anyOf": [ { "$ref": "#/definitions/Align", "description": "The alignment of the legend label, can be left, center, or right." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline", "description": "The position of the baseline of legend label, can be `\"top\"`, `\"middle\"`, `\"bottom\"`, or `\"alphabetic\"`.\n\n__Default value:__ `\"middle\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the legend label, can be in hex color code or regular color name." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelExpr": { "description": "[Vega expression](https://vega.github.io/vega/docs/expressions/) for customizing labels.\n\n__Note:__ The label text and value can be assessed via the `label` and `value` properties of the legend's backing `datum` object.", "type": "string" }, "labelFont": { "anyOf": [ { "description": "The font of the legend label.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelFontSize": { "anyOf": [ { "description": "The font size of legend label.\n\n__Default value:__ `10`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style of legend label." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight of legend label." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelLimit": { "anyOf": [ { "description": "Maximum allowed pixel width of legend tick labels.\n\n__Default value:__ `160`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelOffset": { "anyOf": [ { "description": "The offset of the legend label.\n\n__Default value:__ `4`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelOpacity": { "anyOf": [ { "description": "Opacity of labels.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelOverlap": { "anyOf": [ { "$ref": "#/definitions/LabelOverlap", "description": "The strategy to use for resolving overlap of labels in gradient legends. If `false`, no overlap reduction is attempted. If set to `true` (default) or `\"parity\"`, a strategy of removing every other label is used. If set to `\"greedy\"`, a linear scan of the labels is performed, removing any label that overlaps with the last visible label (this often works better for log-scaled axes).\n\n__Default value:__ `true`." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelPadding": { "anyOf": [ { "description": "Padding in pixels between the legend and legend labels.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelSeparation": { "anyOf": [ { "description": "The minimum separation that must be between label bounding boxes for them to be considered non-overlapping (default `0`). This property is ignored if *labelOverlap* resolution is not enabled.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "legendX": { "anyOf": [ { "description": "Custom x-position for legend with orient \"none\".", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "legendY": { "anyOf": [ { "description": "Custom y-position for legend with orient \"none\".", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "offset": { "anyOf": [ { "description": "The offset in pixels by which to displace the legend from the data rectangle and axes.\n\n__Default value:__ `18`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "orient": { "$ref": "#/definitions/LegendOrient", "description": "The orientation of the legend, which determines how the legend is positioned within the scene. One of `\"left\"`, `\"right\"`, `\"top\"`, `\"bottom\"`, `\"top-left\"`, `\"top-right\"`, `\"bottom-left\"`, `\"bottom-right\"`, `\"none\"`.\n\n__Default value:__ `\"right\"`" }, "padding": { "anyOf": [ { "description": "The padding between the border and content of the legend group.\n\n__Default value:__ `0`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "rowPadding": { "anyOf": [ { "description": "The vertical padding in pixels between symbol legend entries.\n\n__Default value:__ `2`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Border stroke color for the full legend." }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolDash": { "anyOf": [ { "description": "An array of alternating [stroke, space] lengths for dashed symbol strokes.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolDashOffset": { "anyOf": [ { "description": "The pixel offset at which to start drawing with the symbol stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolFillColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the legend symbol," }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolLimit": { "anyOf": [ { "description": "The maximum number of allowed entries for a symbol legend. Additional entries will be dropped.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolOffset": { "anyOf": [ { "description": "Horizontal pixel offset for legend symbols.\n\n__Default value:__ `0`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolOpacity": { "anyOf": [ { "description": "Opacity of the legend symbols.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolSize": { "anyOf": [ { "description": "The size of the legend symbol, in pixels.\n\n__Default value:__ `100`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolStrokeColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Stroke color for legend symbols." }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolStrokeWidth": { "anyOf": [ { "description": "The width of the symbol's stroke.\n\n__Default value:__ `1.5`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolType": { "anyOf": [ { "$ref": "#/definitions/SymbolShape", "description": "The symbol shape. One of the plotting shapes `circle` (default), `square`, `cross`, `diamond`, `triangle-up`, `triangle-down`, `triangle-right`, or `triangle-left`, the line symbol `stroke`, or one of the centered directional shapes `arrow`, `wedge`, or `triangle`. Alternatively, a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) can be provided. For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.\n\n__Default value:__ `\"circle\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "tickCount": { "anyOf": [ { "$ref": "#/definitions/TickCount", "description": "The desired number of tick values for quantitative legends." }, { "$ref": "#/definitions/ExprRef" } ] }, "tickMinStep": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The minimum desired step between legend ticks, in terms of scale domain values. For example, a value of `1` indicates that ticks should not be less than 1 unit apart. If `tickMinStep` is specified, the `tickCount` value will be adjusted, if necessary, to enforce the minimum step value.\n\n__Default value__: `undefined`" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "titleAlign": { "anyOf": [ { "$ref": "#/definitions/Align", "description": "Horizontal text alignment for legend titles.\n\n__Default value:__ `\"left\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleAnchor": { "anyOf": [ { "$ref": "#/definitions/TitleAnchor", "description": "Text anchor position for placing legend titles." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline", "description": "Vertical text baseline for legend titles. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the *lineHeight* rather than *fontSize* alone.\n\n__Default value:__ `\"top\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the legend title, can be in hex color code or regular color name." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFont": { "anyOf": [ { "description": "The font of the legend title.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontSize": { "anyOf": [ { "description": "The font size of the legend title.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style of the legend title." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight of the legend title. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleLimit": { "anyOf": [ { "description": "Maximum allowed pixel width of legend titles.\n\n__Default value:__ `180`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleLineHeight": { "anyOf": [ { "description": "Line height in pixels for multi-line title text or title text with `\"line-top\"` or `\"line-bottom\"` baseline.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleOpacity": { "anyOf": [ { "description": "Opacity of the legend title.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleOrient": { "anyOf": [ { "$ref": "#/definitions/Orient", "description": "Orientation of the legend title." }, { "$ref": "#/definitions/ExprRef" } ] }, "titlePadding": { "anyOf": [ { "description": "The padding, in pixels, between title and legend.\n\n__Default value:__ `5`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "type": { "description": "The type of the legend. Use `\"symbol\"` to create a discrete legend and `\"gradient\"` for a continuous color gradient.\n\n__Default value:__ `\"gradient\"` for non-binned quantitative fields and temporal fields; `\"symbol\"` otherwise.", "enum": [ "symbol", "gradient" ], "type": "string" }, "values": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "items": { "type": "string" }, "type": "array" }, { "items": { "type": "boolean" }, "type": "array" }, { "items": { "$ref": "#/definitions/DateTime" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Explicitly set the visible legend values." }, "zindex": { "description": "A non-negative integer indicating the z-index of the legend. If zindex is 0, legend should be drawn behind all chart elements. To put them in front, use zindex = 1.", "minimum": 0, "type": "number" } }, "type": "object" }, "LegendBinding": { "anyOf": [ { "const": "legend", "type": "string" }, { "$ref": "#/definitions/LegendStreamBinding" } ] }, "LegendConfig": { "additionalProperties": false, "properties": { "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG group, removing the legend from the ARIA accessibility tree.\n\n__Default value:__ `true`", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "clipHeight": { "anyOf": [ { "description": "The height in pixels to clip symbol legend entries and limit their size.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "columnPadding": { "anyOf": [ { "description": "The horizontal padding in pixels between symbol legend entries.\n\n__Default value:__ `10`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "columns": { "anyOf": [ { "description": "The number of columns in which to arrange symbol legend entries. A value of `0` or lower indicates a single row with one column per entry.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadius": { "anyOf": [ { "description": "Corner radius for the full legend.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of this legend for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If the `aria` property is true, for SVG output the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute) will be set to this description. If the description is unspecified it will be automatically generated.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "direction": { "$ref": "#/definitions/Orientation", "description": "The direction of the legend, one of `\"vertical\"` or `\"horizontal\"`.\n\n__Default value:__\n- For top-/bottom-`orient`ed legends, `\"horizontal\"`\n- For left-/right-`orient`ed legends, `\"vertical\"`\n- For top/bottom-left/right-`orient`ed legends, `\"horizontal\"` for gradient legends and `\"vertical\"` for symbol legends." }, "disable": { "description": "Disable legend by default", "type": "boolean" }, "fillColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Background fill color for the full legend." }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientDirection": { "anyOf": [ { "$ref": "#/definitions/Orientation", "description": "The default direction (`\"horizontal\"` or `\"vertical\"`) for gradient legends.\n\n__Default value:__ `\"vertical\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientHorizontalMaxLength": { "description": "Max legend length for a horizontal gradient when `config.legend.gradientLength` is undefined.\n\n__Default value:__ `200`", "type": "number" }, "gradientHorizontalMinLength": { "description": "Min legend length for a horizontal gradient when `config.legend.gradientLength` is undefined.\n\n__Default value:__ `100`", "type": "number" }, "gradientLabelLimit": { "anyOf": [ { "description": "The maximum allowed length in pixels of color ramp gradient labels.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientLabelOffset": { "anyOf": [ { "description": "Vertical offset in pixels for color ramp gradient labels.\n\n__Default value:__ `2`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientLength": { "anyOf": [ { "description": "The length in pixels of the primary axis of a color gradient. This value corresponds to the height of a vertical gradient or the width of a horizontal gradient.\n\n__Default value:__ `200`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientOpacity": { "anyOf": [ { "description": "Opacity of the color gradient.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientStrokeColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the gradient stroke, can be in hex color code or regular color name.\n\n__Default value:__ `\"lightGray\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientStrokeWidth": { "anyOf": [ { "description": "The width of the gradient stroke, in pixels.\n\n__Default value:__ `0`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientThickness": { "anyOf": [ { "description": "The thickness in pixels of the color gradient. This value corresponds to the width of a vertical gradient or the height of a horizontal gradient.\n\n__Default value:__ `16`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "gradientVerticalMaxLength": { "description": "Max legend length for a vertical gradient when `config.legend.gradientLength` is undefined.\n\n__Default value:__ `200`", "type": "number" }, "gradientVerticalMinLength": { "description": "Min legend length for a vertical gradient when `config.legend.gradientLength` is undefined.\n\n__Default value:__ `100`", "type": "number" }, "gridAlign": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign", "description": "The alignment to apply to symbol legends rows and columns. The supported string values are `\"all\"`, `\"each\"` (the default), and `none`. For more information, see the [grid layout documentation](https://vega.github.io/vega/docs/layout).\n\n__Default value:__ `\"each\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelAlign": { "anyOf": [ { "$ref": "#/definitions/Align", "description": "The alignment of the legend label, can be left, center, or right." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline", "description": "The position of the baseline of legend label, can be `\"top\"`, `\"middle\"`, `\"bottom\"`, or `\"alphabetic\"`.\n\n__Default value:__ `\"middle\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the legend label, can be in hex color code or regular color name." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelFont": { "anyOf": [ { "description": "The font of the legend label.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelFontSize": { "anyOf": [ { "description": "The font size of legend label.\n\n__Default value:__ `10`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style of legend label." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight of legend label." }, { "$ref": "#/definitions/ExprRef" } ] }, "labelLimit": { "anyOf": [ { "description": "Maximum allowed pixel width of legend tick labels.\n\n__Default value:__ `160`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelOffset": { "anyOf": [ { "description": "The offset of the legend label.\n\n__Default value:__ `4`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelOpacity": { "anyOf": [ { "description": "Opacity of labels.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelOverlap": { "anyOf": [ { "$ref": "#/definitions/LabelOverlap" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The strategy to use for resolving overlap of labels in gradient legends. If `false`, no overlap reduction is attempted. If set to `true` or `\"parity\"`, a strategy of removing every other label is used. If set to `\"greedy\"`, a linear scan of the labels is performed, removing any label that overlaps with the last visible label (this often works better for log-scaled axes).\n\n__Default value:__ `\"greedy\"` for `log scales otherwise `true`." }, "labelPadding": { "anyOf": [ { "description": "Padding in pixels between the legend and legend labels.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "labelSeparation": { "anyOf": [ { "description": "The minimum separation that must be between label bounding boxes for them to be considered non-overlapping (default `0`). This property is ignored if *labelOverlap* resolution is not enabled.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "layout": { "$ref": "#/definitions/ExprRef" }, "legendX": { "anyOf": [ { "description": "Custom x-position for legend with orient \"none\".", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "legendY": { "anyOf": [ { "description": "Custom y-position for legend with orient \"none\".", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "offset": { "anyOf": [ { "description": "The offset in pixels by which to displace the legend from the data rectangle and axes.\n\n__Default value:__ `18`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "orient": { "$ref": "#/definitions/LegendOrient", "description": "The orientation of the legend, which determines how the legend is positioned within the scene. One of `\"left\"`, `\"right\"`, `\"top\"`, `\"bottom\"`, `\"top-left\"`, `\"top-right\"`, `\"bottom-left\"`, `\"bottom-right\"`, `\"none\"`.\n\n__Default value:__ `\"right\"`" }, "padding": { "anyOf": [ { "description": "The padding between the border and content of the legend group.\n\n__Default value:__ `0`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "rowPadding": { "anyOf": [ { "description": "The vertical padding in pixels between symbol legend entries.\n\n__Default value:__ `2`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Border stroke color for the full legend." }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "Border stroke dash pattern for the full legend.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "Border stroke width for the full legend.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolBaseFillColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Default fill color for legend symbols. Only applied if there is no `\"fill\"` scale color encoding for the legend.\n\n__Default value:__ `\"transparent\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolBaseStrokeColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Default stroke color for legend symbols. Only applied if there is no `\"fill\"` scale color encoding for the legend.\n\n__Default value:__ `\"gray\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolDash": { "anyOf": [ { "description": "An array of alternating [stroke, space] lengths for dashed symbol strokes.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolDashOffset": { "anyOf": [ { "description": "The pixel offset at which to start drawing with the symbol stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolDirection": { "anyOf": [ { "$ref": "#/definitions/Orientation", "description": "The default direction (`\"horizontal\"` or `\"vertical\"`) for symbol legends.\n\n__Default value:__ `\"vertical\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolFillColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the legend symbol," }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolLimit": { "anyOf": [ { "description": "The maximum number of allowed entries for a symbol legend. Additional entries will be dropped.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolOffset": { "anyOf": [ { "description": "Horizontal pixel offset for legend symbols.\n\n__Default value:__ `0`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolOpacity": { "anyOf": [ { "description": "Opacity of the legend symbols.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolSize": { "anyOf": [ { "description": "The size of the legend symbol, in pixels.\n\n__Default value:__ `100`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolStrokeColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Stroke color for legend symbols." }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolStrokeWidth": { "anyOf": [ { "description": "The width of the symbol's stroke.\n\n__Default value:__ `1.5`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "symbolType": { "anyOf": [ { "$ref": "#/definitions/SymbolShape", "description": "The symbol shape. One of the plotting shapes `circle` (default), `square`, `cross`, `diamond`, `triangle-up`, `triangle-down`, `triangle-right`, or `triangle-left`, the line symbol `stroke`, or one of the centered directional shapes `arrow`, `wedge`, or `triangle`. Alternatively, a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) can be provided. For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.\n\n__Default value:__ `\"circle\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "tickCount": { "anyOf": [ { "$ref": "#/definitions/TickCount", "description": "The desired number of tick values for quantitative legends." }, { "$ref": "#/definitions/ExprRef" } ] }, "title": { "description": "Set to null to disable title for the axis, legend, or header.", "type": "null" }, "titleAlign": { "anyOf": [ { "$ref": "#/definitions/Align", "description": "Horizontal text alignment for legend titles.\n\n__Default value:__ `\"left\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleAnchor": { "anyOf": [ { "$ref": "#/definitions/TitleAnchor", "description": "Text anchor position for placing legend titles." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleBaseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline", "description": "Vertical text baseline for legend titles. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the *lineHeight* rather than *fontSize* alone.\n\n__Default value:__ `\"top\"`." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "The color of the legend title, can be in hex color code or regular color name." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFont": { "anyOf": [ { "description": "The font of the legend title.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontSize": { "anyOf": [ { "description": "The font size of the legend title.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style of the legend title." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight of the legend title. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "titleLimit": { "anyOf": [ { "description": "Maximum allowed pixel width of legend titles.\n\n__Default value:__ `180`.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleLineHeight": { "anyOf": [ { "description": "Line height in pixels for multi-line title text or title text with `\"line-top\"` or `\"line-bottom\"` baseline.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleOpacity": { "anyOf": [ { "description": "Opacity of the legend title.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "titleOrient": { "anyOf": [ { "$ref": "#/definitions/Orient", "description": "Orientation of the legend title." }, { "$ref": "#/definitions/ExprRef" } ] }, "titlePadding": { "anyOf": [ { "description": "The padding, in pixels, between title and legend.\n\n__Default value:__ `5`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "unselectedOpacity": { "description": "The opacity of unselected legend entries.\n\n__Default value:__ 0.35.", "type": "number" }, "zindex": { "anyOf": [ { "description": "The integer z-index indicating the layering of the legend group relative to other axis, mark, and legend groups.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] } }, "type": "object" }, "LegendOrient": { "enum": [ "none", "left", "right", "top", "bottom", "top-left", "top-right", "bottom-left", "bottom-right" ], "type": "string" }, "LegendResolveMap": { "additionalProperties": false, "properties": { "angle": { "$ref": "#/definitions/ResolveMode" }, "color": { "$ref": "#/definitions/ResolveMode" }, "fill": { "$ref": "#/definitions/ResolveMode" }, "fillOpacity": { "$ref": "#/definitions/ResolveMode" }, "opacity": { "$ref": "#/definitions/ResolveMode" }, "shape": { "$ref": "#/definitions/ResolveMode" }, "size": { "$ref": "#/definitions/ResolveMode" }, "stroke": { "$ref": "#/definitions/ResolveMode" }, "strokeDash": { "$ref": "#/definitions/ResolveMode" }, "strokeOpacity": { "$ref": "#/definitions/ResolveMode" }, "strokeWidth": { "$ref": "#/definitions/ResolveMode" } }, "type": "object" }, "LegendStreamBinding": { "additionalProperties": false, "properties": { "legend": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Stream" } ] } }, "required": [ "legend" ], "type": "object" }, "LineConfig": { "additionalProperties": false, "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The horizontal alignment of the text or ranged marks (area, bar, image, rect, rule). One of `\"left\"`, `\"right\"`, `\"center\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "angle": { "anyOf": [ { "description": "The rotation angle of the text, in degrees.", "maximum": 360, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG element, removing the mark item from the ARIA accessibility tree.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRole": { "anyOf": [ { "description": "Sets the type of user interface element of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"role\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRoleDescription": { "anyOf": [ { "description": "A human-readable, author-localized description for the role of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"aria-roledescription\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "aspect": { "anyOf": [ { "description": "Whether to keep aspect ratio of image marks.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "baseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For text marks, the vertical text baseline. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, `\"line-bottom\"`, or an expression reference that provides one of the valid values. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `lineHeight` rather than `fontSize` alone.\n\nFor range marks, the vertical alignment of the marks. One of `\"top\"`, `\"middle\"`, `\"bottom\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "blend": { "anyOf": [ { "$ref": "#/definitions/Blend", "description": "The color blend mode for drawing an item on its current background. Any valid [CSS mix-blend-mode](https://developer.mozilla.org/en-US/docs/Web/CSS/mix-blend-mode) value can be used.\n\n__Default value: `\"source-over\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "cornerRadius": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles or arcs' corners.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cursor": { "anyOf": [ { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the mark. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "dir": { "anyOf": [ { "$ref": "#/definitions/TextDirection", "description": "The direction of the text. One of `\"ltr\"` (left-to-right) or `\"rtl\"` (right-to-left). This property determines on which side is truncated in response to the limit parameter.\n\n__Default value:__ `\"ltr\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "dx": { "anyOf": [ { "description": "The horizontal offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "dy": { "anyOf": [ { "description": "The vertical offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "ellipsis": { "anyOf": [ { "description": "The ellipsis string for text truncated in response to the limit parameter.\n\n__Default value:__ `\"…\"`", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "endAngle": { "anyOf": [ { "description": "The end angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fill": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default fill color. This property has higher precedence than `config.color`. Set to `null` to remove fill.\n\n__Default value:__ (None)" }, "fillOpacity": { "anyOf": [ { "description": "The fill opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "filled": { "description": "Whether the mark's color should be used as fill color instead of stroke color.\n\n__Default value:__ `false` for all `point`, `line`, and `rule` marks as well as `geoshape` marks for [`graticule`](https://vega.github.io/vega-lite/docs/data.html#graticule) data sources; otherwise, `true`.\n\n__Note:__ This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).", "type": "boolean" }, "font": { "anyOf": [ { "description": "The typeface to set the text in (e.g., `\"Helvetica Neue\"`).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontSize": { "anyOf": [ { "description": "The font size, in pixels.\n\n__Default value:__ `11`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style (e.g., `\"italic\"`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "fontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "height": { "anyOf": [ { "description": "Height of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "href": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "A URL to load upon mouse click. If defined, the mark acts as a hyperlink." }, { "$ref": "#/definitions/ExprRef" } ] }, "innerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The inner radius in pixels of arc marks. `innerRadius` is an alias for `radius2`.\n\n__Default value:__ `0`", "minimum": 0 }, "interpolate": { "anyOf": [ { "$ref": "#/definitions/Interpolate", "description": "The line interpolation method to use for line and area marks. One of the following:\n- `\"linear\"`: piecewise linear segments, as in a polyline.\n- `\"linear-closed\"`: close the linear segments to form a polygon.\n- `\"step\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"step-before\"`: alternate between vertical and horizontal segments, as in a step function.\n- `\"step-after\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"basis\"`: a B-spline, with control point duplication on the ends.\n- `\"basis-open\"`: an open B-spline; may not intersect the start or end.\n- `\"basis-closed\"`: a closed B-spline, as in a loop.\n- `\"cardinal\"`: a Cardinal spline, with control point duplication on the ends.\n- `\"cardinal-open\"`: an open Cardinal spline; may not intersect the start or end, but will intersect other control points.\n- `\"cardinal-closed\"`: a closed Cardinal spline, as in a loop.\n- `\"bundle\"`: equivalent to basis, except the tension parameter is used to straighten the spline.\n- `\"monotone\"`: cubic interpolation that preserves monotonicity in y." }, { "$ref": "#/definitions/ExprRef" } ] }, "invalid": { "description": "Defines how Vega-Lite should handle marks for invalid values (`null` and `NaN`).\n- If set to `\"filter\"` (default), all data items with null values will be skipped (for line, trail, and area marks) or filtered (for other marks).\n- If `null`, all data items are included. In this case, invalid values will be interpreted as zeroes.", "enum": [ "filter", null ], "type": [ "string", "null" ] }, "limit": { "anyOf": [ { "description": "The maximum length of the text mark in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0` -- indicating no limit", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineBreak": { "anyOf": [ { "description": "A delimiter, such as a newline character, upon which to break text strings into multiple lines. This property is ignored if the text is array-valued.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineHeight": { "anyOf": [ { "description": "The line height in pixels (the spacing between subsequent lines of text) for multi-line text marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "opacity": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The overall opacity (value between [0,1]).\n\n__Default value:__ `0.7` for non-aggregate plots with `point`, `tick`, `circle`, or `square` marks or layered `bar` charts and `1` otherwise.", "maximum": 1, "minimum": 0 }, "order": { "description": "For line and trail marks, this `order` property can be set to `null` or `false` to make the lines use the original order in the data sources.", "type": [ "null", "boolean" ] }, "orient": { "$ref": "#/definitions/Orientation", "description": "The orientation of a non-stacked bar, tick, area, and line charts. The value is either horizontal (default) or vertical.\n- For bar, rule and tick, this determines whether the size of the bar and tick should be applied to x or y dimension.\n- For area, this property determines the orient property of the Vega output.\n- For line and trail marks, this property determines the sort order of the points in the line if `config.sortLineBy` is not specified. For stacked charts, this is always determined by the orientation of the stack; therefore explicitly specified value will be ignored." }, "outerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The outer radius in pixels of arc marks. `outerRadius` is an alias for `radius`.\n\n__Default value:__ `0`", "minimum": 0 }, "padAngle": { "anyOf": [ { "description": "The angular padding applied to sides of the arc, in radians.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "point": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/OverlayMarkDef" }, { "const": "transparent", "type": "string" } ], "description": "A flag for overlaying points on top of line or area marks, or an object defining the properties of the overlayed points.\n\n- If this property is `\"transparent\"`, transparent points will be used (for enhancing tooltips and selections).\n\n- If this property is an empty object (`{}`) or `true`, filled points with default properties will be used.\n\n- If this property is `false`, no points would be automatically added to line or area marks.\n\n__Default value:__ `false`." }, "radius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For arc mark, the primary (outer) radius in pixels.\n\nFor text marks, polar coordinate radial offset, in pixels, of the text from the origin determined by the `x` and `y` properties.\n\n__Default value:__ `min(plot_width, plot_height)/2`", "minimum": 0 }, "radius2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The secondary (inner) radius in pixels of arc marks.\n\n__Default value:__ `0`", "minimum": 0 }, "shape": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/SymbolShape" }, { "type": "string" } ], "description": "Shape of the point marks. Supported values include:\n- plotting shapes: `\"circle\"`, `\"square\"`, `\"cross\"`, `\"diamond\"`, `\"triangle-up\"`, `\"triangle-down\"`, `\"triangle-right\"`, or `\"triangle-left\"`.\n- the line symbol `\"stroke\"`\n- centered directional shapes `\"arrow\"`, `\"wedge\"`, or `\"triangle\"`\n- a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) (For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.)\n\n__Default value:__ `\"circle\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "size": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default size for marks.\n- For `point`/`circle`/`square`, this represents the pixel area of the marks. Note that this value sets the area of the symbol; the side lengths will increase with the square root of this value.\n- For `bar`, this represents the band size of the bar, in pixels.\n- For `text`, this represents the font size, in pixels.\n\n__Default value:__\n- `30` for point, circle, square marks; width/height's `step`\n- `2` for bar marks with discrete dimensions;\n- `5` for bar marks with continuous dimensions;\n- `11` for text marks.", "minimum": 0 }, "smooth": { "anyOf": [ { "description": "A boolean flag (default true) indicating if the image should be smoothed when resized. If false, individual pixels should be scaled directly rather than interpolated with smoothing. For SVG rendering, this option may not work in some browsers due to lack of standardization.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "startAngle": { "anyOf": [ { "description": "The start angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "stroke": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default stroke color. This property has higher precedence than `config.color`. Set to `null` to remove stroke.\n\n__Default value:__ (None)" }, "strokeCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for line ending style. One of `\"butt\"`, `\"round\"`, or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "An array of alternating stroke, space lengths for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDashOffset": { "anyOf": [ { "description": "The offset (in pixels) into which to begin drawing with the stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeJoin": { "anyOf": [ { "$ref": "#/definitions/StrokeJoin", "description": "The stroke line join method. One of `\"miter\"`, `\"round\"` or `\"bevel\"`.\n\n__Default value:__ `\"miter\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeMiterLimit": { "anyOf": [ { "description": "The miter limit at which to bevel a line join.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOffset": { "anyOf": [ { "description": "The offset in pixels at which to draw the group stroke and fill. If unspecified, the default behavior is to dynamically offset stroked groups such that 1 pixel stroke widths align with the pixel grid.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOpacity": { "anyOf": [ { "description": "The stroke opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "The stroke width, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "tension": { "anyOf": [ { "description": "Depending on the interpolation type, sets the tension parameter (for line and area marks).", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "text": { "anyOf": [ { "$ref": "#/definitions/Text", "description": "Placeholder text if the `text` channel is not specified" }, { "$ref": "#/definitions/ExprRef" } ] }, "theta": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "- For arc marks, the arc length in radians if theta2 is not specified, otherwise the start arc angle. (A value of 0 indicates up or “north”, increasing values proceed clockwise.)\n\n- For text marks, polar coordinate angle in radians.", "maximum": 360, "minimum": 0 }, "theta2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The end angle of arc marks in radians. A value of 0 indicates up or “north”, increasing values proceed clockwise." }, "timeUnitBandPosition": { "description": "Default relative band position for a time unit. If set to `0`, the marks will be positioned at the beginning of the time unit band step. If set to `0.5`, the marks will be positioned in the middle of the time unit band step.", "type": "number" }, "timeUnitBandSize": { "description": "Default relative band size for a time unit. If set to `1`, the bandwidth of the marks will be equal to the time unit band step. If set to `0.5`, bandwidth of the marks will be half of the time unit band step.", "type": "number" }, "tooltip": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "type": "boolean" }, { "$ref": "#/definitions/TooltipContent" }, { "$ref": "#/definitions/ExprRef" }, { "type": "null" } ], "description": "The tooltip text string to show upon mouse hover or an object defining which fields should the tooltip be derived from.\n\n- If `tooltip` is `true` or `{\"content\": \"encoding\"}`, then all fields from `encoding` will be used.\n- If `tooltip` is `{\"content\": \"data\"}`, then all fields that appear in the highlighted data point will be used.\n- If set to `null` or `false`, then no tooltip will be used.\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite.\n\n__Default value:__ `null`" }, "url": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "The URL of the image file for image marks." }, { "$ref": "#/definitions/ExprRef" } ] }, "width": { "anyOf": [ { "description": "Width of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "x": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X coordinates of the marks, or width of horizontal `\"bar\"` and `\"area\"` without specified `x2` or `width`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "y": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y coordinates of the marks, or height of vertical `\"bar\"` and `\"area\"` without specified `y2` or `height`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." } }, "type": "object" }, "LineString": { "additionalProperties": false, "description": "LineString geometry object. https://tools.ietf.org/html/rfc7946#section-3.1.4", "properties": { "bbox": { "$ref": "#/definitions/BBox", "description": "Bounding box of the coordinate range of the object's Geometries, Features, or Feature Collections. https://tools.ietf.org/html/rfc7946#section-5" }, "coordinates": { "items": { "$ref": "#/definitions/Position" }, "type": "array" }, "type": { "const": "LineString", "description": "Specifies the type of GeoJSON object.", "type": "string" } }, "required": [ "coordinates", "type" ], "type": "object" }, "LinearGradient": { "additionalProperties": false, "properties": { "gradient": { "const": "linear", "description": "The type of gradient. Use `\"linear\"` for a linear gradient.", "type": "string" }, "id": { "type": "string" }, "stops": { "description": "An array of gradient stops defining the gradient color sequence.", "items": { "$ref": "#/definitions/GradientStop" }, "type": "array" }, "x1": { "description": "The starting x-coordinate, in normalized [0, 1] coordinates, of the linear gradient.\n\n__Default value:__ `0`", "type": "number" }, "x2": { "description": "The ending x-coordinate, in normalized [0, 1] coordinates, of the linear gradient.\n\n__Default value:__ `1`", "type": "number" }, "y1": { "description": "The starting y-coordinate, in normalized [0, 1] coordinates, of the linear gradient.\n\n__Default value:__ `0`", "type": "number" }, "y2": { "description": "The ending y-coordinate, in normalized [0, 1] coordinates, of the linear gradient.\n\n__Default value:__ `0`", "type": "number" } }, "required": [ "gradient", "stops" ], "type": "object" }, "LocalMultiTimeUnit": { "enum": [ "yearquarter", "yearquartermonth", "yearmonth", "yearmonthdate", "yearmonthdatehours", "yearmonthdatehoursminutes", "yearmonthdatehoursminutesseconds", "yearweek", "yearweekday", "yearweekdayhours", "yearweekdayhoursminutes", "yearweekdayhoursminutesseconds", "yeardayofyear", "quartermonth", "monthdate", "monthdatehours", "monthdatehoursminutes", "monthdatehoursminutesseconds", "weekday", "weeksdayhours", "weekdayhoursminutes", "weekdayhoursminutesseconds", "dayhours", "dayhoursminutes", "dayhoursminutesseconds", "hoursminutes", "hoursminutesseconds", "minutesseconds", "secondsmilliseconds" ], "type": "string" }, "LocalSingleTimeUnit": { "enum": [ "year", "quarter", "month", "week", "day", "dayofyear", "date", "hours", "minutes", "seconds", "milliseconds" ], "type": "string" }, "Locale": { "additionalProperties": false, "properties": { "number": { "$ref": "#/definitions/NumberLocale" }, "time": { "$ref": "#/definitions/TimeLocale" } }, "type": "object" }, "LoessTransform": { "additionalProperties": false, "properties": { "as": { "description": "The output field names for the smoothed points generated by the loess transform.\n\n__Default value:__ The field names of the input x and y values.", "items": { "$ref": "#/definitions/FieldName" }, "maxItems": 2, "minItems": 2, "type": "array" }, "bandwidth": { "description": "A bandwidth parameter in the range `[0, 1]` that determines the amount of smoothing.\n\n__Default value:__ `0.3`", "type": "number" }, "groupby": { "description": "The data fields to group by. If not specified, a single group containing all data objects will be used.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "loess": { "$ref": "#/definitions/FieldName", "description": "The data field of the dependent variable to smooth." }, "on": { "$ref": "#/definitions/FieldName", "description": "The data field of the independent variable to use a predictor." } }, "required": [ "loess", "on" ], "type": "object" }, "LogicalAnd": { "additionalProperties": false, "properties": { "and": { "items": { "$ref": "#/definitions/PredicateComposition" }, "type": "array" } }, "required": [ "and" ], "type": "object" }, "PredicateComposition": { "anyOf": [ { "$ref": "#/definitions/LogicalNot" }, { "$ref": "#/definitions/LogicalAnd" }, { "$ref": "#/definitions/LogicalOr" }, { "$ref": "#/definitions/Predicate" } ] }, "LogicalNot": { "additionalProperties": false, "properties": { "not": { "$ref": "#/definitions/PredicateComposition" } }, "required": [ "not" ], "type": "object" }, "LogicalOr": { "additionalProperties": false, "properties": { "or": { "items": { "$ref": "#/definitions/PredicateComposition" }, "type": "array" } }, "required": [ "or" ], "type": "object" }, "LookupData": { "additionalProperties": false, "properties": { "data": { "$ref": "#/definitions/Data", "description": "Secondary data source to lookup in." }, "fields": { "description": "Fields in foreign data or selection to lookup. If not specified, the entire object is queried.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "key": { "$ref": "#/definitions/FieldName", "description": "Key in data to lookup." } }, "required": [ "data", "key" ], "type": "object" }, "LookupSelection": { "additionalProperties": false, "properties": { "fields": { "description": "Fields in foreign data or selection to lookup. If not specified, the entire object is queried.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "key": { "$ref": "#/definitions/FieldName", "description": "Key in data to lookup." }, "param": { "$ref": "#/definitions/ParameterName", "description": "Selection parameter name to look up." } }, "required": [ "key", "param" ], "type": "object" }, "LookupTransform": { "additionalProperties": false, "properties": { "as": { "anyOf": [ { "$ref": "#/definitions/FieldName" }, { "items": { "$ref": "#/definitions/FieldName" }, "type": "array" } ], "description": "The output fields on which to store the looked up data values.\n\nFor data lookups, this property may be left blank if `from.fields` has been specified (those field names will be used); if `from.fields` has not been specified, `as` must be a string.\n\nFor selection lookups, this property is optional: if unspecified, looked up values will be stored under a property named for the selection; and if specified, it must correspond to `from.fields`." }, "default": { "description": "The default value to use if lookup fails.\n\n__Default value:__ `null`" }, "from": { "anyOf": [ { "$ref": "#/definitions/LookupData" }, { "$ref": "#/definitions/LookupSelection" } ], "description": "Data source or selection for secondary data reference." }, "lookup": { "description": "Key in primary data source.", "type": "string" } }, "required": [ "lookup", "from" ], "type": "object" }, "Mark": { "description": "All types of primitive marks.", "enum": [ "arc", "area", "bar", "image", "line", "point", "rect", "rule", "text", "tick", "trail", "circle", "square", "geoshape" ], "type": "string" }, "MarkConfig": { "additionalProperties": false, "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The horizontal alignment of the text or ranged marks (area, bar, image, rect, rule). One of `\"left\"`, `\"right\"`, `\"center\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "angle": { "anyOf": [ { "description": "The rotation angle of the text, in degrees.", "maximum": 360, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG element, removing the mark item from the ARIA accessibility tree.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRole": { "anyOf": [ { "description": "Sets the type of user interface element of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"role\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRoleDescription": { "anyOf": [ { "description": "A human-readable, author-localized description for the role of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"aria-roledescription\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "aspect": { "anyOf": [ { "description": "Whether to keep aspect ratio of image marks.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "baseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For text marks, the vertical text baseline. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, `\"line-bottom\"`, or an expression reference that provides one of the valid values. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `lineHeight` rather than `fontSize` alone.\n\nFor range marks, the vertical alignment of the marks. One of `\"top\"`, `\"middle\"`, `\"bottom\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "blend": { "anyOf": [ { "$ref": "#/definitions/Blend", "description": "The color blend mode for drawing an item on its current background. Any valid [CSS mix-blend-mode](https://developer.mozilla.org/en-US/docs/Web/CSS/mix-blend-mode) value can be used.\n\n__Default value: `\"source-over\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "cornerRadius": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles or arcs' corners.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cursor": { "anyOf": [ { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the mark. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "dir": { "anyOf": [ { "$ref": "#/definitions/TextDirection", "description": "The direction of the text. One of `\"ltr\"` (left-to-right) or `\"rtl\"` (right-to-left). This property determines on which side is truncated in response to the limit parameter.\n\n__Default value:__ `\"ltr\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "dx": { "anyOf": [ { "description": "The horizontal offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "dy": { "anyOf": [ { "description": "The vertical offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "ellipsis": { "anyOf": [ { "description": "The ellipsis string for text truncated in response to the limit parameter.\n\n__Default value:__ `\"…\"`", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "endAngle": { "anyOf": [ { "description": "The end angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fill": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default fill color. This property has higher precedence than `config.color`. Set to `null` to remove fill.\n\n__Default value:__ (None)" }, "fillOpacity": { "anyOf": [ { "description": "The fill opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "filled": { "description": "Whether the mark's color should be used as fill color instead of stroke color.\n\n__Default value:__ `false` for all `point`, `line`, and `rule` marks as well as `geoshape` marks for [`graticule`](https://vega.github.io/vega-lite/docs/data.html#graticule) data sources; otherwise, `true`.\n\n__Note:__ This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).", "type": "boolean" }, "font": { "anyOf": [ { "description": "The typeface to set the text in (e.g., `\"Helvetica Neue\"`).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontSize": { "anyOf": [ { "description": "The font size, in pixels.\n\n__Default value:__ `11`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style (e.g., `\"italic\"`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "fontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "height": { "anyOf": [ { "description": "Height of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "href": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "A URL to load upon mouse click. If defined, the mark acts as a hyperlink." }, { "$ref": "#/definitions/ExprRef" } ] }, "innerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The inner radius in pixels of arc marks. `innerRadius` is an alias for `radius2`.\n\n__Default value:__ `0`", "minimum": 0 }, "interpolate": { "anyOf": [ { "$ref": "#/definitions/Interpolate", "description": "The line interpolation method to use for line and area marks. One of the following:\n- `\"linear\"`: piecewise linear segments, as in a polyline.\n- `\"linear-closed\"`: close the linear segments to form a polygon.\n- `\"step\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"step-before\"`: alternate between vertical and horizontal segments, as in a step function.\n- `\"step-after\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"basis\"`: a B-spline, with control point duplication on the ends.\n- `\"basis-open\"`: an open B-spline; may not intersect the start or end.\n- `\"basis-closed\"`: a closed B-spline, as in a loop.\n- `\"cardinal\"`: a Cardinal spline, with control point duplication on the ends.\n- `\"cardinal-open\"`: an open Cardinal spline; may not intersect the start or end, but will intersect other control points.\n- `\"cardinal-closed\"`: a closed Cardinal spline, as in a loop.\n- `\"bundle\"`: equivalent to basis, except the tension parameter is used to straighten the spline.\n- `\"monotone\"`: cubic interpolation that preserves monotonicity in y." }, { "$ref": "#/definitions/ExprRef" } ] }, "invalid": { "description": "Defines how Vega-Lite should handle marks for invalid values (`null` and `NaN`).\n- If set to `\"filter\"` (default), all data items with null values will be skipped (for line, trail, and area marks) or filtered (for other marks).\n- If `null`, all data items are included. In this case, invalid values will be interpreted as zeroes.", "enum": [ "filter", null ], "type": [ "string", "null" ] }, "limit": { "anyOf": [ { "description": "The maximum length of the text mark in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0` -- indicating no limit", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineBreak": { "anyOf": [ { "description": "A delimiter, such as a newline character, upon which to break text strings into multiple lines. This property is ignored if the text is array-valued.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineHeight": { "anyOf": [ { "description": "The line height in pixels (the spacing between subsequent lines of text) for multi-line text marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "opacity": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The overall opacity (value between [0,1]).\n\n__Default value:__ `0.7` for non-aggregate plots with `point`, `tick`, `circle`, or `square` marks or layered `bar` charts and `1` otherwise.", "maximum": 1, "minimum": 0 }, "order": { "description": "For line and trail marks, this `order` property can be set to `null` or `false` to make the lines use the original order in the data sources.", "type": [ "null", "boolean" ] }, "orient": { "$ref": "#/definitions/Orientation", "description": "The orientation of a non-stacked bar, tick, area, and line charts. The value is either horizontal (default) or vertical.\n- For bar, rule and tick, this determines whether the size of the bar and tick should be applied to x or y dimension.\n- For area, this property determines the orient property of the Vega output.\n- For line and trail marks, this property determines the sort order of the points in the line if `config.sortLineBy` is not specified. For stacked charts, this is always determined by the orientation of the stack; therefore explicitly specified value will be ignored." }, "outerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The outer radius in pixels of arc marks. `outerRadius` is an alias for `radius`.\n\n__Default value:__ `0`", "minimum": 0 }, "padAngle": { "anyOf": [ { "description": "The angular padding applied to sides of the arc, in radians.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "radius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For arc mark, the primary (outer) radius in pixels.\n\nFor text marks, polar coordinate radial offset, in pixels, of the text from the origin determined by the `x` and `y` properties.\n\n__Default value:__ `min(plot_width, plot_height)/2`", "minimum": 0 }, "radius2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The secondary (inner) radius in pixels of arc marks.\n\n__Default value:__ `0`", "minimum": 0 }, "shape": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/SymbolShape" }, { "type": "string" } ], "description": "Shape of the point marks. Supported values include:\n- plotting shapes: `\"circle\"`, `\"square\"`, `\"cross\"`, `\"diamond\"`, `\"triangle-up\"`, `\"triangle-down\"`, `\"triangle-right\"`, or `\"triangle-left\"`.\n- the line symbol `\"stroke\"`\n- centered directional shapes `\"arrow\"`, `\"wedge\"`, or `\"triangle\"`\n- a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) (For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.)\n\n__Default value:__ `\"circle\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "size": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default size for marks.\n- For `point`/`circle`/`square`, this represents the pixel area of the marks. Note that this value sets the area of the symbol; the side lengths will increase with the square root of this value.\n- For `bar`, this represents the band size of the bar, in pixels.\n- For `text`, this represents the font size, in pixels.\n\n__Default value:__\n- `30` for point, circle, square marks; width/height's `step`\n- `2` for bar marks with discrete dimensions;\n- `5` for bar marks with continuous dimensions;\n- `11` for text marks.", "minimum": 0 }, "smooth": { "anyOf": [ { "description": "A boolean flag (default true) indicating if the image should be smoothed when resized. If false, individual pixels should be scaled directly rather than interpolated with smoothing. For SVG rendering, this option may not work in some browsers due to lack of standardization.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "startAngle": { "anyOf": [ { "description": "The start angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "stroke": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default stroke color. This property has higher precedence than `config.color`. Set to `null` to remove stroke.\n\n__Default value:__ (None)" }, "strokeCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for line ending style. One of `\"butt\"`, `\"round\"`, or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "An array of alternating stroke, space lengths for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDashOffset": { "anyOf": [ { "description": "The offset (in pixels) into which to begin drawing with the stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeJoin": { "anyOf": [ { "$ref": "#/definitions/StrokeJoin", "description": "The stroke line join method. One of `\"miter\"`, `\"round\"` or `\"bevel\"`.\n\n__Default value:__ `\"miter\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeMiterLimit": { "anyOf": [ { "description": "The miter limit at which to bevel a line join.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOffset": { "anyOf": [ { "description": "The offset in pixels at which to draw the group stroke and fill. If unspecified, the default behavior is to dynamically offset stroked groups such that 1 pixel stroke widths align with the pixel grid.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOpacity": { "anyOf": [ { "description": "The stroke opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "The stroke width, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "tension": { "anyOf": [ { "description": "Depending on the interpolation type, sets the tension parameter (for line and area marks).", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "text": { "anyOf": [ { "$ref": "#/definitions/Text", "description": "Placeholder text if the `text` channel is not specified" }, { "$ref": "#/definitions/ExprRef" } ] }, "theta": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "- For arc marks, the arc length in radians if theta2 is not specified, otherwise the start arc angle. (A value of 0 indicates up or “north”, increasing values proceed clockwise.)\n\n- For text marks, polar coordinate angle in radians.", "maximum": 360, "minimum": 0 }, "theta2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The end angle of arc marks in radians. A value of 0 indicates up or “north”, increasing values proceed clockwise." }, "timeUnitBandPosition": { "description": "Default relative band position for a time unit. If set to `0`, the marks will be positioned at the beginning of the time unit band step. If set to `0.5`, the marks will be positioned in the middle of the time unit band step.", "type": "number" }, "timeUnitBandSize": { "description": "Default relative band size for a time unit. If set to `1`, the bandwidth of the marks will be equal to the time unit band step. If set to `0.5`, bandwidth of the marks will be half of the time unit band step.", "type": "number" }, "tooltip": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "type": "boolean" }, { "$ref": "#/definitions/TooltipContent" }, { "$ref": "#/definitions/ExprRef" }, { "type": "null" } ], "description": "The tooltip text string to show upon mouse hover or an object defining which fields should the tooltip be derived from.\n\n- If `tooltip` is `true` or `{\"content\": \"encoding\"}`, then all fields from `encoding` will be used.\n- If `tooltip` is `{\"content\": \"data\"}`, then all fields that appear in the highlighted data point will be used.\n- If set to `null` or `false`, then no tooltip will be used.\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite.\n\n__Default value:__ `null`" }, "url": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "The URL of the image file for image marks." }, { "$ref": "#/definitions/ExprRef" } ] }, "width": { "anyOf": [ { "description": "Width of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "x": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X coordinates of the marks, or width of horizontal `\"bar\"` and `\"area\"` without specified `x2` or `width`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "y": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y coordinates of the marks, or height of vertical `\"bar\"` and `\"area\"` without specified `y2` or `height`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." } }, "type": "object" }, "MarkDef": { "additionalProperties": false, "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The horizontal alignment of the text or ranged marks (area, bar, image, rect, rule). One of `\"left\"`, `\"right\"`, `\"center\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "angle": { "anyOf": [ { "description": "The rotation angle of the text, in degrees.", "maximum": 360, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG element, removing the mark item from the ARIA accessibility tree.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRole": { "anyOf": [ { "description": "Sets the type of user interface element of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"role\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRoleDescription": { "anyOf": [ { "description": "A human-readable, author-localized description for the role of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"aria-roledescription\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "aspect": { "anyOf": [ { "description": "Whether to keep aspect ratio of image marks.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "bandSize": { "description": "The width of the ticks.\n\n__Default value:__ 3/4 of step (width step for horizontal ticks and height step for vertical ticks).", "minimum": 0, "type": "number" }, "baseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For text marks, the vertical text baseline. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, `\"line-bottom\"`, or an expression reference that provides one of the valid values. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `lineHeight` rather than `fontSize` alone.\n\nFor range marks, the vertical alignment of the marks. One of `\"top\"`, `\"middle\"`, `\"bottom\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "binSpacing": { "description": "Offset between bars for binned field. The ideal value for this is either 0 (preferred by statisticians) or 1 (Vega-Lite default, D3 example style).\n\n__Default value:__ `1`", "minimum": 0, "type": "number" }, "blend": { "anyOf": [ { "$ref": "#/definitions/Blend", "description": "The color blend mode for drawing an item on its current background. Any valid [CSS mix-blend-mode](https://developer.mozilla.org/en-US/docs/Web/CSS/mix-blend-mode) value can be used.\n\n__Default value: `\"source-over\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "clip": { "description": "Whether a mark be clipped to the enclosing group’s width and height.", "type": "boolean" }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "continuousBandSize": { "description": "The default size of the bars on continuous scales.\n\n__Default value:__ `5`", "minimum": 0, "type": "number" }, "cornerRadius": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles or arcs' corners.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusEnd": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "- For vertical bars, top-left and top-right corner radius.\n\n- For horizontal bars, top-right and bottom-right corner radius." }, "cornerRadiusTopLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cursor": { "anyOf": [ { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the mark. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "dir": { "anyOf": [ { "$ref": "#/definitions/TextDirection", "description": "The direction of the text. One of `\"ltr\"` (left-to-right) or `\"rtl\"` (right-to-left). This property determines on which side is truncated in response to the limit parameter.\n\n__Default value:__ `\"ltr\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "discreteBandSize": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RelativeBandSize" } ], "description": "The default size of the bars with discrete dimensions. If unspecified, the default size is `step-2`, which provides 2 pixel offset between bars.", "minimum": 0 }, "dx": { "anyOf": [ { "description": "The horizontal offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "dy": { "anyOf": [ { "description": "The vertical offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "ellipsis": { "anyOf": [ { "description": "The ellipsis string for text truncated in response to the limit parameter.\n\n__Default value:__ `\"…\"`", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fill": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default fill color. This property has higher precedence than `config.color`. Set to `null` to remove fill.\n\n__Default value:__ (None)" }, "fillOpacity": { "anyOf": [ { "description": "The fill opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "filled": { "description": "Whether the mark's color should be used as fill color instead of stroke color.\n\n__Default value:__ `false` for all `point`, `line`, and `rule` marks as well as `geoshape` marks for [`graticule`](https://vega.github.io/vega-lite/docs/data.html#graticule) data sources; otherwise, `true`.\n\n__Note:__ This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).", "type": "boolean" }, "font": { "anyOf": [ { "description": "The typeface to set the text in (e.g., `\"Helvetica Neue\"`).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontSize": { "anyOf": [ { "description": "The font size, in pixels.\n\n__Default value:__ `11`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style (e.g., `\"italic\"`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "fontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "height": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RelativeBandSize" } ], "description": "Height of the marks. One of:\n\n- A number representing a fixed pixel height.\n\n- A relative band size definition. For example, `{band: 0.5}` represents half of the band" }, "href": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "A URL to load upon mouse click. If defined, the mark acts as a hyperlink." }, { "$ref": "#/definitions/ExprRef" } ] }, "innerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The inner radius in pixels of arc marks. `innerRadius` is an alias for `radius2`.\n\n__Default value:__ `0`", "minimum": 0 }, "interpolate": { "anyOf": [ { "$ref": "#/definitions/Interpolate", "description": "The line interpolation method to use for line and area marks. One of the following:\n- `\"linear\"`: piecewise linear segments, as in a polyline.\n- `\"linear-closed\"`: close the linear segments to form a polygon.\n- `\"step\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"step-before\"`: alternate between vertical and horizontal segments, as in a step function.\n- `\"step-after\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"basis\"`: a B-spline, with control point duplication on the ends.\n- `\"basis-open\"`: an open B-spline; may not intersect the start or end.\n- `\"basis-closed\"`: a closed B-spline, as in a loop.\n- `\"cardinal\"`: a Cardinal spline, with control point duplication on the ends.\n- `\"cardinal-open\"`: an open Cardinal spline; may not intersect the start or end, but will intersect other control points.\n- `\"cardinal-closed\"`: a closed Cardinal spline, as in a loop.\n- `\"bundle\"`: equivalent to basis, except the tension parameter is used to straighten the spline.\n- `\"monotone\"`: cubic interpolation that preserves monotonicity in y." }, { "$ref": "#/definitions/ExprRef" } ] }, "invalid": { "description": "Defines how Vega-Lite should handle marks for invalid values (`null` and `NaN`).\n- If set to `\"filter\"` (default), all data items with null values will be skipped (for line, trail, and area marks) or filtered (for other marks).\n- If `null`, all data items are included. In this case, invalid values will be interpreted as zeroes.", "enum": [ "filter", null ], "type": [ "string", "null" ] }, "limit": { "anyOf": [ { "description": "The maximum length of the text mark in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0` -- indicating no limit", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "line": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/OverlayMarkDef" } ], "description": "A flag for overlaying line on top of area marks, or an object defining the properties of the overlayed lines.\n\n- If this value is an empty object (`{}`) or `true`, lines with default properties will be used.\n\n- If this value is `false`, no lines would be automatically added to area marks.\n\n__Default value:__ `false`." }, "lineBreak": { "anyOf": [ { "description": "A delimiter, such as a newline character, upon which to break text strings into multiple lines. This property is ignored if the text is array-valued.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineHeight": { "anyOf": [ { "description": "The line height in pixels (the spacing between subsequent lines of text) for multi-line text marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "minBandSize": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The minimum band size for bar and rectangle marks. __Default value:__ `0.25`" }, "opacity": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The overall opacity (value between [0,1]).\n\n__Default value:__ `0.7` for non-aggregate plots with `point`, `tick`, `circle`, or `square` marks or layered `bar` charts and `1` otherwise.", "maximum": 1, "minimum": 0 }, "order": { "description": "For line and trail marks, this `order` property can be set to `null` or `false` to make the lines use the original order in the data sources.", "type": [ "null", "boolean" ] }, "orient": { "$ref": "#/definitions/Orientation", "description": "The orientation of a non-stacked bar, tick, area, and line charts. The value is either horizontal (default) or vertical.\n- For bar, rule and tick, this determines whether the size of the bar and tick should be applied to x or y dimension.\n- For area, this property determines the orient property of the Vega output.\n- For line and trail marks, this property determines the sort order of the points in the line if `config.sortLineBy` is not specified. For stacked charts, this is always determined by the orientation of the stack; therefore explicitly specified value will be ignored." }, "outerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The outer radius in pixels of arc marks. `outerRadius` is an alias for `radius`.\n\n__Default value:__ `0`", "minimum": 0 }, "padAngle": { "anyOf": [ { "description": "The angular padding applied to sides of the arc, in radians.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "point": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/OverlayMarkDef" }, { "const": "transparent", "type": "string" } ], "description": "A flag for overlaying points on top of line or area marks, or an object defining the properties of the overlayed points.\n\n- If this property is `\"transparent\"`, transparent points will be used (for enhancing tooltips and selections).\n\n- If this property is an empty object (`{}`) or `true`, filled points with default properties will be used.\n\n- If this property is `false`, no points would be automatically added to line or area marks.\n\n__Default value:__ `false`." }, "radius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For arc mark, the primary (outer) radius in pixels.\n\nFor text marks, polar coordinate radial offset, in pixels, of the text from the origin determined by the `x` and `y` properties.\n\n__Default value:__ `min(plot_width, plot_height)/2`", "minimum": 0 }, "radius2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The secondary (inner) radius in pixels of arc marks.\n\n__Default value:__ `0`", "minimum": 0 }, "radius2Offset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for radius2." }, "radiusOffset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for radius." }, "shape": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/SymbolShape" }, { "type": "string" } ], "description": "Shape of the point marks. Supported values include:\n- plotting shapes: `\"circle\"`, `\"square\"`, `\"cross\"`, `\"diamond\"`, `\"triangle-up\"`, `\"triangle-down\"`, `\"triangle-right\"`, or `\"triangle-left\"`.\n- the line symbol `\"stroke\"`\n- centered directional shapes `\"arrow\"`, `\"wedge\"`, or `\"triangle\"`\n- a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) (For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.)\n\n__Default value:__ `\"circle\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "size": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default size for marks.\n- For `point`/`circle`/`square`, this represents the pixel area of the marks. Note that this value sets the area of the symbol; the side lengths will increase with the square root of this value.\n- For `bar`, this represents the band size of the bar, in pixels.\n- For `text`, this represents the font size, in pixels.\n\n__Default value:__\n- `30` for point, circle, square marks; width/height's `step`\n- `2` for bar marks with discrete dimensions;\n- `5` for bar marks with continuous dimensions;\n- `11` for text marks.", "minimum": 0 }, "smooth": { "anyOf": [ { "description": "A boolean flag (default true) indicating if the image should be smoothed when resized. If false, individual pixels should be scaled directly rather than interpolated with smoothing. For SVG rendering, this option may not work in some browsers due to lack of standardization.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "stroke": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default stroke color. This property has higher precedence than `config.color`. Set to `null` to remove stroke.\n\n__Default value:__ (None)" }, "strokeCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for line ending style. One of `\"butt\"`, `\"round\"`, or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "An array of alternating stroke, space lengths for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDashOffset": { "anyOf": [ { "description": "The offset (in pixels) into which to begin drawing with the stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeJoin": { "anyOf": [ { "$ref": "#/definitions/StrokeJoin", "description": "The stroke line join method. One of `\"miter\"`, `\"round\"` or `\"bevel\"`.\n\n__Default value:__ `\"miter\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeMiterLimit": { "anyOf": [ { "description": "The miter limit at which to bevel a line join.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOffset": { "anyOf": [ { "description": "The offset in pixels at which to draw the group stroke and fill. If unspecified, the default behavior is to dynamically offset stroked groups such that 1 pixel stroke widths align with the pixel grid.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOpacity": { "anyOf": [ { "description": "The stroke opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "The stroke width, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "style": { "anyOf": [ { "type": "string" }, { "items": { "type": "string" }, "type": "array" } ], "description": "A string or array of strings indicating the name of custom styles to apply to the mark. A style is a named collection of mark property defaults defined within the [style configuration](https://vega.github.io/vega-lite/docs/mark.html#style-config). If style is an array, later styles will override earlier styles. Any [mark properties](https://vega.github.io/vega-lite/docs/encoding.html#mark-prop) explicitly defined within the `encoding` will override a style default.\n\n__Default value:__ The mark's name. For example, a bar mark will have style `\"bar\"` by default. __Note:__ Any specified style will augment the default style. For example, a bar mark with `\"style\": \"foo\"` will receive from `config.style.bar` and `config.style.foo` (the specified style `\"foo\"` has higher precedence)." }, "tension": { "anyOf": [ { "description": "Depending on the interpolation type, sets the tension parameter (for line and area marks).", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "text": { "anyOf": [ { "$ref": "#/definitions/Text", "description": "Placeholder text if the `text` channel is not specified" }, { "$ref": "#/definitions/ExprRef" } ] }, "theta": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "- For arc marks, the arc length in radians if theta2 is not specified, otherwise the start arc angle. (A value of 0 indicates up or “north”, increasing values proceed clockwise.)\n\n- For text marks, polar coordinate angle in radians.", "maximum": 360, "minimum": 0 }, "theta2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The end angle of arc marks in radians. A value of 0 indicates up or “north”, increasing values proceed clockwise." }, "theta2Offset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for theta2." }, "thetaOffset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for theta." }, "thickness": { "description": "Thickness of the tick mark.\n\n__Default value:__ `1`", "minimum": 0, "type": "number" }, "timeUnitBandPosition": { "description": "Default relative band position for a time unit. If set to `0`, the marks will be positioned at the beginning of the time unit band step. If set to `0.5`, the marks will be positioned in the middle of the time unit band step.", "type": "number" }, "timeUnitBandSize": { "description": "Default relative band size for a time unit. If set to `1`, the bandwidth of the marks will be equal to the time unit band step. If set to `0.5`, bandwidth of the marks will be half of the time unit band step.", "type": "number" }, "tooltip": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "type": "boolean" }, { "$ref": "#/definitions/TooltipContent" }, { "$ref": "#/definitions/ExprRef" }, { "type": "null" } ], "description": "The tooltip text string to show upon mouse hover or an object defining which fields should the tooltip be derived from.\n\n- If `tooltip` is `true` or `{\"content\": \"encoding\"}`, then all fields from `encoding` will be used.\n- If `tooltip` is `{\"content\": \"data\"}`, then all fields that appear in the highlighted data point will be used.\n- If set to `null` or `false`, then no tooltip will be used.\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite.\n\n__Default value:__ `null`" }, "type": { "$ref": "#/definitions/Mark", "description": "The mark type. This could a primitive mark type (one of `\"bar\"`, `\"circle\"`, `\"square\"`, `\"tick\"`, `\"line\"`, `\"area\"`, `\"point\"`, `\"geoshape\"`, `\"rule\"`, and `\"text\"`) or a composite mark type (`\"boxplot\"`, `\"errorband\"`, `\"errorbar\"`)." }, "url": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "The URL of the image file for image marks." }, { "$ref": "#/definitions/ExprRef" } ] }, "width": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RelativeBandSize" } ], "description": "Width of the marks. One of:\n\n- A number representing a fixed pixel width.\n\n- A relative band size definition. For example, `{band: 0.5}` represents half of the band." }, "x": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X coordinates of the marks, or width of horizontal `\"bar\"` and `\"area\"` without specified `x2` or `width`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2Offset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for x2-position." }, "xOffset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for x-position." }, "y": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y coordinates of the marks, or height of vertical `\"bar\"` and `\"area\"` without specified `y2` or `height`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2Offset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for y2-position." }, "yOffset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for y-position." } }, "required": [ "type" ], "type": "object" }, "MarkPropDef<(Gradient|string|null)>": { "anyOf": [ { "$ref": "#/definitions/FieldOrDatumDefWithCondition" }, { "$ref": "#/definitions/FieldOrDatumDefWithCondition" }, { "$ref": "#/definitions/ValueDefWithCondition" } ] }, "MarkPropDef<(string|null),TypeForShape>": { "anyOf": [ { "$ref": "#/definitions/FieldOrDatumDefWithCondition,(string|null)>" }, { "$ref": "#/definitions/FieldOrDatumDefWithCondition" }, { "$ref": "#/definitions/ValueDefWithCondition,(string|null)>" } ] }, "MarkPropDef": { "anyOf": [ { "$ref": "#/definitions/FieldOrDatumDefWithCondition" }, { "$ref": "#/definitions/FieldOrDatumDefWithCondition" }, { "$ref": "#/definitions/ValueDefWithCondition" } ] }, "MarkPropDef": { "anyOf": [ { "$ref": "#/definitions/FieldOrDatumDefWithCondition" }, { "$ref": "#/definitions/FieldOrDatumDefWithCondition" }, { "$ref": "#/definitions/ValueDefWithCondition" } ] }, "MarkType": { "enum": [ "arc", "area", "image", "group", "line", "path", "rect", "rule", "shape", "symbol", "text", "trail" ], "type": "string" }, "MergedStream": { "additionalProperties": false, "properties": { "between": { "items": { "$ref": "#/definitions/Stream" }, "type": "array" }, "consume": { "type": "boolean" }, "debounce": { "type": "number" }, "filter": { "anyOf": [ { "$ref": "#/definitions/Expr" }, { "items": { "$ref": "#/definitions/Expr" }, "type": "array" } ] }, "markname": { "type": "string" }, "marktype": { "$ref": "#/definitions/MarkType" }, "merge": { "items": { "$ref": "#/definitions/Stream" }, "type": "array" }, "throttle": { "type": "number" } }, "required": [ "merge" ], "type": "object" }, "Month": { "maximum": 12, "minimum": 1, "type": "number" }, "MultiLineString": { "additionalProperties": false, "description": "MultiLineString geometry object. https://tools.ietf.org/html/rfc7946#section-3.1.5", "properties": { "bbox": { "$ref": "#/definitions/BBox", "description": "Bounding box of the coordinate range of the object's Geometries, Features, or Feature Collections. https://tools.ietf.org/html/rfc7946#section-5" }, "coordinates": { "items": { "items": { "$ref": "#/definitions/Position" }, "type": "array" }, "type": "array" }, "type": { "const": "MultiLineString", "description": "Specifies the type of GeoJSON object.", "type": "string" } }, "required": [ "coordinates", "type" ], "type": "object" }, "MultiPoint": { "additionalProperties": false, "description": "MultiPoint geometry object. https://tools.ietf.org/html/rfc7946#section-3.1.3", "properties": { "bbox": { "$ref": "#/definitions/BBox", "description": "Bounding box of the coordinate range of the object's Geometries, Features, or Feature Collections. https://tools.ietf.org/html/rfc7946#section-5" }, "coordinates": { "items": { "$ref": "#/definitions/Position" }, "type": "array" }, "type": { "const": "MultiPoint", "description": "Specifies the type of GeoJSON object.", "type": "string" } }, "required": [ "coordinates", "type" ], "type": "object" }, "MultiPolygon": { "additionalProperties": false, "description": "MultiPolygon geometry object. https://tools.ietf.org/html/rfc7946#section-3.1.7", "properties": { "bbox": { "$ref": "#/definitions/BBox", "description": "Bounding box of the coordinate range of the object's Geometries, Features, or Feature Collections. https://tools.ietf.org/html/rfc7946#section-5" }, "coordinates": { "items": { "items": { "items": { "$ref": "#/definitions/Position" }, "type": "array" }, "type": "array" }, "type": "array" }, "type": { "const": "MultiPolygon", "description": "Specifies the type of GeoJSON object.", "type": "string" } }, "required": [ "coordinates", "type" ], "type": "object" }, "MultiTimeUnit": { "anyOf": [ { "$ref": "#/definitions/LocalMultiTimeUnit" }, { "$ref": "#/definitions/UtcMultiTimeUnit" } ] }, "NamedData": { "additionalProperties": false, "properties": { "format": { "$ref": "#/definitions/DataFormat", "description": "An object that specifies the format for parsing the data." }, "name": { "description": "Provide a placeholder name and bind data at runtime.\n\nNew data may change the layout but Vega does not always resize the chart. To update the layout when the data updates, set [autosize](https://vega.github.io/vega-lite/docs/size.html#autosize) or explicitly use [view.resize](https://vega.github.io/vega/docs/api/view/#view_resize).", "type": "string" } }, "required": [ "name" ], "type": "object" }, "NonArgAggregateOp": { "enum": [ "average", "count", "distinct", "max", "mean", "median", "min", "missing", "product", "q1", "q3", "ci0", "ci1", "stderr", "stdev", "stdevp", "sum", "valid", "values", "variance", "variancep" ], "type": "string" }, "NonLayerRepeatSpec": { "additionalProperties": false, "description": "Base interface for a repeat specification.", "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "columns": { "description": "The number of columns to include in the view composition layout.\n\n__Default value__: `undefined` -- An infinite number of columns (a single row) will be assumed. This is equivalent to `hconcat` (for `concat`) and to using the `column` channel (for `facet` and `repeat`).\n\n__Note__:\n\n1) This property is only for:\n- the general (wrappable) `concat` operator (not `hconcat`/`vconcat`)\n- the `facet` and `repeat` operator with one field/repetition definition (without row/column nesting)\n\n2) Setting the `columns` to `1` is equivalent to `vconcat` (for `concat`) and to using the `row` channel (for `facet` and `repeat`).", "type": "number" }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "repeat": { "anyOf": [ { "items": { "type": "string" }, "type": "array" }, { "$ref": "#/definitions/RepeatMapping" } ], "description": "Definition for fields to be repeated. One of: 1) An array of fields to be repeated. If `\"repeat\"` is an array, the field can be referred to as `{\"repeat\": \"repeat\"}`. The repeated views are laid out in a wrapped row. You can set the number of columns to control the wrapping. 2) An object that maps `\"row\"` and/or `\"column\"` to the listed fields to be repeated along the particular orientations. The objects `{\"repeat\": \"row\"}` and `{\"repeat\": \"column\"}` can be used to refer to the repeated field respectively." }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "spec": { "$ref": "#/definitions/NonNormalizedSpec", "description": "A specification of the view that gets repeated." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" } }, "required": [ "repeat", "spec" ], "type": "object" }, "NonNormalizedSpec": { "$ref": "#/definitions/Spec" }, "NumberLocale": { "additionalProperties": false, "description": "Locale definition for formatting numbers.", "properties": { "currency": { "$ref": "#/definitions/Vector2", "description": "The currency prefix and suffix (e.g., [\"$\", \"\"])." }, "decimal": { "description": "The decimal point (e.g., \".\").", "type": "string" }, "grouping": { "description": "The array of group sizes (e.g., [3]), cycled as needed.", "items": { "type": "number" }, "type": "array" }, "minus": { "description": "The minus sign (defaults to hyphen-minus, \"-\").", "type": "string" }, "nan": { "description": "The not-a-number value (defaults to \"NaN\").", "type": "string" }, "numerals": { "$ref": "#/definitions/Vector10", "description": "An array of ten strings to replace the numerals 0-9." }, "percent": { "description": "The percent sign (defaults to \"%\").", "type": "string" }, "thousands": { "description": "The group separator (e.g., \",\").", "type": "string" } }, "required": [ "decimal", "thousands", "grouping", "currency" ], "type": "object" }, "NumericArrayMarkPropDef": { "$ref": "#/definitions/MarkPropDef" }, "NumericMarkPropDef": { "$ref": "#/definitions/MarkPropDef" }, "OffsetDef": { "anyOf": [ { "$ref": "#/definitions/ScaleFieldDef" }, { "$ref": "#/definitions/ScaleDatumDef" }, { "$ref": "#/definitions/ValueDef" } ] }, "OrderFieldDef": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "sort": { "$ref": "#/definitions/SortOrder", "description": "The sort order. One of `\"ascending\"` (default) or `\"descending\"`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "OrderOnlyDef": { "additionalProperties": false, "properties": { "sort": { "$ref": "#/definitions/SortOrder", "description": "The sort order. One of `\"ascending\"` (default) or `\"descending\"`." } }, "type": "object" }, "OrderValueDef": { "additionalProperties": false, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef" }, { "items": { "$ref": "#/definitions/ConditionalValueDef" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "value": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "value" ], "type": "object" }, "Orient": { "enum": [ "left", "right", "top", "bottom" ], "type": "string" }, "Orientation": { "enum": [ "horizontal", "vertical" ], "type": "string" }, "OverlayMarkDef": { "additionalProperties": false, "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The horizontal alignment of the text or ranged marks (area, bar, image, rect, rule). One of `\"left\"`, `\"right\"`, `\"center\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "angle": { "anyOf": [ { "description": "The rotation angle of the text, in degrees.", "maximum": 360, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG element, removing the mark item from the ARIA accessibility tree.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRole": { "anyOf": [ { "description": "Sets the type of user interface element of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"role\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRoleDescription": { "anyOf": [ { "description": "A human-readable, author-localized description for the role of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"aria-roledescription\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "aspect": { "anyOf": [ { "description": "Whether to keep aspect ratio of image marks.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "baseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For text marks, the vertical text baseline. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, `\"line-bottom\"`, or an expression reference that provides one of the valid values. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `lineHeight` rather than `fontSize` alone.\n\nFor range marks, the vertical alignment of the marks. One of `\"top\"`, `\"middle\"`, `\"bottom\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "blend": { "anyOf": [ { "$ref": "#/definitions/Blend", "description": "The color blend mode for drawing an item on its current background. Any valid [CSS mix-blend-mode](https://developer.mozilla.org/en-US/docs/Web/CSS/mix-blend-mode) value can be used.\n\n__Default value: `\"source-over\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "clip": { "description": "Whether a mark be clipped to the enclosing group’s width and height.", "type": "boolean" }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "cornerRadius": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles or arcs' corners.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cursor": { "anyOf": [ { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the mark. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "dir": { "anyOf": [ { "$ref": "#/definitions/TextDirection", "description": "The direction of the text. One of `\"ltr\"` (left-to-right) or `\"rtl\"` (right-to-left). This property determines on which side is truncated in response to the limit parameter.\n\n__Default value:__ `\"ltr\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "dx": { "anyOf": [ { "description": "The horizontal offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "dy": { "anyOf": [ { "description": "The vertical offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "ellipsis": { "anyOf": [ { "description": "The ellipsis string for text truncated in response to the limit parameter.\n\n__Default value:__ `\"…\"`", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "endAngle": { "anyOf": [ { "description": "The end angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fill": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default fill color. This property has higher precedence than `config.color`. Set to `null` to remove fill.\n\n__Default value:__ (None)" }, "fillOpacity": { "anyOf": [ { "description": "The fill opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "filled": { "description": "Whether the mark's color should be used as fill color instead of stroke color.\n\n__Default value:__ `false` for all `point`, `line`, and `rule` marks as well as `geoshape` marks for [`graticule`](https://vega.github.io/vega-lite/docs/data.html#graticule) data sources; otherwise, `true`.\n\n__Note:__ This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).", "type": "boolean" }, "font": { "anyOf": [ { "description": "The typeface to set the text in (e.g., `\"Helvetica Neue\"`).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontSize": { "anyOf": [ { "description": "The font size, in pixels.\n\n__Default value:__ `11`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style (e.g., `\"italic\"`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "fontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "height": { "anyOf": [ { "description": "Height of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "href": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "A URL to load upon mouse click. If defined, the mark acts as a hyperlink." }, { "$ref": "#/definitions/ExprRef" } ] }, "innerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The inner radius in pixels of arc marks. `innerRadius` is an alias for `radius2`.\n\n__Default value:__ `0`", "minimum": 0 }, "interpolate": { "anyOf": [ { "$ref": "#/definitions/Interpolate", "description": "The line interpolation method to use for line and area marks. One of the following:\n- `\"linear\"`: piecewise linear segments, as in a polyline.\n- `\"linear-closed\"`: close the linear segments to form a polygon.\n- `\"step\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"step-before\"`: alternate between vertical and horizontal segments, as in a step function.\n- `\"step-after\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"basis\"`: a B-spline, with control point duplication on the ends.\n- `\"basis-open\"`: an open B-spline; may not intersect the start or end.\n- `\"basis-closed\"`: a closed B-spline, as in a loop.\n- `\"cardinal\"`: a Cardinal spline, with control point duplication on the ends.\n- `\"cardinal-open\"`: an open Cardinal spline; may not intersect the start or end, but will intersect other control points.\n- `\"cardinal-closed\"`: a closed Cardinal spline, as in a loop.\n- `\"bundle\"`: equivalent to basis, except the tension parameter is used to straighten the spline.\n- `\"monotone\"`: cubic interpolation that preserves monotonicity in y." }, { "$ref": "#/definitions/ExprRef" } ] }, "invalid": { "description": "Defines how Vega-Lite should handle marks for invalid values (`null` and `NaN`).\n- If set to `\"filter\"` (default), all data items with null values will be skipped (for line, trail, and area marks) or filtered (for other marks).\n- If `null`, all data items are included. In this case, invalid values will be interpreted as zeroes.", "enum": [ "filter", null ], "type": [ "string", "null" ] }, "limit": { "anyOf": [ { "description": "The maximum length of the text mark in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0` -- indicating no limit", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineBreak": { "anyOf": [ { "description": "A delimiter, such as a newline character, upon which to break text strings into multiple lines. This property is ignored if the text is array-valued.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineHeight": { "anyOf": [ { "description": "The line height in pixels (the spacing between subsequent lines of text) for multi-line text marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "opacity": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The overall opacity (value between [0,1]).\n\n__Default value:__ `0.7` for non-aggregate plots with `point`, `tick`, `circle`, or `square` marks or layered `bar` charts and `1` otherwise.", "maximum": 1, "minimum": 0 }, "order": { "description": "For line and trail marks, this `order` property can be set to `null` or `false` to make the lines use the original order in the data sources.", "type": [ "null", "boolean" ] }, "orient": { "$ref": "#/definitions/Orientation", "description": "The orientation of a non-stacked bar, tick, area, and line charts. The value is either horizontal (default) or vertical.\n- For bar, rule and tick, this determines whether the size of the bar and tick should be applied to x or y dimension.\n- For area, this property determines the orient property of the Vega output.\n- For line and trail marks, this property determines the sort order of the points in the line if `config.sortLineBy` is not specified. For stacked charts, this is always determined by the orientation of the stack; therefore explicitly specified value will be ignored." }, "outerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The outer radius in pixels of arc marks. `outerRadius` is an alias for `radius`.\n\n__Default value:__ `0`", "minimum": 0 }, "padAngle": { "anyOf": [ { "description": "The angular padding applied to sides of the arc, in radians.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "radius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For arc mark, the primary (outer) radius in pixels.\n\nFor text marks, polar coordinate radial offset, in pixels, of the text from the origin determined by the `x` and `y` properties.\n\n__Default value:__ `min(plot_width, plot_height)/2`", "minimum": 0 }, "radius2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The secondary (inner) radius in pixels of arc marks.\n\n__Default value:__ `0`", "minimum": 0 }, "radius2Offset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for radius2." }, "radiusOffset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for radius." }, "shape": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/SymbolShape" }, { "type": "string" } ], "description": "Shape of the point marks. Supported values include:\n- plotting shapes: `\"circle\"`, `\"square\"`, `\"cross\"`, `\"diamond\"`, `\"triangle-up\"`, `\"triangle-down\"`, `\"triangle-right\"`, or `\"triangle-left\"`.\n- the line symbol `\"stroke\"`\n- centered directional shapes `\"arrow\"`, `\"wedge\"`, or `\"triangle\"`\n- a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) (For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.)\n\n__Default value:__ `\"circle\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "size": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default size for marks.\n- For `point`/`circle`/`square`, this represents the pixel area of the marks. Note that this value sets the area of the symbol; the side lengths will increase with the square root of this value.\n- For `bar`, this represents the band size of the bar, in pixels.\n- For `text`, this represents the font size, in pixels.\n\n__Default value:__\n- `30` for point, circle, square marks; width/height's `step`\n- `2` for bar marks with discrete dimensions;\n- `5` for bar marks with continuous dimensions;\n- `11` for text marks.", "minimum": 0 }, "smooth": { "anyOf": [ { "description": "A boolean flag (default true) indicating if the image should be smoothed when resized. If false, individual pixels should be scaled directly rather than interpolated with smoothing. For SVG rendering, this option may not work in some browsers due to lack of standardization.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "startAngle": { "anyOf": [ { "description": "The start angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "stroke": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default stroke color. This property has higher precedence than `config.color`. Set to `null` to remove stroke.\n\n__Default value:__ (None)" }, "strokeCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for line ending style. One of `\"butt\"`, `\"round\"`, or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "An array of alternating stroke, space lengths for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDashOffset": { "anyOf": [ { "description": "The offset (in pixels) into which to begin drawing with the stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeJoin": { "anyOf": [ { "$ref": "#/definitions/StrokeJoin", "description": "The stroke line join method. One of `\"miter\"`, `\"round\"` or `\"bevel\"`.\n\n__Default value:__ `\"miter\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeMiterLimit": { "anyOf": [ { "description": "The miter limit at which to bevel a line join.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOffset": { "anyOf": [ { "description": "The offset in pixels at which to draw the group stroke and fill. If unspecified, the default behavior is to dynamically offset stroked groups such that 1 pixel stroke widths align with the pixel grid.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOpacity": { "anyOf": [ { "description": "The stroke opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "The stroke width, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "style": { "anyOf": [ { "type": "string" }, { "items": { "type": "string" }, "type": "array" } ], "description": "A string or array of strings indicating the name of custom styles to apply to the mark. A style is a named collection of mark property defaults defined within the [style configuration](https://vega.github.io/vega-lite/docs/mark.html#style-config). If style is an array, later styles will override earlier styles. Any [mark properties](https://vega.github.io/vega-lite/docs/encoding.html#mark-prop) explicitly defined within the `encoding` will override a style default.\n\n__Default value:__ The mark's name. For example, a bar mark will have style `\"bar\"` by default. __Note:__ Any specified style will augment the default style. For example, a bar mark with `\"style\": \"foo\"` will receive from `config.style.bar` and `config.style.foo` (the specified style `\"foo\"` has higher precedence)." }, "tension": { "anyOf": [ { "description": "Depending on the interpolation type, sets the tension parameter (for line and area marks).", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "text": { "anyOf": [ { "$ref": "#/definitions/Text", "description": "Placeholder text if the `text` channel is not specified" }, { "$ref": "#/definitions/ExprRef" } ] }, "theta": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "- For arc marks, the arc length in radians if theta2 is not specified, otherwise the start arc angle. (A value of 0 indicates up or “north”, increasing values proceed clockwise.)\n\n- For text marks, polar coordinate angle in radians.", "maximum": 360, "minimum": 0 }, "theta2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The end angle of arc marks in radians. A value of 0 indicates up or “north”, increasing values proceed clockwise." }, "theta2Offset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for theta2." }, "thetaOffset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for theta." }, "timeUnitBandPosition": { "description": "Default relative band position for a time unit. If set to `0`, the marks will be positioned at the beginning of the time unit band step. If set to `0.5`, the marks will be positioned in the middle of the time unit band step.", "type": "number" }, "timeUnitBandSize": { "description": "Default relative band size for a time unit. If set to `1`, the bandwidth of the marks will be equal to the time unit band step. If set to `0.5`, bandwidth of the marks will be half of the time unit band step.", "type": "number" }, "tooltip": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "type": "boolean" }, { "$ref": "#/definitions/TooltipContent" }, { "$ref": "#/definitions/ExprRef" }, { "type": "null" } ], "description": "The tooltip text string to show upon mouse hover or an object defining which fields should the tooltip be derived from.\n\n- If `tooltip` is `true` or `{\"content\": \"encoding\"}`, then all fields from `encoding` will be used.\n- If `tooltip` is `{\"content\": \"data\"}`, then all fields that appear in the highlighted data point will be used.\n- If set to `null` or `false`, then no tooltip will be used.\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite.\n\n__Default value:__ `null`" }, "url": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "The URL of the image file for image marks." }, { "$ref": "#/definitions/ExprRef" } ] }, "width": { "anyOf": [ { "description": "Width of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "x": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X coordinates of the marks, or width of horizontal `\"bar\"` and `\"area\"` without specified `x2` or `width`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2Offset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for x2-position." }, "xOffset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for x-position." }, "y": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y coordinates of the marks, or height of vertical `\"bar\"` and `\"area\"` without specified `y2` or `height`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2Offset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for y2-position." }, "yOffset": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Offset for y-position." } }, "type": "object" }, "Padding": { "anyOf": [ { "type": "number" }, { "additionalProperties": false, "properties": { "bottom": { "type": "number" }, "left": { "type": "number" }, "right": { "type": "number" }, "top": { "type": "number" } }, "type": "object" } ], "minimum": 0 }, "ParameterExtent": { "anyOf": [ { "additionalProperties": false, "properties": { "field": { "$ref": "#/definitions/FieldName", "description": "If a selection parameter is specified, the field name to extract selected values for when the selection is [projected](https://vega.github.io/vega-lite/docs/selection.html#project) over multiple fields or encodings." }, "param": { "$ref": "#/definitions/ParameterName", "description": "The name of a parameter." } }, "required": [ "param" ], "type": "object" }, { "additionalProperties": false, "properties": { "encoding": { "$ref": "#/definitions/SingleDefUnitChannel", "description": "If a selection parameter is specified, the encoding channel to extract selected values for when a selection is [projected](https://vega.github.io/vega-lite/docs/selection.html#project) over multiple fields or encodings." }, "param": { "$ref": "#/definitions/ParameterName", "description": "The name of a parameter." } }, "required": [ "param" ], "type": "object" } ] }, "ParameterName": { "type": "string" }, "ParameterPredicate": { "additionalProperties": false, "properties": { "empty": { "description": "For selection parameters, the predicate of empty selections returns true by default. Override this behavior, by setting this property `empty: false`.", "type": "boolean" }, "param": { "$ref": "#/definitions/ParameterName", "description": "Filter using a parameter name." } }, "required": [ "param" ], "type": "object" }, "Parse": { "additionalProperties": { "$ref": "#/definitions/ParseValue" }, "type": "object" }, "ParseValue": { "anyOf": [ { "type": "null" }, { "type": "string" }, { "const": "string", "type": "string" }, { "const": "boolean", "type": "string" }, { "const": "date", "type": "string" }, { "const": "number", "type": "string" } ] }, "PivotTransform": { "additionalProperties": false, "properties": { "groupby": { "description": "The optional data fields to group by. If not specified, a single group containing all data objects will be used.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "limit": { "description": "An optional parameter indicating the maximum number of pivoted fields to generate. The default (`0`) applies no limit. The pivoted `pivot` names are sorted in ascending order prior to enforcing the limit. __Default value:__ `0`", "type": "number" }, "op": { "$ref": "#/definitions/AggregateOp", "description": "The aggregation operation to apply to grouped `value` field values. __Default value:__ `sum`" }, "pivot": { "$ref": "#/definitions/FieldName", "description": "The data field to pivot on. The unique values of this field become new field names in the output stream." }, "value": { "$ref": "#/definitions/FieldName", "description": "The data field to populate pivoted fields. The aggregate values of this field become the values of the new pivoted fields." } }, "required": [ "pivot", "value" ], "type": "object" }, "Point": { "additionalProperties": false, "description": "Point geometry object. https://tools.ietf.org/html/rfc7946#section-3.1.2", "properties": { "bbox": { "$ref": "#/definitions/BBox", "description": "Bounding box of the coordinate range of the object's Geometries, Features, or Feature Collections. https://tools.ietf.org/html/rfc7946#section-5" }, "coordinates": { "$ref": "#/definitions/Position" }, "type": { "const": "Point", "description": "Specifies the type of GeoJSON object.", "type": "string" } }, "required": [ "coordinates", "type" ], "type": "object" }, "PointSelectionConfig": { "additionalProperties": false, "properties": { "clear": { "anyOf": [ { "$ref": "#/definitions/Stream" }, { "type": "string" }, { "type": "boolean" } ], "description": "Clears the selection, emptying it of all values. This property can be a [Event Stream](https://vega.github.io/vega/docs/event-streams/) or `false` to disable clear.\n\n__Default value:__ `dblclick`.\n\n__See also:__ [`clear` examples ](https://vega.github.io/vega-lite/docs/selection.html#clear) in the documentation." }, "encodings": { "description": "An array of encoding channels. The corresponding data field values must match for a data tuple to fall within the selection.\n\n__See also:__ The [projection with `encodings` and `fields` section](https://vega.github.io/vega-lite/docs/selection.html#project) in the documentation.", "items": { "$ref": "#/definitions/SingleDefUnitChannel" }, "type": "array" }, "fields": { "description": "An array of field names whose values must match for a data tuple to fall within the selection.\n\n__See also:__ The [projection with `encodings` and `fields` section](https://vega.github.io/vega-lite/docs/selection.html#project) in the documentation.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "nearest": { "description": "When true, an invisible voronoi diagram is computed to accelerate discrete selection. The data value _nearest_ the mouse cursor is added to the selection.\n\n__Default value:__ `false`, which means that data values must be interacted with directly (e.g., clicked on) to be added to the selection.\n\n__See also:__ [`nearest` examples](https://vega.github.io/vega-lite/docs/selection.html#nearest) documentation.", "type": "boolean" }, "on": { "anyOf": [ { "$ref": "#/definitions/Stream" }, { "type": "string" } ], "description": "A [Vega event stream](https://vega.github.io/vega/docs/event-streams/) (object or selector) that triggers the selection. For interval selections, the event stream must specify a [start and end](https://vega.github.io/vega/docs/event-streams/#between-filters).\n\n__See also:__ [`on` examples](https://vega.github.io/vega-lite/docs/selection.html#on) in the documentation." }, "resolve": { "$ref": "#/definitions/SelectionResolution", "description": "With layered and multi-view displays, a strategy that determines how selections' data queries are resolved when applied in a filter transform, conditional encoding rule, or scale domain.\n\nOne of:\n- `\"global\"` -- only one brush exists for the entire SPLOM. When the user begins to drag, any previous brushes are cleared, and a new one is constructed.\n- `\"union\"` -- each cell contains its own brush, and points are highlighted if they lie within _any_ of these individual brushes.\n- `\"intersect\"` -- each cell contains its own brush, and points are highlighted only if they fall within _all_ of these individual brushes.\n\n__Default value:__ `global`.\n\n__See also:__ [`resolve` examples](https://vega.github.io/vega-lite/docs/selection.html#resolve) in the documentation." }, "toggle": { "description": "Controls whether data values should be toggled (inserted or removed from a point selection) or only ever inserted into point selections.\n\nOne of:\n- `true` -- the default behavior, which corresponds to `\"event.shiftKey\"`. As a result, data values are toggled when the user interacts with the shift-key pressed.\n- `false` -- disables toggling behaviour; the selection will only ever contain a single data value corresponding to the most recent interaction.\n- A [Vega expression](https://vega.github.io/vega/docs/expressions/) which is re-evaluated as the user interacts. If the expression evaluates to `true`, the data value is toggled into or out of the point selection. If the expression evaluates to `false`, the point selection is first cleared, and the data value is then inserted. For example, setting the value to the Vega expression `\"true\"` will toggle data values without the user pressing the shift-key.\n\n__Default value:__ `true`\n\n__See also:__ [`toggle` examples](https://vega.github.io/vega-lite/docs/selection.html#toggle) in the documentation.", "type": [ "string", "boolean" ] }, "type": { "const": "point", "description": "Determines the default event processing and data query for the selection. Vega-Lite currently supports two selection types:\n\n- `\"point\"` -- to select multiple discrete data values; the first value is selected on `click` and additional values toggled on shift-click.\n- `\"interval\"` -- to select a continuous range of data values on `drag`.", "type": "string" } }, "required": [ "type" ], "type": "object" }, "PointSelectionConfigWithoutType": { "additionalProperties": false, "properties": { "clear": { "anyOf": [ { "$ref": "#/definitions/Stream" }, { "type": "string" }, { "type": "boolean" } ], "description": "Clears the selection, emptying it of all values. This property can be a [Event Stream](https://vega.github.io/vega/docs/event-streams/) or `false` to disable clear.\n\n__Default value:__ `dblclick`.\n\n__See also:__ [`clear` examples ](https://vega.github.io/vega-lite/docs/selection.html#clear) in the documentation." }, "encodings": { "description": "An array of encoding channels. The corresponding data field values must match for a data tuple to fall within the selection.\n\n__See also:__ The [projection with `encodings` and `fields` section](https://vega.github.io/vega-lite/docs/selection.html#project) in the documentation.", "items": { "$ref": "#/definitions/SingleDefUnitChannel" }, "type": "array" }, "fields": { "description": "An array of field names whose values must match for a data tuple to fall within the selection.\n\n__See also:__ The [projection with `encodings` and `fields` section](https://vega.github.io/vega-lite/docs/selection.html#project) in the documentation.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "nearest": { "description": "When true, an invisible voronoi diagram is computed to accelerate discrete selection. The data value _nearest_ the mouse cursor is added to the selection.\n\n__Default value:__ `false`, which means that data values must be interacted with directly (e.g., clicked on) to be added to the selection.\n\n__See also:__ [`nearest` examples](https://vega.github.io/vega-lite/docs/selection.html#nearest) documentation.", "type": "boolean" }, "on": { "anyOf": [ { "$ref": "#/definitions/Stream" }, { "type": "string" } ], "description": "A [Vega event stream](https://vega.github.io/vega/docs/event-streams/) (object or selector) that triggers the selection. For interval selections, the event stream must specify a [start and end](https://vega.github.io/vega/docs/event-streams/#between-filters).\n\n__See also:__ [`on` examples](https://vega.github.io/vega-lite/docs/selection.html#on) in the documentation." }, "resolve": { "$ref": "#/definitions/SelectionResolution", "description": "With layered and multi-view displays, a strategy that determines how selections' data queries are resolved when applied in a filter transform, conditional encoding rule, or scale domain.\n\nOne of:\n- `\"global\"` -- only one brush exists for the entire SPLOM. When the user begins to drag, any previous brushes are cleared, and a new one is constructed.\n- `\"union\"` -- each cell contains its own brush, and points are highlighted if they lie within _any_ of these individual brushes.\n- `\"intersect\"` -- each cell contains its own brush, and points are highlighted only if they fall within _all_ of these individual brushes.\n\n__Default value:__ `global`.\n\n__See also:__ [`resolve` examples](https://vega.github.io/vega-lite/docs/selection.html#resolve) in the documentation." }, "toggle": { "description": "Controls whether data values should be toggled (inserted or removed from a point selection) or only ever inserted into point selections.\n\nOne of:\n- `true` -- the default behavior, which corresponds to `\"event.shiftKey\"`. As a result, data values are toggled when the user interacts with the shift-key pressed.\n- `false` -- disables toggling behaviour; the selection will only ever contain a single data value corresponding to the most recent interaction.\n- A [Vega expression](https://vega.github.io/vega/docs/expressions/) which is re-evaluated as the user interacts. If the expression evaluates to `true`, the data value is toggled into or out of the point selection. If the expression evaluates to `false`, the point selection is first cleared, and the data value is then inserted. For example, setting the value to the Vega expression `\"true\"` will toggle data values without the user pressing the shift-key.\n\n__Default value:__ `true`\n\n__See also:__ [`toggle` examples](https://vega.github.io/vega-lite/docs/selection.html#toggle) in the documentation.", "type": [ "string", "boolean" ] } }, "type": "object" }, "PolarDef": { "anyOf": [ { "$ref": "#/definitions/PositionFieldDefBase" }, { "$ref": "#/definitions/PositionDatumDefBase" }, { "$ref": "#/definitions/PositionValueDef" } ] }, "Polygon": { "additionalProperties": false, "description": "Polygon geometry object. https://tools.ietf.org/html/rfc7946#section-3.1.6", "properties": { "bbox": { "$ref": "#/definitions/BBox", "description": "Bounding box of the coordinate range of the object's Geometries, Features, or Feature Collections. https://tools.ietf.org/html/rfc7946#section-5" }, "coordinates": { "items": { "items": { "$ref": "#/definitions/Position" }, "type": "array" }, "type": "array" }, "type": { "const": "Polygon", "description": "Specifies the type of GeoJSON object.", "type": "string" } }, "required": [ "coordinates", "type" ], "type": "object" }, "Position": { "description": "A Position is an array of coordinates. https://tools.ietf.org/html/rfc7946#section-3.1.1 Array should contain between two and three elements. The previous GeoJSON specification allowed more elements (e.g., which could be used to represent M values), but the current specification only allows X, Y, and (optionally) Z to be defined.", "items": { "type": "number" }, "type": "array" }, "Position2Def": { "anyOf": [ { "$ref": "#/definitions/SecondaryFieldDef" }, { "$ref": "#/definitions/DatumDef" }, { "$ref": "#/definitions/PositionValueDef" } ] }, "PositionDatumDef": { "additionalProperties": false, "properties": { "axis": { "anyOf": [ { "$ref": "#/definitions/Axis" }, { "type": "null" } ], "description": "An object defining properties of axis's gridlines, ticks and labels. If `null`, the axis for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [axis properties](https://vega.github.io/vega-lite/docs/axis.html) are applied.\n\n__See also:__ [`axis`](https://vega.github.io/vega-lite/docs/axis.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "impute": { "anyOf": [ { "$ref": "#/definitions/ImputeParams" }, { "type": "null" } ], "description": "An object defining the properties of the Impute Operation to be applied. The field value of the other positional channel is taken as `key` of the `Impute` Operation. The field of the `color` channel if specified is used as `groupby` of the `Impute` Operation.\n\n__See also:__ [`impute`](https://vega.github.io/vega-lite/docs/impute.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "stack": { "anyOf": [ { "$ref": "#/definitions/StackOffset" }, { "type": "null" }, { "type": "boolean" } ], "description": "Type of stacking offset if the field should be stacked. `stack` is only applicable for `x`, `y`, `theta`, and `radius` channels with continuous domains. For example, `stack` of `y` can be used to customize stacking for a vertical bar chart.\n\n`stack` can be one of the following values:\n- `\"zero\"` or `true`: stacking with baseline offset at zero value of the scale (for creating typical stacked [bar](https://vega.github.io/vega-lite/docs/stack.html#bar) and [area](https://vega.github.io/vega-lite/docs/stack.html#area) chart).\n- `\"normalize\"` - stacking with normalized domain (for creating [normalized stacked bar and area charts](https://vega.github.io/vega-lite/docs/stack.html#normalized) and pie charts [with percentage tooltip](https://vega.github.io/vega-lite/docs/arc.html#tooltip)).
\n-`\"center\"` - stacking with center baseline (for [streamgraph](https://vega.github.io/vega-lite/docs/stack.html#streamgraph)).\n- `null` or `false` - No-stacking. This will produce layered [bar](https://vega.github.io/vega-lite/docs/stack.html#layered-bar-chart) and area chart.\n\n__Default value:__ `zero` for plots with all of the following conditions are true: (1) the mark is `bar`, `area`, or `arc`; (2) the stacked measure channel (x or y) has a linear scale; (3) At least one of non-position channels mapped to an unaggregated field that is different from x and y. Otherwise, `null` by default.\n\n__See also:__ [`stack`](https://vega.github.io/vega-lite/docs/stack.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "PositionDatumDefBase": { "additionalProperties": false, "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "stack": { "anyOf": [ { "$ref": "#/definitions/StackOffset" }, { "type": "null" }, { "type": "boolean" } ], "description": "Type of stacking offset if the field should be stacked. `stack` is only applicable for `x`, `y`, `theta`, and `radius` channels with continuous domains. For example, `stack` of `y` can be used to customize stacking for a vertical bar chart.\n\n`stack` can be one of the following values:\n- `\"zero\"` or `true`: stacking with baseline offset at zero value of the scale (for creating typical stacked [bar](https://vega.github.io/vega-lite/docs/stack.html#bar) and [area](https://vega.github.io/vega-lite/docs/stack.html#area) chart).\n- `\"normalize\"` - stacking with normalized domain (for creating [normalized stacked bar and area charts](https://vega.github.io/vega-lite/docs/stack.html#normalized) and pie charts [with percentage tooltip](https://vega.github.io/vega-lite/docs/arc.html#tooltip)).
\n-`\"center\"` - stacking with center baseline (for [streamgraph](https://vega.github.io/vega-lite/docs/stack.html#streamgraph)).\n- `null` or `false` - No-stacking. This will produce layered [bar](https://vega.github.io/vega-lite/docs/stack.html#layered-bar-chart) and area chart.\n\n__Default value:__ `zero` for plots with all of the following conditions are true: (1) the mark is `bar`, `area`, or `arc`; (2) the stacked measure channel (x or y) has a linear scale; (3) At least one of non-position channels mapped to an unaggregated field that is different from x and y. Otherwise, `null` by default.\n\n__See also:__ [`stack`](https://vega.github.io/vega-lite/docs/stack.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "PositionDef": { "anyOf": [ { "$ref": "#/definitions/PositionFieldDef" }, { "$ref": "#/definitions/PositionDatumDef" }, { "$ref": "#/definitions/PositionValueDef" } ] }, "PositionFieldDef": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "axis": { "anyOf": [ { "$ref": "#/definitions/Axis" }, { "type": "null" } ], "description": "An object defining properties of axis's gridlines, ticks and labels. If `null`, the axis for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [axis properties](https://vega.github.io/vega-lite/docs/axis.html) are applied.\n\n__See also:__ [`axis`](https://vega.github.io/vega-lite/docs/axis.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "impute": { "anyOf": [ { "$ref": "#/definitions/ImputeParams" }, { "type": "null" } ], "description": "An object defining the properties of the Impute Operation to be applied. The field value of the other positional channel is taken as `key` of the `Impute` Operation. The field of the `color` channel if specified is used as `groupby` of the `Impute` Operation.\n\n__See also:__ [`impute`](https://vega.github.io/vega-lite/docs/impute.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "stack": { "anyOf": [ { "$ref": "#/definitions/StackOffset" }, { "type": "null" }, { "type": "boolean" } ], "description": "Type of stacking offset if the field should be stacked. `stack` is only applicable for `x`, `y`, `theta`, and `radius` channels with continuous domains. For example, `stack` of `y` can be used to customize stacking for a vertical bar chart.\n\n`stack` can be one of the following values:\n- `\"zero\"` or `true`: stacking with baseline offset at zero value of the scale (for creating typical stacked [bar](https://vega.github.io/vega-lite/docs/stack.html#bar) and [area](https://vega.github.io/vega-lite/docs/stack.html#area) chart).\n- `\"normalize\"` - stacking with normalized domain (for creating [normalized stacked bar and area charts](https://vega.github.io/vega-lite/docs/stack.html#normalized) and pie charts [with percentage tooltip](https://vega.github.io/vega-lite/docs/arc.html#tooltip)).
\n-`\"center\"` - stacking with center baseline (for [streamgraph](https://vega.github.io/vega-lite/docs/stack.html#streamgraph)).\n- `null` or `false` - No-stacking. This will produce layered [bar](https://vega.github.io/vega-lite/docs/stack.html#layered-bar-chart) and area chart.\n\n__Default value:__ `zero` for plots with all of the following conditions are true: (1) the mark is `bar`, `area`, or `arc`; (2) the stacked measure channel (x or y) has a linear scale; (3) At least one of non-position channels mapped to an unaggregated field that is different from x and y. Otherwise, `null` by default.\n\n__See also:__ [`stack`](https://vega.github.io/vega-lite/docs/stack.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "PositionFieldDefBase": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "stack": { "anyOf": [ { "$ref": "#/definitions/StackOffset" }, { "type": "null" }, { "type": "boolean" } ], "description": "Type of stacking offset if the field should be stacked. `stack` is only applicable for `x`, `y`, `theta`, and `radius` channels with continuous domains. For example, `stack` of `y` can be used to customize stacking for a vertical bar chart.\n\n`stack` can be one of the following values:\n- `\"zero\"` or `true`: stacking with baseline offset at zero value of the scale (for creating typical stacked [bar](https://vega.github.io/vega-lite/docs/stack.html#bar) and [area](https://vega.github.io/vega-lite/docs/stack.html#area) chart).\n- `\"normalize\"` - stacking with normalized domain (for creating [normalized stacked bar and area charts](https://vega.github.io/vega-lite/docs/stack.html#normalized) and pie charts [with percentage tooltip](https://vega.github.io/vega-lite/docs/arc.html#tooltip)).
\n-`\"center\"` - stacking with center baseline (for [streamgraph](https://vega.github.io/vega-lite/docs/stack.html#streamgraph)).\n- `null` or `false` - No-stacking. This will produce layered [bar](https://vega.github.io/vega-lite/docs/stack.html#layered-bar-chart) and area chart.\n\n__Default value:__ `zero` for plots with all of the following conditions are true: (1) the mark is `bar`, `area`, or `arc`; (2) the stacked measure channel (x or y) has a linear scale; (3) At least one of non-position channels mapped to an unaggregated field that is different from x and y. Otherwise, `null` by default.\n\n__See also:__ [`stack`](https://vega.github.io/vega-lite/docs/stack.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "PositionValueDef": { "$ref": "#/definitions/ValueDef<(number|\"width\"|\"height\"|ExprRef)>" }, "Predicate": { "anyOf": [ { "$ref": "#/definitions/FieldEqualPredicate" }, { "$ref": "#/definitions/FieldRangePredicate" }, { "$ref": "#/definitions/FieldOneOfPredicate" }, { "$ref": "#/definitions/FieldLTPredicate" }, { "$ref": "#/definitions/FieldGTPredicate" }, { "$ref": "#/definitions/FieldLTEPredicate" }, { "$ref": "#/definitions/FieldGTEPredicate" }, { "$ref": "#/definitions/FieldValidPredicate" }, { "$ref": "#/definitions/ParameterPredicate" }, { "type": "string" } ] }, "PrimitiveValue": { "type": [ "number", "string", "boolean", "null" ] }, "Projection": { "additionalProperties": false, "properties": { "center": { "anyOf": [ { "$ref": "#/definitions/Vector2", "description": "The projection's center, a two-element array of longitude and latitude in degrees.\n\n__Default value:__ `[0, 0]`" }, { "$ref": "#/definitions/ExprRef" } ] }, "clipAngle": { "anyOf": [ { "description": "The projection's clipping circle radius to the specified angle in degrees. If `null`, switches to [antimeridian](http://bl.ocks.org/mbostock/3788999) cutting rather than small-circle clipping.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "clipExtent": { "anyOf": [ { "$ref": "#/definitions/Vector2>", "description": "The projection's viewport clip extent to the specified bounds in pixels. The extent bounds are specified as an array `[[x0, y0], [x1, y1]]`, where `x0` is the left-side of the viewport, `y0` is the top, `x1` is the right and `y1` is the bottom. If `null`, no viewport clipping is performed." }, { "$ref": "#/definitions/ExprRef" } ] }, "coefficient": { "anyOf": [ { "description": "The coefficient parameter for the `hammer` projection.\n\n__Default value:__ `2`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "distance": { "anyOf": [ { "description": "For the `satellite` projection, the distance from the center of the sphere to the point of view, as a proportion of the sphere’s radius. The recommended maximum clip angle for a given `distance` is acos(1 / distance) converted to degrees. If tilt is also applied, then more conservative clipping may be necessary.\n\n__Default value:__ `2.0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "extent": { "anyOf": [ { "$ref": "#/definitions/Vector2>" }, { "$ref": "#/definitions/ExprRef" } ] }, "fit": { "anyOf": [ { "$ref": "#/definitions/Fit" }, { "items": { "$ref": "#/definitions/Fit" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "fraction": { "anyOf": [ { "description": "The fraction parameter for the `bottomley` projection.\n\n__Default value:__ `0.5`, corresponding to a sin(ψ) where ψ = π/6.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "lobes": { "anyOf": [ { "description": "The number of lobes in projections that support multi-lobe views: `berghaus`, `gingery`, or `healpix`. The default value varies based on the projection type.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "parallel": { "anyOf": [ { "description": "The parallel parameter for projections that support it: `armadillo`, `bonne`, `craig`, `cylindricalEqualArea`, `cylindricalStereographic`, `hammerRetroazimuthal`, `loximuthal`, or `rectangularPolyconic`. The default value varies based on the projection type.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "parallels": { "anyOf": [ { "description": "For conic projections, the [two standard parallels](https://en.wikipedia.org/wiki/Map_projection#Conic) that define the map layout. The default depends on the specific conic projection used.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "pointRadius": { "anyOf": [ { "description": "The default radius (in pixels) to use when drawing GeoJSON `Point` and `MultiPoint` geometries. This parameter sets a constant default value. To modify the point radius in response to data, see the corresponding parameter of the GeoPath and GeoShape transforms.\n\n__Default value:__ `4.5`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "precision": { "anyOf": [ { "description": "The threshold for the projection's [adaptive resampling](http://bl.ocks.org/mbostock/3795544) to the specified value in pixels. This value corresponds to the [Douglas–Peucker distance](http://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm). If precision is not specified, returns the projection's current resampling precision which defaults to `√0.5 ≅ 0.70710…`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "radius": { "anyOf": [ { "description": "The radius parameter for the `airy` or `gingery` projection. The default value varies based on the projection type.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "ratio": { "anyOf": [ { "description": "The ratio parameter for the `hill`, `hufnagel`, or `wagner` projections. The default value varies based on the projection type.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "reflectX": { "anyOf": [ { "description": "Sets whether or not the x-dimension is reflected (negated) in the output.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "reflectY": { "anyOf": [ { "description": "Sets whether or not the y-dimension is reflected (negated) in the output.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "rotate": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/Vector2" }, { "$ref": "#/definitions/Vector3" } ], "description": "The projection's three-axis rotation to the specified angles, which must be a two- or three-element array of numbers [`lambda`, `phi`, `gamma`] specifying the rotation angles in degrees about each spherical axis. (These correspond to yaw, pitch and roll.)\n\n__Default value:__ `[0, 0, 0]`" }, { "$ref": "#/definitions/ExprRef" } ] }, "scale": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The projection’s scale (zoom) factor, overriding automatic fitting. The default scale is projection-specific. The scale factor corresponds linearly to the distance between projected points; however, scale factor values are not equivalent across projections." }, "size": { "anyOf": [ { "$ref": "#/definitions/Vector2", "description": "Used in conjunction with fit, provides the width and height in pixels of the area to which the projection should be automatically fit." }, { "$ref": "#/definitions/ExprRef" } ] }, "spacing": { "anyOf": [ { "description": "The spacing parameter for the `lagrange` projection.\n\n__Default value:__ `0.5`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "tilt": { "anyOf": [ { "description": "The tilt angle (in degrees) for the `satellite` projection.\n\n__Default value:__ `0`.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "translate": { "anyOf": [ { "$ref": "#/definitions/Vector2" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The projection’s translation offset as a two-element array `[tx, ty]`." }, "type": { "anyOf": [ { "$ref": "#/definitions/ProjectionType" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The cartographic projection to use. This value is case-insensitive, for example `\"albers\"` and `\"Albers\"` indicate the same projection type. You can find all valid projection types [in the documentation](https://vega.github.io/vega-lite/docs/projection.html#projection-types).\n\n__Default value:__ `equalEarth`" } }, "type": "object" }, "ProjectionConfig": { "$ref": "#/definitions/Projection", "description": "Any property of Projection can be in config" }, "ProjectionType": { "enum": [ "albers", "albersUsa", "azimuthalEqualArea", "azimuthalEquidistant", "conicConformal", "conicEqualArea", "conicEquidistant", "equalEarth", "equirectangular", "gnomonic", "identity", "mercator", "naturalEarth1", "orthographic", "stereographic", "transverseMercator" ], "type": "string" }, "QuantileTransform": { "additionalProperties": false, "properties": { "as": { "description": "The output field names for the probability and quantile values.\n\n__Default value:__ `[\"prob\", \"value\"]`", "items": { "$ref": "#/definitions/FieldName" }, "maxItems": 2, "minItems": 2, "type": "array" }, "groupby": { "description": "The data fields to group by. If not specified, a single group containing all data objects will be used.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "probs": { "description": "An array of probabilities in the range (0, 1) for which to compute quantile values. If not specified, the *step* parameter will be used.", "items": { "type": "number" }, "type": "array" }, "quantile": { "$ref": "#/definitions/FieldName", "description": "The data field for which to perform quantile estimation." }, "step": { "description": "A probability step size (default 0.01) for sampling quantile values. All values from one-half the step size up to 1 (exclusive) will be sampled. This parameter is only used if the *probs* parameter is not provided.", "type": "number" } }, "required": [ "quantile" ], "type": "object" }, "RadialGradient": { "additionalProperties": false, "properties": { "gradient": { "const": "radial", "description": "The type of gradient. Use `\"radial\"` for a radial gradient.", "type": "string" }, "id": { "type": "string" }, "r1": { "description": "The radius length, in normalized [0, 1] coordinates, of the inner circle for the gradient.\n\n__Default value:__ `0`", "type": "number" }, "r2": { "description": "The radius length, in normalized [0, 1] coordinates, of the outer circle for the gradient.\n\n__Default value:__ `0.5`", "type": "number" }, "stops": { "description": "An array of gradient stops defining the gradient color sequence.", "items": { "$ref": "#/definitions/GradientStop" }, "type": "array" }, "x1": { "description": "The x-coordinate, in normalized [0, 1] coordinates, for the center of the inner circle for the gradient.\n\n__Default value:__ `0.5`", "type": "number" }, "x2": { "description": "The x-coordinate, in normalized [0, 1] coordinates, for the center of the outer circle for the gradient.\n\n__Default value:__ `0.5`", "type": "number" }, "y1": { "description": "The y-coordinate, in normalized [0, 1] coordinates, for the center of the inner circle for the gradient.\n\n__Default value:__ `0.5`", "type": "number" }, "y2": { "description": "The y-coordinate, in normalized [0, 1] coordinates, for the center of the outer circle for the gradient.\n\n__Default value:__ `0.5`", "type": "number" } }, "required": [ "gradient", "stops" ], "type": "object" }, "RangeConfig": { "additionalProperties": { "anyOf": [ { "$ref": "#/definitions/RangeScheme" }, { "type": "array" } ] }, "properties": { "category": { "anyOf": [ { "$ref": "#/definitions/RangeScheme" }, { "items": { "$ref": "#/definitions/Color" }, "type": "array" } ], "description": "Default [color scheme](https://vega.github.io/vega/docs/schemes/) for categorical data." }, "diverging": { "anyOf": [ { "$ref": "#/definitions/RangeScheme" }, { "items": { "$ref": "#/definitions/Color" }, "type": "array" } ], "description": "Default [color scheme](https://vega.github.io/vega/docs/schemes/) for diverging quantitative ramps." }, "heatmap": { "anyOf": [ { "$ref": "#/definitions/RangeScheme" }, { "items": { "$ref": "#/definitions/Color" }, "type": "array" } ], "description": "Default [color scheme](https://vega.github.io/vega/docs/schemes/) for quantitative heatmaps." }, "ordinal": { "anyOf": [ { "$ref": "#/definitions/RangeScheme" }, { "items": { "$ref": "#/definitions/Color" }, "type": "array" } ], "description": "Default [color scheme](https://vega.github.io/vega/docs/schemes/) for rank-ordered data." }, "ramp": { "anyOf": [ { "$ref": "#/definitions/RangeScheme" }, { "items": { "$ref": "#/definitions/Color" }, "type": "array" } ], "description": "Default [color scheme](https://vega.github.io/vega/docs/schemes/) for sequential quantitative ramps." }, "symbol": { "description": "Array of [symbol](https://vega.github.io/vega/docs/marks/symbol/) names or paths for the default shape palette.", "items": { "$ref": "#/definitions/SymbolShape" }, "type": "array" } }, "type": "object" }, "RangeEnum": { "enum": [ "width", "height", "symbol", "category", "ordinal", "ramp", "diverging", "heatmap" ], "type": "string" }, "RangeRaw": { "items": { "anyOf": [ { "type": "null" }, { "type": "boolean" }, { "type": "string" }, { "type": "number" }, { "$ref": "#/definitions/RangeRawArray" } ] }, "type": "array" }, "RangeRawArray": { "items": { "type": "number" }, "type": "array" }, "RangeScheme": { "anyOf": [ { "$ref": "#/definitions/RangeEnum" }, { "$ref": "#/definitions/RangeRaw" }, { "additionalProperties": false, "properties": { "count": { "type": "number" }, "extent": { "items": { "type": "number" }, "type": "array" }, "scheme": { "anyOf": [ { "type": "string" }, { "items": { "type": "string" }, "type": "array" }, { "$ref": "#/definitions/ColorScheme" } ] } }, "required": [ "scheme" ], "type": "object" } ] }, "RectConfig": { "additionalProperties": false, "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The horizontal alignment of the text or ranged marks (area, bar, image, rect, rule). One of `\"left\"`, `\"right\"`, `\"center\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "angle": { "anyOf": [ { "description": "The rotation angle of the text, in degrees.", "maximum": 360, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG element, removing the mark item from the ARIA accessibility tree.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRole": { "anyOf": [ { "description": "Sets the type of user interface element of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"role\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRoleDescription": { "anyOf": [ { "description": "A human-readable, author-localized description for the role of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"aria-roledescription\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "aspect": { "anyOf": [ { "description": "Whether to keep aspect ratio of image marks.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "baseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For text marks, the vertical text baseline. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, `\"line-bottom\"`, or an expression reference that provides one of the valid values. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `lineHeight` rather than `fontSize` alone.\n\nFor range marks, the vertical alignment of the marks. One of `\"top\"`, `\"middle\"`, `\"bottom\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "binSpacing": { "description": "Offset between bars for binned field. The ideal value for this is either 0 (preferred by statisticians) or 1 (Vega-Lite default, D3 example style).\n\n__Default value:__ `1`", "minimum": 0, "type": "number" }, "blend": { "anyOf": [ { "$ref": "#/definitions/Blend", "description": "The color blend mode for drawing an item on its current background. Any valid [CSS mix-blend-mode](https://developer.mozilla.org/en-US/docs/Web/CSS/mix-blend-mode) value can be used.\n\n__Default value: `\"source-over\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "continuousBandSize": { "description": "The default size of the bars on continuous scales.\n\n__Default value:__ `5`", "minimum": 0, "type": "number" }, "cornerRadius": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles or arcs' corners.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cursor": { "anyOf": [ { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the mark. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "dir": { "anyOf": [ { "$ref": "#/definitions/TextDirection", "description": "The direction of the text. One of `\"ltr\"` (left-to-right) or `\"rtl\"` (right-to-left). This property determines on which side is truncated in response to the limit parameter.\n\n__Default value:__ `\"ltr\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "discreteBandSize": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RelativeBandSize" } ], "description": "The default size of the bars with discrete dimensions. If unspecified, the default size is `step-2`, which provides 2 pixel offset between bars.", "minimum": 0 }, "dx": { "anyOf": [ { "description": "The horizontal offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "dy": { "anyOf": [ { "description": "The vertical offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "ellipsis": { "anyOf": [ { "description": "The ellipsis string for text truncated in response to the limit parameter.\n\n__Default value:__ `\"…\"`", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "endAngle": { "anyOf": [ { "description": "The end angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fill": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default fill color. This property has higher precedence than `config.color`. Set to `null` to remove fill.\n\n__Default value:__ (None)" }, "fillOpacity": { "anyOf": [ { "description": "The fill opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "filled": { "description": "Whether the mark's color should be used as fill color instead of stroke color.\n\n__Default value:__ `false` for all `point`, `line`, and `rule` marks as well as `geoshape` marks for [`graticule`](https://vega.github.io/vega-lite/docs/data.html#graticule) data sources; otherwise, `true`.\n\n__Note:__ This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).", "type": "boolean" }, "font": { "anyOf": [ { "description": "The typeface to set the text in (e.g., `\"Helvetica Neue\"`).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontSize": { "anyOf": [ { "description": "The font size, in pixels.\n\n__Default value:__ `11`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style (e.g., `\"italic\"`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "fontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "height": { "anyOf": [ { "description": "Height of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "href": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "A URL to load upon mouse click. If defined, the mark acts as a hyperlink." }, { "$ref": "#/definitions/ExprRef" } ] }, "innerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The inner radius in pixels of arc marks. `innerRadius` is an alias for `radius2`.\n\n__Default value:__ `0`", "minimum": 0 }, "interpolate": { "anyOf": [ { "$ref": "#/definitions/Interpolate", "description": "The line interpolation method to use for line and area marks. One of the following:\n- `\"linear\"`: piecewise linear segments, as in a polyline.\n- `\"linear-closed\"`: close the linear segments to form a polygon.\n- `\"step\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"step-before\"`: alternate between vertical and horizontal segments, as in a step function.\n- `\"step-after\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"basis\"`: a B-spline, with control point duplication on the ends.\n- `\"basis-open\"`: an open B-spline; may not intersect the start or end.\n- `\"basis-closed\"`: a closed B-spline, as in a loop.\n- `\"cardinal\"`: a Cardinal spline, with control point duplication on the ends.\n- `\"cardinal-open\"`: an open Cardinal spline; may not intersect the start or end, but will intersect other control points.\n- `\"cardinal-closed\"`: a closed Cardinal spline, as in a loop.\n- `\"bundle\"`: equivalent to basis, except the tension parameter is used to straighten the spline.\n- `\"monotone\"`: cubic interpolation that preserves monotonicity in y." }, { "$ref": "#/definitions/ExprRef" } ] }, "invalid": { "description": "Defines how Vega-Lite should handle marks for invalid values (`null` and `NaN`).\n- If set to `\"filter\"` (default), all data items with null values will be skipped (for line, trail, and area marks) or filtered (for other marks).\n- If `null`, all data items are included. In this case, invalid values will be interpreted as zeroes.", "enum": [ "filter", null ], "type": [ "string", "null" ] }, "limit": { "anyOf": [ { "description": "The maximum length of the text mark in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0` -- indicating no limit", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineBreak": { "anyOf": [ { "description": "A delimiter, such as a newline character, upon which to break text strings into multiple lines. This property is ignored if the text is array-valued.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineHeight": { "anyOf": [ { "description": "The line height in pixels (the spacing between subsequent lines of text) for multi-line text marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "minBandSize": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The minimum band size for bar and rectangle marks. __Default value:__ `0.25`" }, "opacity": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The overall opacity (value between [0,1]).\n\n__Default value:__ `0.7` for non-aggregate plots with `point`, `tick`, `circle`, or `square` marks or layered `bar` charts and `1` otherwise.", "maximum": 1, "minimum": 0 }, "order": { "description": "For line and trail marks, this `order` property can be set to `null` or `false` to make the lines use the original order in the data sources.", "type": [ "null", "boolean" ] }, "orient": { "$ref": "#/definitions/Orientation", "description": "The orientation of a non-stacked bar, tick, area, and line charts. The value is either horizontal (default) or vertical.\n- For bar, rule and tick, this determines whether the size of the bar and tick should be applied to x or y dimension.\n- For area, this property determines the orient property of the Vega output.\n- For line and trail marks, this property determines the sort order of the points in the line if `config.sortLineBy` is not specified. For stacked charts, this is always determined by the orientation of the stack; therefore explicitly specified value will be ignored." }, "outerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The outer radius in pixels of arc marks. `outerRadius` is an alias for `radius`.\n\n__Default value:__ `0`", "minimum": 0 }, "padAngle": { "anyOf": [ { "description": "The angular padding applied to sides of the arc, in radians.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "radius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For arc mark, the primary (outer) radius in pixels.\n\nFor text marks, polar coordinate radial offset, in pixels, of the text from the origin determined by the `x` and `y` properties.\n\n__Default value:__ `min(plot_width, plot_height)/2`", "minimum": 0 }, "radius2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The secondary (inner) radius in pixels of arc marks.\n\n__Default value:__ `0`", "minimum": 0 }, "shape": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/SymbolShape" }, { "type": "string" } ], "description": "Shape of the point marks. Supported values include:\n- plotting shapes: `\"circle\"`, `\"square\"`, `\"cross\"`, `\"diamond\"`, `\"triangle-up\"`, `\"triangle-down\"`, `\"triangle-right\"`, or `\"triangle-left\"`.\n- the line symbol `\"stroke\"`\n- centered directional shapes `\"arrow\"`, `\"wedge\"`, or `\"triangle\"`\n- a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) (For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.)\n\n__Default value:__ `\"circle\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "size": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default size for marks.\n- For `point`/`circle`/`square`, this represents the pixel area of the marks. Note that this value sets the area of the symbol; the side lengths will increase with the square root of this value.\n- For `bar`, this represents the band size of the bar, in pixels.\n- For `text`, this represents the font size, in pixels.\n\n__Default value:__\n- `30` for point, circle, square marks; width/height's `step`\n- `2` for bar marks with discrete dimensions;\n- `5` for bar marks with continuous dimensions;\n- `11` for text marks.", "minimum": 0 }, "smooth": { "anyOf": [ { "description": "A boolean flag (default true) indicating if the image should be smoothed when resized. If false, individual pixels should be scaled directly rather than interpolated with smoothing. For SVG rendering, this option may not work in some browsers due to lack of standardization.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "startAngle": { "anyOf": [ { "description": "The start angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "stroke": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default stroke color. This property has higher precedence than `config.color`. Set to `null` to remove stroke.\n\n__Default value:__ (None)" }, "strokeCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for line ending style. One of `\"butt\"`, `\"round\"`, or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "An array of alternating stroke, space lengths for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDashOffset": { "anyOf": [ { "description": "The offset (in pixels) into which to begin drawing with the stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeJoin": { "anyOf": [ { "$ref": "#/definitions/StrokeJoin", "description": "The stroke line join method. One of `\"miter\"`, `\"round\"` or `\"bevel\"`.\n\n__Default value:__ `\"miter\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeMiterLimit": { "anyOf": [ { "description": "The miter limit at which to bevel a line join.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOffset": { "anyOf": [ { "description": "The offset in pixels at which to draw the group stroke and fill. If unspecified, the default behavior is to dynamically offset stroked groups such that 1 pixel stroke widths align with the pixel grid.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOpacity": { "anyOf": [ { "description": "The stroke opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "The stroke width, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "tension": { "anyOf": [ { "description": "Depending on the interpolation type, sets the tension parameter (for line and area marks).", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "text": { "anyOf": [ { "$ref": "#/definitions/Text", "description": "Placeholder text if the `text` channel is not specified" }, { "$ref": "#/definitions/ExprRef" } ] }, "theta": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "- For arc marks, the arc length in radians if theta2 is not specified, otherwise the start arc angle. (A value of 0 indicates up or “north”, increasing values proceed clockwise.)\n\n- For text marks, polar coordinate angle in radians.", "maximum": 360, "minimum": 0 }, "theta2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The end angle of arc marks in radians. A value of 0 indicates up or “north”, increasing values proceed clockwise." }, "timeUnitBandPosition": { "description": "Default relative band position for a time unit. If set to `0`, the marks will be positioned at the beginning of the time unit band step. If set to `0.5`, the marks will be positioned in the middle of the time unit band step.", "type": "number" }, "timeUnitBandSize": { "description": "Default relative band size for a time unit. If set to `1`, the bandwidth of the marks will be equal to the time unit band step. If set to `0.5`, bandwidth of the marks will be half of the time unit band step.", "type": "number" }, "tooltip": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "type": "boolean" }, { "$ref": "#/definitions/TooltipContent" }, { "$ref": "#/definitions/ExprRef" }, { "type": "null" } ], "description": "The tooltip text string to show upon mouse hover or an object defining which fields should the tooltip be derived from.\n\n- If `tooltip` is `true` or `{\"content\": \"encoding\"}`, then all fields from `encoding` will be used.\n- If `tooltip` is `{\"content\": \"data\"}`, then all fields that appear in the highlighted data point will be used.\n- If set to `null` or `false`, then no tooltip will be used.\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite.\n\n__Default value:__ `null`" }, "url": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "The URL of the image file for image marks." }, { "$ref": "#/definitions/ExprRef" } ] }, "width": { "anyOf": [ { "description": "Width of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "x": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X coordinates of the marks, or width of horizontal `\"bar\"` and `\"area\"` without specified `x2` or `width`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "y": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y coordinates of the marks, or height of vertical `\"bar\"` and `\"area\"` without specified `y2` or `height`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." } }, "type": "object" }, "RegressionTransform": { "additionalProperties": false, "properties": { "as": { "description": "The output field names for the smoothed points generated by the regression transform.\n\n__Default value:__ The field names of the input x and y values.", "items": { "$ref": "#/definitions/FieldName" }, "maxItems": 2, "minItems": 2, "type": "array" }, "extent": { "description": "A [min, max] domain over the independent (x) field for the starting and ending points of the generated trend line.", "items": { "type": "number" }, "maxItems": 2, "minItems": 2, "type": "array" }, "groupby": { "description": "The data fields to group by. If not specified, a single group containing all data objects will be used.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "method": { "description": "The functional form of the regression model. One of `\"linear\"`, `\"log\"`, `\"exp\"`, `\"pow\"`, `\"quad\"`, or `\"poly\"`.\n\n__Default value:__ `\"linear\"`", "enum": [ "linear", "log", "exp", "pow", "quad", "poly" ], "type": "string" }, "on": { "$ref": "#/definitions/FieldName", "description": "The data field of the independent variable to use a predictor." }, "order": { "description": "The polynomial order (number of coefficients) for the 'poly' method.\n\n__Default value:__ `3`", "type": "number" }, "params": { "description": "A boolean flag indicating if the transform should return the regression model parameters (one object per group), rather than trend line points. The resulting objects include a `coef` array of fitted coefficient values (starting with the intercept term and then including terms of increasing order) and an `rSquared` value (indicating the total variance explained by the model).\n\n__Default value:__ `false`", "type": "boolean" }, "regression": { "$ref": "#/definitions/FieldName", "description": "The data field of the dependent variable to predict." } }, "required": [ "regression", "on" ], "type": "object" }, "RelativeBandSize": { "additionalProperties": false, "properties": { "band": { "description": "The relative band size. For example `0.5` means half of the band scale's band width.", "type": "number" } }, "required": [ "band" ], "type": "object" }, "RepeatMapping": { "additionalProperties": false, "properties": { "column": { "description": "An array of fields to be repeated horizontally.", "items": { "type": "string" }, "type": "array" }, "row": { "description": "An array of fields to be repeated vertically.", "items": { "type": "string" }, "type": "array" } }, "type": "object" }, "RepeatRef": { "additionalProperties": false, "description": "Reference to a repeated value.", "properties": { "repeat": { "enum": [ "row", "column", "repeat", "layer" ], "type": "string" } }, "required": [ "repeat" ], "type": "object" }, "RepeatSpec": { "anyOf": [ { "$ref": "#/definitions/NonLayerRepeatSpec" }, { "$ref": "#/definitions/LayerRepeatSpec" } ] }, "Resolve": { "additionalProperties": false, "description": "Defines how scales, axes, and legends from different specs should be combined. Resolve is a mapping from `scale`, `axis`, and `legend` to a mapping from channels to resolutions. Scales and guides can be resolved to be `\"independent\"` or `\"shared\"`.", "properties": { "axis": { "$ref": "#/definitions/AxisResolveMap" }, "legend": { "$ref": "#/definitions/LegendResolveMap" }, "scale": { "$ref": "#/definitions/ScaleResolveMap" } }, "type": "object" }, "ResolveMode": { "enum": [ "independent", "shared" ], "type": "string" }, "RowCol": { "additionalProperties": false, "properties": { "column": { "$ref": "#/definitions/LayoutAlign" }, "row": { "$ref": "#/definitions/LayoutAlign" } }, "type": "object" }, "RowCol": { "additionalProperties": false, "properties": { "column": { "type": "boolean" }, "row": { "type": "boolean" } }, "type": "object" }, "RowCol": { "additionalProperties": false, "properties": { "column": { "type": "number" }, "row": { "type": "number" } }, "type": "object" }, "RowColumnEncodingFieldDef": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "align": { "$ref": "#/definitions/LayoutAlign", "description": "The alignment to apply to row/column facet's subplot. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\n__Default value:__ `\"all\"`." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "center": { "description": "Boolean flag indicating if facet's subviews should be centered relative to their respective rows or columns.\n\n__Default value:__ `false`", "type": "boolean" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "header": { "anyOf": [ { "$ref": "#/definitions/Header" }, { "type": "null" } ], "description": "An object defining properties of a facet's header." }, "sort": { "anyOf": [ { "$ref": "#/definitions/SortArray" }, { "$ref": "#/definitions/SortOrder" }, { "$ref": "#/definitions/EncodingSortField" }, { "type": "null" } ], "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` is not supported for `row` and `column`." }, "spacing": { "description": "The spacing in pixels between facet's sub-views.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)", "type": "number" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "SampleTransform": { "additionalProperties": false, "properties": { "sample": { "description": "The maximum number of data objects to include in the sample.\n\n__Default value:__ `1000`", "type": "number" } }, "required": [ "sample" ], "type": "object" }, "Scale": { "additionalProperties": false, "properties": { "align": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The alignment of the steps within the scale range.\n\nThis value must lie in the range `[0,1]`. A value of `0.5` indicates that the steps should be centered within the range. A value of `0` or `1` may be used to shift the bands to one side, say to position them adjacent to an axis.\n\n__Default value:__ `0.5`" }, "base": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The logarithm base of the `log` scale (default `10`)." }, "bins": { "$ref": "#/definitions/ScaleBins", "description": "Bin boundaries can be provided to scales as either an explicit array of bin boundaries or as a bin specification object. The legal values are:\n- An [array](../types/#Array) literal of bin boundary values. For example, `[0, 5, 10, 15, 20]`. The array must include both starting and ending boundaries. The previous example uses five values to indicate a total of four bin intervals: [0-5), [5-10), [10-15), [15-20]. Array literals may include signal references as elements.\n- A [bin specification object](https://vega.github.io/vega-lite/docs/scale.html#bins) that indicates the bin _step_ size, and optionally the _start_ and _stop_ boundaries.\n- An array of bin boundaries over the scale domain. If provided, axes and legends will use the bin boundaries to inform the choice of tick marks and text labels." }, "clamp": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ], "description": "If `true`, values that exceed the data domain are clamped to either the minimum or maximum range value\n\n__Default value:__ derived from the [scale config](https://vega.github.io/vega-lite/docs/config.html#scale-config)'s `clamp` (`true` by default)." }, "constant": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant determining the slope of the symlog function around zero. Only used for `symlog` scales.\n\n__Default value:__ `1`" }, "domain": { "anyOf": [ { "items": { "anyOf": [ { "type": "null" }, { "type": "string" }, { "type": "number" }, { "type": "boolean" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" } ] }, "type": "array" }, { "const": "unaggregated", "type": "string" }, { "$ref": "#/definitions/ParameterExtent" }, { "$ref": "#/definitions/DomainUnionWith" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Customized domain values in the form of constant values or dynamic values driven by a parameter.\n\n1) Constant `domain` for _quantitative_ fields can take one of the following forms:\n\n- A two-element array with minimum and maximum values. To create a diverging scale, this two-element array can be combined with the `domainMid` property.\n- An array with more than two entries, for [Piecewise quantitative scales](https://vega.github.io/vega-lite/docs/scale.html#piecewise).\n- A string value `\"unaggregated\"`, if the input field is aggregated, to indicate that the domain should include the raw data values prior to the aggregation.\n\n2) Constant `domain` for _temporal_ fields can be a two-element array with minimum and maximum values, in the form of either timestamps or the [DateTime definition objects](https://vega.github.io/vega-lite/docs/types.html#datetime).\n\n3) Constant `domain` for _ordinal_ and _nominal_ fields can be an array that lists valid input values.\n\n4) To combine (union) specified constant domain with the field's values, `domain` can be an object with a `unionWith` property that specify constant domain to be combined. For example, `domain: {unionWith: [0, 100]}` for a quantitative scale means that the scale domain always includes `[0, 100]`, but will include other values in the fields beyond `[0, 100]`.\n\n5) Domain can also takes an object defining a field or encoding of a parameter that [interactively determines](https://vega.github.io/vega-lite/docs/selection.html#scale-domains) the scale domain." }, "domainMax": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Sets the maximum value in the scale domain, overriding the `domain` property. This property is only intended for use with scales having continuous domains." }, "domainMid": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Inserts a single mid-point value into a two-element domain. The mid-point value must lie between the domain minimum and maximum values. This property can be useful for setting a midpoint for [diverging color scales](https://vega.github.io/vega-lite/docs/scale.html#piecewise). The domainMid property is only intended for use with scales supporting continuous, piecewise domains." }, "domainMin": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Sets the minimum value in the scale domain, overriding the domain property. This property is only intended for use with scales having continuous domains." }, "domainRaw": { "$ref": "#/definitions/ExprRef", "description": "An expression for an array of raw values that, if non-null, directly overrides the _domain_ property. This is useful for supporting interactions such as panning or zooming a scale. The scale may be initially determined using a data-driven domain, then modified in response to user input by setting the rawDomain value." }, "exponent": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The exponent of the `pow` scale." }, "interpolate": { "anyOf": [ { "$ref": "#/definitions/ScaleInterpolateEnum" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/ScaleInterpolateParams" } ], "description": "The interpolation method for range values. By default, a general interpolator for numbers, dates, strings and colors (in HCL space) is used. For color ranges, this property allows interpolation in alternative color spaces. Legal values include `rgb`, `hsl`, `hsl-long`, `lab`, `hcl`, `hcl-long`, `cubehelix` and `cubehelix-long` ('-long' variants use longer paths in polar coordinate spaces). If object-valued, this property accepts an object with a string-valued _type_ property and an optional numeric _gamma_ property applicable to rgb and cubehelix interpolators. For more, see the [d3-interpolate documentation](https://github.com/d3/d3-interpolate).\n\n* __Default value:__ `hcl`" }, "nice": { "anyOf": [ { "type": "boolean" }, { "type": "number" }, { "$ref": "#/definitions/TimeInterval" }, { "$ref": "#/definitions/TimeIntervalStep" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Extending the domain so that it starts and ends on nice round values. This method typically modifies the scale’s domain, and may only extend the bounds to the nearest round value. Nicing is useful if the domain is computed from data and may be irregular. For example, for a domain of _[0.201479…, 0.996679…]_, a nice domain might be _[0.2, 1.0]_.\n\nFor quantitative scales such as linear, `nice` can be either a boolean flag or a number. If `nice` is a number, it will represent a desired tick count. This allows greater control over the step size used to extend the bounds, guaranteeing that the returned ticks will exactly cover the domain.\n\nFor temporal fields with time and utc scales, the `nice` value can be a string indicating the desired time interval. Legal values are `\"millisecond\"`, `\"second\"`, `\"minute\"`, `\"hour\"`, `\"day\"`, `\"week\"`, `\"month\"`, and `\"year\"`. Alternatively, `time` and `utc` scales can accept an object-valued interval specifier of the form `{\"interval\": \"month\", \"step\": 3}`, which includes a desired number of interval steps. Here, the domain would snap to quarter (Jan, Apr, Jul, Oct) boundaries.\n\n__Default value:__ `true` for unbinned _quantitative_ fields without explicit domain bounds; `false` otherwise." }, "padding": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For _[continuous](https://vega.github.io/vega-lite/docs/scale.html#continuous)_ scales, expands the scale domain to accommodate the specified number of pixels on each of the scale range. The scale range must represent pixels for this parameter to function as intended. Padding adjustment is performed prior to all other adjustments, including the effects of the `zero`, `nice`, `domainMin`, and `domainMax` properties.\n\nFor _[band](https://vega.github.io/vega-lite/docs/scale.html#band)_ scales, shortcut for setting `paddingInner` and `paddingOuter` to the same value.\n\nFor _[point](https://vega.github.io/vega-lite/docs/scale.html#point)_ scales, alias for `paddingOuter`.\n\n__Default value:__ For _continuous_ scales, derived from the [scale config](https://vega.github.io/vega-lite/docs/scale.html#config)'s `continuousPadding`. For _band and point_ scales, see `paddingInner` and `paddingOuter`. By default, Vega-Lite sets padding such that _width/height = number of unique values * step_.", "minimum": 0 }, "paddingInner": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The inner padding (spacing) within each band step of band scales, as a fraction of the step size. This value must lie in the range [0,1].\n\nFor point scale, this property is invalid as point scales do not have internal band widths (only step sizes between bands).\n\n__Default value:__ derived from the [scale config](https://vega.github.io/vega-lite/docs/scale.html#config)'s `bandPaddingInner`.", "maximum": 1, "minimum": 0 }, "paddingOuter": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The outer padding (spacing) at the ends of the range of band and point scales, as a fraction of the step size. This value must lie in the range [0,1].\n\n__Default value:__ derived from the [scale config](https://vega.github.io/vega-lite/docs/scale.html#config)'s `bandPaddingOuter` for band scales and `pointPadding` for point scales. By default, Vega-Lite sets outer padding such that _width/height = number of unique values * step_.", "maximum": 1, "minimum": 0 }, "range": { "anyOf": [ { "$ref": "#/definitions/RangeEnum" }, { "items": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "type": "array" }, { "$ref": "#/definitions/FieldRange" } ], "description": "The range of the scale. One of:\n\n- A string indicating a [pre-defined named scale range](https://vega.github.io/vega-lite/docs/scale.html#range-config) (e.g., example, `\"symbol\"`, or `\"diverging\"`).\n\n- For [continuous scales](https://vega.github.io/vega-lite/docs/scale.html#continuous), two-element array indicating minimum and maximum values, or an array with more than two entries for specifying a [piecewise scale](https://vega.github.io/vega-lite/docs/scale.html#piecewise).\n\n- For [discrete](https://vega.github.io/vega-lite/docs/scale.html#discrete) and [discretizing](https://vega.github.io/vega-lite/docs/scale.html#discretizing) scales, an array of desired output values or an object with a `field` property representing the range values. For example, if a field `color` contains CSS color names, we can set `range` to `{field: \"color\"}`.\n\n__Notes:__\n\n1) For color scales you can also specify a color [`scheme`](https://vega.github.io/vega-lite/docs/scale.html#scheme) instead of `range`.\n\n2) Any directly specified `range` for `x` and `y` channels will be ignored. Range can be customized via the view's corresponding [size](https://vega.github.io/vega-lite/docs/size.html) (`width` and `height`)." }, "rangeMax": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Sets the maximum value in the scale range, overriding the `range` property or the default range. This property is only intended for use with scales having continuous ranges." }, "rangeMin": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Sets the minimum value in the scale range, overriding the `range` property or the default range. This property is only intended for use with scales having continuous ranges." }, "reverse": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ], "description": "If true, reverses the order of the scale range. __Default value:__ `false`." }, "round": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ], "description": "If `true`, rounds numeric output values to integers. This can be helpful for snapping to the pixel grid.\n\n__Default value:__ `false`." }, "scheme": { "anyOf": [ { "$ref": "#/definitions/ColorScheme" }, { "$ref": "#/definitions/SchemeParams" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A string indicating a color [scheme](https://vega.github.io/vega-lite/docs/scale.html#scheme) name (e.g., `\"category10\"` or `\"blues\"`) or a [scheme parameter object](https://vega.github.io/vega-lite/docs/scale.html#scheme-params).\n\nDiscrete color schemes may be used with [discrete](https://vega.github.io/vega-lite/docs/scale.html#discrete) or [discretizing](https://vega.github.io/vega-lite/docs/scale.html#discretizing) scales. Continuous color schemes are intended for use with color scales.\n\nFor the full list of supported schemes, please refer to the [Vega Scheme](https://vega.github.io/vega/docs/schemes/#reference) reference." }, "type": { "$ref": "#/definitions/ScaleType", "description": "The type of scale. Vega-Lite supports the following categories of scale types:\n\n1) [**Continuous Scales**](https://vega.github.io/vega-lite/docs/scale.html#continuous) -- mapping continuous domains to continuous output ranges ([`\"linear\"`](https://vega.github.io/vega-lite/docs/scale.html#linear), [`\"pow\"`](https://vega.github.io/vega-lite/docs/scale.html#pow), [`\"sqrt\"`](https://vega.github.io/vega-lite/docs/scale.html#sqrt), [`\"symlog\"`](https://vega.github.io/vega-lite/docs/scale.html#symlog), [`\"log\"`](https://vega.github.io/vega-lite/docs/scale.html#log), [`\"time\"`](https://vega.github.io/vega-lite/docs/scale.html#time), [`\"utc\"`](https://vega.github.io/vega-lite/docs/scale.html#utc).\n\n2) [**Discrete Scales**](https://vega.github.io/vega-lite/docs/scale.html#discrete) -- mapping discrete domains to discrete ([`\"ordinal\"`](https://vega.github.io/vega-lite/docs/scale.html#ordinal)) or continuous ([`\"band\"`](https://vega.github.io/vega-lite/docs/scale.html#band) and [`\"point\"`](https://vega.github.io/vega-lite/docs/scale.html#point)) output ranges.\n\n3) [**Discretizing Scales**](https://vega.github.io/vega-lite/docs/scale.html#discretizing) -- mapping continuous domains to discrete output ranges [`\"bin-ordinal\"`](https://vega.github.io/vega-lite/docs/scale.html#bin-ordinal), [`\"quantile\"`](https://vega.github.io/vega-lite/docs/scale.html#quantile), [`\"quantize\"`](https://vega.github.io/vega-lite/docs/scale.html#quantize) and [`\"threshold\"`](https://vega.github.io/vega-lite/docs/scale.html#threshold).\n\n__Default value:__ please see the [scale type table](https://vega.github.io/vega-lite/docs/scale.html#type)." }, "zero": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ], "description": "If `true`, ensures that a zero baseline value is included in the scale domain.\n\n__Default value:__ `true` for x and y channels if the quantitative field is not binned and no custom `domain` is provided; `false` otherwise.\n\n__Note:__ Log, time, and utc scales do not support `zero`." } }, "type": "object" }, "ScaleBinParams": { "additionalProperties": false, "properties": { "start": { "description": "The starting (lowest-valued) bin boundary.\n\n__Default value:__ The lowest value of the scale domain will be used.", "type": "number" }, "step": { "description": "The step size defining the bin interval width.", "type": "number" }, "stop": { "description": "The stopping (highest-valued) bin boundary.\n\n__Default value:__ The highest value of the scale domain will be used.", "type": "number" } }, "required": [ "step" ], "type": "object" }, "ScaleBins": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ScaleBinParams" } ] }, "ScaleConfig": { "additionalProperties": false, "properties": { "bandPaddingInner": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default inner padding for `x` and `y` band scales.\n\n__Default value:__\n- `nestedOffsetPaddingInner` for x/y scales with nested x/y offset scales.\n- `barBandPaddingInner` for bar marks (`0.1` by default)\n- `rectBandPaddingInner` for rect and other marks (`0` by default)", "maximum": 1, "minimum": 0 }, "bandPaddingOuter": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default outer padding for `x` and `y` band scales.\n\n__Default value:__ `paddingInner/2` (which makes _width/height = number of unique values * step_)", "maximum": 1, "minimum": 0 }, "bandWithNestedOffsetPaddingInner": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default inner padding for `x` and `y` band scales with nested `xOffset` and `yOffset` encoding.\n\n__Default value:__ `0.2`", "maximum": 1, "minimum": 0 }, "bandWithNestedOffsetPaddingOuter": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default outer padding for `x` and `y` band scales with nested `xOffset` and `yOffset` encoding.\n\n__Default value:__ `0.2`", "maximum": 1, "minimum": 0 }, "barBandPaddingInner": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default inner padding for `x` and `y` band-ordinal scales of `\"bar\"` marks.\n\n__Default value:__ `0.1`", "maximum": 1, "minimum": 0 }, "clamp": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ], "description": "If true, values that exceed the data domain are clamped to either the minimum or maximum range value" }, "continuousPadding": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default padding for continuous x/y scales.\n\n__Default:__ The bar width for continuous x-scale of a vertical bar and continuous y-scale of a horizontal bar.; `0` otherwise.", "minimum": 0 }, "maxBandSize": { "description": "The default max value for mapping quantitative fields to bar's size/bandSize.\n\nIf undefined (default), we will use the axis's size (width or height) - 1.", "minimum": 0, "type": "number" }, "maxFontSize": { "description": "The default max value for mapping quantitative fields to text's size/fontSize.\n\n__Default value:__ `40`", "minimum": 0, "type": "number" }, "maxOpacity": { "description": "Default max opacity for mapping a field to opacity.\n\n__Default value:__ `0.8`", "maximum": 1, "minimum": 0, "type": "number" }, "maxSize": { "description": "Default max value for point size scale.", "minimum": 0, "type": "number" }, "maxStrokeWidth": { "description": "Default max strokeWidth for the scale of strokeWidth for rule and line marks and of size for trail marks.\n\n__Default value:__ `4`", "minimum": 0, "type": "number" }, "minBandSize": { "description": "The default min value for mapping quantitative fields to bar and tick's size/bandSize scale with zero=false.\n\n__Default value:__ `2`", "minimum": 0, "type": "number" }, "minFontSize": { "description": "The default min value for mapping quantitative fields to tick's size/fontSize scale with zero=false\n\n__Default value:__ `8`", "minimum": 0, "type": "number" }, "minOpacity": { "description": "Default minimum opacity for mapping a field to opacity.\n\n__Default value:__ `0.3`", "maximum": 1, "minimum": 0, "type": "number" }, "minSize": { "description": "Default minimum value for point size scale with zero=false.\n\n__Default value:__ `9`", "minimum": 0, "type": "number" }, "minStrokeWidth": { "description": "Default minimum strokeWidth for the scale of strokeWidth for rule and line marks and of size for trail marks with zero=false.\n\n__Default value:__ `1`", "minimum": 0, "type": "number" }, "offsetBandPaddingInner": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default padding inner for xOffset/yOffset's band scales.\n\n__Default Value:__ `0`" }, "offsetBandPaddingOuter": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default padding outer for xOffset/yOffset's band scales.\n\n__Default Value:__ `0`" }, "pointPadding": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default outer padding for `x` and `y` point-ordinal scales.\n\n__Default value:__ `0.5` (which makes _width/height = number of unique values * step_)", "maximum": 1, "minimum": 0 }, "quantileCount": { "description": "Default range cardinality for [`quantile`](https://vega.github.io/vega-lite/docs/scale.html#quantile) scale.\n\n__Default value:__ `4`", "minimum": 0, "type": "number" }, "quantizeCount": { "description": "Default range cardinality for [`quantize`](https://vega.github.io/vega-lite/docs/scale.html#quantize) scale.\n\n__Default value:__ `4`", "minimum": 0, "type": "number" }, "rectBandPaddingInner": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default inner padding for `x` and `y` band-ordinal scales of `\"rect\"` marks.\n\n__Default value:__ `0`", "maximum": 1, "minimum": 0 }, "round": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ], "description": "If true, rounds numeric output values to integers. This can be helpful for snapping to the pixel grid. (Only available for `x`, `y`, and `size` scales.)" }, "useUnaggregatedDomain": { "description": "Use the source data range before aggregation as scale domain instead of aggregated data for aggregate axis.\n\nThis is equivalent to setting `domain` to `\"unaggregate\"` for aggregated _quantitative_ fields by default.\n\nThis property only works with aggregate functions that produce values within the raw data domain (`\"mean\"`, `\"average\"`, `\"median\"`, `\"q1\"`, `\"q3\"`, `\"min\"`, `\"max\"`). For other aggregations that produce values outside of the raw data domain (e.g. `\"count\"`, `\"sum\"`), this property is ignored.\n\n__Default value:__ `false`", "type": "boolean" }, "xReverse": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Reverse x-scale by default (useful for right-to-left charts)." }, "zero": { "description": "Default `scale.zero` for [`continuous`](https://vega.github.io/vega-lite/docs/scale.html#continuous) scales except for (1) x/y-scales of non-ranged bar or area charts and (2) size scales.\n\n__Default value:__ `true`", "type": "boolean" } }, "type": "object" }, "ScaleDatumDef": { "additionalProperties": false, "properties": { "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "ScaleFieldDef": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "ScaleInterpolateEnum": { "enum": [ "rgb", "lab", "hcl", "hsl", "hsl-long", "hcl-long", "cubehelix", "cubehelix-long" ], "type": "string" }, "ScaleInterpolateParams": { "additionalProperties": false, "properties": { "gamma": { "type": "number" }, "type": { "enum": [ "rgb", "cubehelix", "cubehelix-long" ], "type": "string" } }, "required": [ "type" ], "type": "object" }, "ScaleResolveMap": { "additionalProperties": false, "properties": { "angle": { "$ref": "#/definitions/ResolveMode" }, "color": { "$ref": "#/definitions/ResolveMode" }, "fill": { "$ref": "#/definitions/ResolveMode" }, "fillOpacity": { "$ref": "#/definitions/ResolveMode" }, "opacity": { "$ref": "#/definitions/ResolveMode" }, "radius": { "$ref": "#/definitions/ResolveMode" }, "shape": { "$ref": "#/definitions/ResolveMode" }, "size": { "$ref": "#/definitions/ResolveMode" }, "stroke": { "$ref": "#/definitions/ResolveMode" }, "strokeDash": { "$ref": "#/definitions/ResolveMode" }, "strokeOpacity": { "$ref": "#/definitions/ResolveMode" }, "strokeWidth": { "$ref": "#/definitions/ResolveMode" }, "theta": { "$ref": "#/definitions/ResolveMode" }, "x": { "$ref": "#/definitions/ResolveMode" }, "xOffset": { "$ref": "#/definitions/ResolveMode" }, "y": { "$ref": "#/definitions/ResolveMode" }, "yOffset": { "$ref": "#/definitions/ResolveMode" } }, "type": "object" }, "ScaleType": { "enum": [ "linear", "log", "pow", "sqrt", "symlog", "identity", "sequential", "time", "utc", "quantile", "quantize", "threshold", "bin-ordinal", "ordinal", "point", "band" ], "type": "string" }, "SchemeParams": { "additionalProperties": false, "properties": { "count": { "description": "The number of colors to use in the scheme. This can be useful for scale types such as `\"quantize\"`, which use the length of the scale range to determine the number of discrete bins for the scale domain.", "type": "number" }, "extent": { "description": "The extent of the color range to use. For example `[0.2, 1]` will rescale the color scheme such that color values in the range _[0, 0.2)_ are excluded from the scheme.", "items": { "type": "number" }, "type": "array" }, "name": { "$ref": "#/definitions/ColorScheme", "description": "A color scheme name for ordinal scales (e.g., `\"category10\"` or `\"blues\"`).\n\nFor the full list of supported schemes, please refer to the [Vega Scheme](https://vega.github.io/vega/docs/schemes/#reference) reference." } }, "required": [ "name" ], "type": "object" }, "SecondaryFieldDef": { "additionalProperties": false, "description": "A field definition of a secondary channel that shares a scale with another primary channel. For example, `x2`, `xError` and `xError2` share the same scale with `x`.", "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." } }, "type": "object" }, "SelectionConfig": { "additionalProperties": false, "properties": { "interval": { "$ref": "#/definitions/IntervalSelectionConfigWithoutType", "description": "The default definition for an [`interval`](https://vega.github.io/vega-lite/docs/parameter.html#select) selection. All properties and transformations for an interval selection definition (except `type`) may be specified here.\n\nFor instance, setting `interval` to `{\"translate\": false}` disables the ability to move interval selections by default." }, "point": { "$ref": "#/definitions/PointSelectionConfigWithoutType", "description": "The default definition for a [`point`](https://vega.github.io/vega-lite/docs/parameter.html#select) selection. All properties and transformations for a point selection definition (except `type`) may be specified here.\n\nFor instance, setting `point` to `{\"on\": \"dblclick\"}` populates point selections on double-click by default." } }, "type": "object" }, "SelectionInit": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" } ] }, "SelectionInitInterval": { "anyOf": [ { "$ref": "#/definitions/Vector2" }, { "$ref": "#/definitions/Vector2" }, { "$ref": "#/definitions/Vector2" }, { "$ref": "#/definitions/Vector2" } ] }, "SelectionInitIntervalMapping": { "$ref": "#/definitions/Dict" }, "SelectionInitMapping": { "$ref": "#/definitions/Dict" }, "SelectionParameter": { "additionalProperties": false, "properties": { "bind": { "anyOf": [ { "$ref": "#/definitions/Binding" }, { "additionalProperties": { "$ref": "#/definitions/Binding" }, "type": "object" }, { "$ref": "#/definitions/LegendBinding" }, { "const": "scales", "type": "string" } ], "description": "When set, a selection is populated by input elements (also known as dynamic query widgets) or by interacting with the corresponding legend. Direct manipulation interaction is disabled by default; to re-enable it, set the selection's [`on`](https://vega.github.io/vega-lite/docs/selection.html#common-selection-properties) property.\n\nLegend bindings are restricted to selections that only specify a single field or encoding.\n\nQuery widget binding takes the form of Vega's [input element binding definition](https://vega.github.io/vega/docs/signals/#bind) or can be a mapping between projected field/encodings and binding definitions.\n\n__See also:__ [`bind`](https://vega.github.io/vega-lite/docs/bind.html) documentation." }, "name": { "$ref": "#/definitions/ParameterName", "description": "Required. A unique name for the selection parameter. Selection names should be valid JavaScript identifiers: they should contain only alphanumeric characters (or \"$\", or \"_\") and may not start with a digit. Reserved keywords that may not be used as parameter names are \"datum\", \"event\", \"item\", and \"parent\"." }, "select": { "anyOf": [ { "$ref": "#/definitions/SelectionType" }, { "$ref": "#/definitions/PointSelectionConfig" }, { "$ref": "#/definitions/IntervalSelectionConfig" } ], "description": "Determines the default event processing and data query for the selection. Vega-Lite currently supports two selection types:\n\n- `\"point\"` -- to select multiple discrete data values; the first value is selected on `click` and additional values toggled on shift-click.\n- `\"interval\"` -- to select a continuous range of data values on `drag`." }, "value": { "anyOf": [ { "$ref": "#/definitions/SelectionInit" }, { "items": { "$ref": "#/definitions/SelectionInitMapping" }, "type": "array" }, { "$ref": "#/definitions/SelectionInitIntervalMapping" } ], "description": "Initialize the selection with a mapping between [projected channels or field names](https://vega.github.io/vega-lite/docs/selection.html#project) and initial values.\n\n__See also:__ [`init`](https://vega.github.io/vega-lite/docs/value.html) documentation." } }, "required": [ "name", "select" ], "type": "object" }, "SelectionResolution": { "enum": [ "global", "union", "intersect" ], "type": "string" }, "SelectionType": { "enum": [ "point", "interval" ], "type": "string" }, "SequenceGenerator": { "additionalProperties": false, "properties": { "name": { "description": "Provide a placeholder name and bind data at runtime.", "type": "string" }, "sequence": { "$ref": "#/definitions/SequenceParams", "description": "Generate a sequence of numbers." } }, "required": [ "sequence" ], "type": "object" }, "SequenceParams": { "additionalProperties": false, "properties": { "as": { "$ref": "#/definitions/FieldName", "description": "The name of the generated sequence field.\n\n__Default value:__ `\"data\"`" }, "start": { "description": "The starting value of the sequence (inclusive).", "type": "number" }, "step": { "description": "The step value between sequence entries.\n\n__Default value:__ `1`", "type": "number" }, "stop": { "description": "The ending value of the sequence (exclusive).", "type": "number" } }, "required": [ "start", "stop" ], "type": "object" }, "SequentialMultiHue": { "enum": [ "turbo", "viridis", "inferno", "magma", "plasma", "cividis", "bluegreen", "bluegreen-3", "bluegreen-4", "bluegreen-5", "bluegreen-6", "bluegreen-7", "bluegreen-8", "bluegreen-9", "bluepurple", "bluepurple-3", "bluepurple-4", "bluepurple-5", "bluepurple-6", "bluepurple-7", "bluepurple-8", "bluepurple-9", "goldgreen", "goldgreen-3", "goldgreen-4", "goldgreen-5", "goldgreen-6", "goldgreen-7", "goldgreen-8", "goldgreen-9", "goldorange", "goldorange-3", "goldorange-4", "goldorange-5", "goldorange-6", "goldorange-7", "goldorange-8", "goldorange-9", "goldred", "goldred-3", "goldred-4", "goldred-5", "goldred-6", "goldred-7", "goldred-8", "goldred-9", "greenblue", "greenblue-3", "greenblue-4", "greenblue-5", "greenblue-6", "greenblue-7", "greenblue-8", "greenblue-9", "orangered", "orangered-3", "orangered-4", "orangered-5", "orangered-6", "orangered-7", "orangered-8", "orangered-9", "purplebluegreen", "purplebluegreen-3", "purplebluegreen-4", "purplebluegreen-5", "purplebluegreen-6", "purplebluegreen-7", "purplebluegreen-8", "purplebluegreen-9", "purpleblue", "purpleblue-3", "purpleblue-4", "purpleblue-5", "purpleblue-6", "purpleblue-7", "purpleblue-8", "purpleblue-9", "purplered", "purplered-3", "purplered-4", "purplered-5", "purplered-6", "purplered-7", "purplered-8", "purplered-9", "redpurple", "redpurple-3", "redpurple-4", "redpurple-5", "redpurple-6", "redpurple-7", "redpurple-8", "redpurple-9", "yellowgreenblue", "yellowgreenblue-3", "yellowgreenblue-4", "yellowgreenblue-5", "yellowgreenblue-6", "yellowgreenblue-7", "yellowgreenblue-8", "yellowgreenblue-9", "yellowgreen", "yellowgreen-3", "yellowgreen-4", "yellowgreen-5", "yellowgreen-6", "yellowgreen-7", "yellowgreen-8", "yellowgreen-9", "yelloworangebrown", "yelloworangebrown-3", "yelloworangebrown-4", "yelloworangebrown-5", "yelloworangebrown-6", "yelloworangebrown-7", "yelloworangebrown-8", "yelloworangebrown-9", "yelloworangered", "yelloworangered-3", "yelloworangered-4", "yelloworangered-5", "yelloworangered-6", "yelloworangered-7", "yelloworangered-8", "yelloworangered-9", "darkblue", "darkblue-3", "darkblue-4", "darkblue-5", "darkblue-6", "darkblue-7", "darkblue-8", "darkblue-9", "darkgold", "darkgold-3", "darkgold-4", "darkgold-5", "darkgold-6", "darkgold-7", "darkgold-8", "darkgold-9", "darkgreen", "darkgreen-3", "darkgreen-4", "darkgreen-5", "darkgreen-6", "darkgreen-7", "darkgreen-8", "darkgreen-9", "darkmulti", "darkmulti-3", "darkmulti-4", "darkmulti-5", "darkmulti-6", "darkmulti-7", "darkmulti-8", "darkmulti-9", "darkred", "darkred-3", "darkred-4", "darkred-5", "darkred-6", "darkred-7", "darkred-8", "darkred-9", "lightgreyred", "lightgreyred-3", "lightgreyred-4", "lightgreyred-5", "lightgreyred-6", "lightgreyred-7", "lightgreyred-8", "lightgreyred-9", "lightgreyteal", "lightgreyteal-3", "lightgreyteal-4", "lightgreyteal-5", "lightgreyteal-6", "lightgreyteal-7", "lightgreyteal-8", "lightgreyteal-9", "lightmulti", "lightmulti-3", "lightmulti-4", "lightmulti-5", "lightmulti-6", "lightmulti-7", "lightmulti-8", "lightmulti-9", "lightorange", "lightorange-3", "lightorange-4", "lightorange-5", "lightorange-6", "lightorange-7", "lightorange-8", "lightorange-9", "lighttealblue", "lighttealblue-3", "lighttealblue-4", "lighttealblue-5", "lighttealblue-6", "lighttealblue-7", "lighttealblue-8", "lighttealblue-9" ], "type": "string" }, "SequentialSingleHue": { "enum": [ "blues", "tealblues", "teals", "greens", "browns", "greys", "purples", "warmgreys", "reds", "oranges" ], "type": "string" }, "ShapeDef": { "$ref": "#/definitions/MarkPropDef<(string|null),TypeForShape>" }, "SharedEncoding": { "additionalProperties": false, "properties": { "angle": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "color": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "$ref": "#/definitions/Gradient" }, { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "description": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(string|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "detail": { "anyOf": [ { "$ref": "#/definitions/FieldDefWithoutScale" }, { "items": { "$ref": "#/definitions/FieldDefWithoutScale" }, "type": "array" } ], "description": "Additional levels of detail for grouping data in aggregate views and in line, trail, and area marks without mapping data to a specific visual channel." }, "fill": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "$ref": "#/definitions/Gradient" }, { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "fillOpacity": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "href": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(string|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "key": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "latitude": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "const": "quantitative", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation.", "type": "string" }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "latitude2": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "longitude": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "const": "quantitative", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation.", "type": "string" }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "longitude2": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "opacity": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "order": { "anyOf": [ { "$ref": "#/definitions/OrderFieldDef" }, { "items": { "$ref": "#/definitions/OrderFieldDef" }, "type": "array" }, { "$ref": "#/definitions/OrderValueDef" }, { "$ref": "#/definitions/OrderOnlyDef" } ], "description": "Order of the marks.\n- For stacked marks, this `order` channel encodes [stack order](https://vega.github.io/vega-lite/docs/stack.html#order).\n- For line and trail marks, this `order` channel encodes order of data points in the lines. This can be useful for creating [a connected scatterplot](https://vega.github.io/vega-lite/examples/connected_scatterplot.html). Setting `order` to `{\"value\": null}` makes the line marks use the original order in the data sources.\n- Otherwise, this `order` channel encodes layer order of the marks.\n\n__Note__: In aggregate plots, `order` field should be `aggregate`d to avoid creating additional aggregation grouping." }, "radius": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "stack": { "anyOf": [ { "$ref": "#/definitions/StackOffset" }, { "type": "null" }, { "type": "boolean" } ], "description": "Type of stacking offset if the field should be stacked. `stack` is only applicable for `x`, `y`, `theta`, and `radius` channels with continuous domains. For example, `stack` of `y` can be used to customize stacking for a vertical bar chart.\n\n`stack` can be one of the following values:\n- `\"zero\"` or `true`: stacking with baseline offset at zero value of the scale (for creating typical stacked [bar](https://vega.github.io/vega-lite/docs/stack.html#bar) and [area](https://vega.github.io/vega-lite/docs/stack.html#area) chart).\n- `\"normalize\"` - stacking with normalized domain (for creating [normalized stacked bar and area charts](https://vega.github.io/vega-lite/docs/stack.html#normalized) and pie charts [with percentage tooltip](https://vega.github.io/vega-lite/docs/arc.html#tooltip)).
\n-`\"center\"` - stacking with center baseline (for [streamgraph](https://vega.github.io/vega-lite/docs/stack.html#streamgraph)).\n- `null` or `false` - No-stacking. This will produce layered [bar](https://vega.github.io/vega-lite/docs/stack.html#layered-bar-chart) and area chart.\n\n__Default value:__ `zero` for plots with all of the following conditions are true: (1) the mark is `bar`, `area`, or `arc`; (2) the stacked measure channel (x or y) has a linear scale; (3) At least one of non-position channels mapped to an unaggregated field that is different from x and y. Otherwise, `null` by default.\n\n__See also:__ [`stack`](https://vega.github.io/vega-lite/docs/stack.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "radius2": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "shape": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/TypeForShape", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "size": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "stroke": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "$ref": "#/definitions/Gradient" }, { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "strokeDash": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number[]|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number[]|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(number[]|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number[]|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "strokeOpacity": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "strokeWidth": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "legend": { "anyOf": [ { "$ref": "#/definitions/Legend" }, { "type": "null" } ], "description": "An object defining properties of the legend. If `null`, the legend for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [legend properties](https://vega.github.io/vega-lite/docs/legend.html) are applied.\n\n__See also:__ [`legend`](https://vega.github.io/vega-lite/docs/legend.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "text": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(Text|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Text|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalStringFieldDef" }, { "$ref": "#/definitions/ConditionalValueDef<(Text|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Text|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "theta": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "stack": { "anyOf": [ { "$ref": "#/definitions/StackOffset" }, { "type": "null" }, { "type": "boolean" } ], "description": "Type of stacking offset if the field should be stacked. `stack` is only applicable for `x`, `y`, `theta`, and `radius` channels with continuous domains. For example, `stack` of `y` can be used to customize stacking for a vertical bar chart.\n\n`stack` can be one of the following values:\n- `\"zero\"` or `true`: stacking with baseline offset at zero value of the scale (for creating typical stacked [bar](https://vega.github.io/vega-lite/docs/stack.html#bar) and [area](https://vega.github.io/vega-lite/docs/stack.html#area) chart).\n- `\"normalize\"` - stacking with normalized domain (for creating [normalized stacked bar and area charts](https://vega.github.io/vega-lite/docs/stack.html#normalized) and pie charts [with percentage tooltip](https://vega.github.io/vega-lite/docs/arc.html#tooltip)).
\n-`\"center\"` - stacking with center baseline (for [streamgraph](https://vega.github.io/vega-lite/docs/stack.html#streamgraph)).\n- `null` or `false` - No-stacking. This will produce layered [bar](https://vega.github.io/vega-lite/docs/stack.html#layered-bar-chart) and area chart.\n\n__Default value:__ `zero` for plots with all of the following conditions are true: (1) the mark is `bar`, `area`, or `arc`; (2) the stacked measure channel (x or y) has a linear scale; (3) At least one of non-position channels mapped to an unaggregated field that is different from x and y. Otherwise, `null` by default.\n\n__See also:__ [`stack`](https://vega.github.io/vega-lite/docs/stack.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "theta2": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "tooltip": { "anyOf": [ { "$ref": "#/definitions/StringFieldDefWithCondition" }, { "$ref": "#/definitions/StringValueDefWithCondition" }, { "items": { "$ref": "#/definitions/StringFieldDef" }, "type": "array" }, { "type": "null" } ], "description": "The tooltip text to show upon mouse hover. Specifying `tooltip` encoding overrides [the `tooltip` property in the mark definition](https://vega.github.io/vega-lite/docs/mark.html#mark-def).\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite." }, "url": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "condition": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/ConditionalValueDef<(string|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|ExprRef)>" }, "type": "array" } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." } ], "description": "One or more value definition(s) with [a parameter or a test predicate](https://vega.github.io/vega-lite/docs/condition.html).\n\n__Note:__ A field definition's `condition` property can only contain [conditional value definitions](https://vega.github.io/vega-lite/docs/condition.html#value) since Vega-Lite only allows at most one encoded field per encoding channel." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "x": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "axis": { "anyOf": [ { "$ref": "#/definitions/Axis" }, { "type": "null" } ], "description": "An object defining properties of axis's gridlines, ticks and labels. If `null`, the axis for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [axis properties](https://vega.github.io/vega-lite/docs/axis.html) are applied.\n\n__See also:__ [`axis`](https://vega.github.io/vega-lite/docs/axis.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "impute": { "anyOf": [ { "$ref": "#/definitions/ImputeParams" }, { "type": "null" } ], "description": "An object defining the properties of the Impute Operation to be applied. The field value of the other positional channel is taken as `key` of the `Impute` Operation. The field of the `color` channel if specified is used as `groupby` of the `Impute` Operation.\n\n__See also:__ [`impute`](https://vega.github.io/vega-lite/docs/impute.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "stack": { "anyOf": [ { "$ref": "#/definitions/StackOffset" }, { "type": "null" }, { "type": "boolean" } ], "description": "Type of stacking offset if the field should be stacked. `stack` is only applicable for `x`, `y`, `theta`, and `radius` channels with continuous domains. For example, `stack` of `y` can be used to customize stacking for a vertical bar chart.\n\n`stack` can be one of the following values:\n- `\"zero\"` or `true`: stacking with baseline offset at zero value of the scale (for creating typical stacked [bar](https://vega.github.io/vega-lite/docs/stack.html#bar) and [area](https://vega.github.io/vega-lite/docs/stack.html#area) chart).\n- `\"normalize\"` - stacking with normalized domain (for creating [normalized stacked bar and area charts](https://vega.github.io/vega-lite/docs/stack.html#normalized) and pie charts [with percentage tooltip](https://vega.github.io/vega-lite/docs/arc.html#tooltip)).
\n-`\"center\"` - stacking with center baseline (for [streamgraph](https://vega.github.io/vega-lite/docs/stack.html#streamgraph)).\n- `null` or `false` - No-stacking. This will produce layered [bar](https://vega.github.io/vega-lite/docs/stack.html#layered-bar-chart) and area chart.\n\n__Default value:__ `zero` for plots with all of the following conditions are true: (1) the mark is `bar`, `area`, or `arc`; (2) the stacked measure channel (x or y) has a linear scale; (3) At least one of non-position channels mapped to an unaggregated field that is different from x and y. Otherwise, `null` by default.\n\n__See also:__ [`stack`](https://vega.github.io/vega-lite/docs/stack.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "x2": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "xError": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": "number" } }, "type": "object" }, "xError2": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": "number" } }, "type": "object" }, "xOffset": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": "number" } }, "type": "object" }, "y": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "axis": { "anyOf": [ { "$ref": "#/definitions/Axis" }, { "type": "null" } ], "description": "An object defining properties of axis's gridlines, ticks and labels. If `null`, the axis for the encoding channel will be removed.\n\n__Default value:__ If undefined, default [axis properties](https://vega.github.io/vega-lite/docs/axis.html) are applied.\n\n__See also:__ [`axis`](https://vega.github.io/vega-lite/docs/axis.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "impute": { "anyOf": [ { "$ref": "#/definitions/ImputeParams" }, { "type": "null" } ], "description": "An object defining the properties of the Impute Operation to be applied. The field value of the other positional channel is taken as `key` of the `Impute` Operation. The field of the `color` channel if specified is used as `groupby` of the `Impute` Operation.\n\n__See also:__ [`impute`](https://vega.github.io/vega-lite/docs/impute.html) documentation." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "stack": { "anyOf": [ { "$ref": "#/definitions/StackOffset" }, { "type": "null" }, { "type": "boolean" } ], "description": "Type of stacking offset if the field should be stacked. `stack` is only applicable for `x`, `y`, `theta`, and `radius` channels with continuous domains. For example, `stack` of `y` can be used to customize stacking for a vertical bar chart.\n\n`stack` can be one of the following values:\n- `\"zero\"` or `true`: stacking with baseline offset at zero value of the scale (for creating typical stacked [bar](https://vega.github.io/vega-lite/docs/stack.html#bar) and [area](https://vega.github.io/vega-lite/docs/stack.html#area) chart).\n- `\"normalize\"` - stacking with normalized domain (for creating [normalized stacked bar and area charts](https://vega.github.io/vega-lite/docs/stack.html#normalized) and pie charts [with percentage tooltip](https://vega.github.io/vega-lite/docs/arc.html#tooltip)).
\n-`\"center\"` - stacking with center baseline (for [streamgraph](https://vega.github.io/vega-lite/docs/stack.html#streamgraph)).\n- `null` or `false` - No-stacking. This will produce layered [bar](https://vega.github.io/vega-lite/docs/stack.html#layered-bar-chart) and area chart.\n\n__Default value:__ `zero` for plots with all of the following conditions are true: (1) the mark is `bar`, `area`, or `arc`; (2) the stacked measure channel (x or y) has a linear scale; (3) At least one of non-position channels mapped to an unaggregated field that is different from x and y. Otherwise, `null` by default.\n\n__See also:__ [`stack`](https://vega.github.io/vega-lite/docs/stack.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "y2": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "yError": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": "number" } }, "type": "object" }, "yError2": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation.", "type": "null" }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": "number" } }, "type": "object" }, "yOffset": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "datum": { "anyOf": [ { "$ref": "#/definitions/PrimitiveValue" }, { "$ref": "#/definitions/DateTime" }, { "$ref": "#/definitions/ExprRef" }, { "$ref": "#/definitions/RepeatRef" } ], "description": "A constant value in data domain." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "scale": { "anyOf": [ { "$ref": "#/definitions/Scale" }, { "type": "null" } ], "description": "An object defining properties of the channel's scale, which is the function that transforms values in the data domain (numbers, dates, strings, etc) to visual values (pixels, colors, sizes) of the encoding channels.\n\nIf `null`, the scale will be [disabled and the data value will be directly encoded](https://vega.github.io/vega-lite/docs/scale.html#disable).\n\n__Default value:__ If undefined, default [scale properties](https://vega.github.io/vega-lite/docs/scale.html) are applied.\n\n__See also:__ [`scale`](https://vega.github.io/vega-lite/docs/scale.html) documentation." }, "sort": { "$ref": "#/definitions/Sort", "description": "Sort order for the encoded field.\n\nFor continuous fields (quantitative or temporal), `sort` can be either `\"ascending\"` or `\"descending\"`.\n\nFor discrete fields, `sort` can be one of the following:\n- `\"ascending\"` or `\"descending\"` -- for sorting by the values' natural order in JavaScript.\n- [A string indicating an encoding channel name to sort by](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding) (e.g., `\"x\"` or `\"y\"`) with an optional minus prefix for descending sort (e.g., `\"-x\"` to sort by x-field, descending). This channel string is short-form of [a sort-by-encoding definition](https://vega.github.io/vega-lite/docs/sort.html#sort-by-encoding). For example, `\"sort\": \"-x\"` is equivalent to `\"sort\": {\"encoding\": \"x\", \"order\": \"descending\"}`.\n- [A sort field definition](https://vega.github.io/vega-lite/docs/sort.html#sort-field) for sorting by another field.\n- [An array specifying the field values in preferred order](https://vega.github.io/vega-lite/docs/sort.html#sort-array). In this case, the sort order will obey the values in the array, followed by any unspecified values in their original order. For discrete time field, values in the sort array can be [date-time definition objects](types#datetime). In addition, for time units `\"month\"` and `\"day\"`, the values can be the month or day names (case insensitive) or their 3-letter initials (e.g., `\"Mon\"`, `\"Tue\"`).\n- `null` indicating no sort.\n\n__Default value:__ `\"ascending\"`\n\n__Note:__ `null` and sorting by another channel is not supported for `row` and `column`.\n\n__See also:__ [`sort`](https://vega.github.io/vega-lite/docs/sort.html) documentation." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "anyOf": [ { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, { "$ref": "#/definitions/Type", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } ], "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." }, "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": "number" } }, "type": "object" } }, "type": "object" }, "SingleDefUnitChannel": { "enum": [ "x", "y", "xOffset", "yOffset", "x2", "y2", "longitude", "latitude", "longitude2", "latitude2", "theta", "theta2", "radius", "radius2", "color", "fill", "stroke", "opacity", "fillOpacity", "strokeOpacity", "strokeWidth", "strokeDash", "size", "angle", "shape", "key", "text", "href", "url", "description" ], "type": "string" }, "SingleTimeUnit": { "anyOf": [ { "$ref": "#/definitions/LocalSingleTimeUnit" }, { "$ref": "#/definitions/UtcSingleTimeUnit" } ] }, "Sort": { "anyOf": [ { "$ref": "#/definitions/SortArray" }, { "$ref": "#/definitions/AllSortString" }, { "$ref": "#/definitions/EncodingSortField" }, { "$ref": "#/definitions/SortByEncoding" }, { "type": "null" } ] }, "SortArray": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "items": { "type": "string" }, "type": "array" }, { "items": { "type": "boolean" }, "type": "array" }, { "items": { "$ref": "#/definitions/DateTime" }, "type": "array" } ] }, "SortByChannel": { "enum": [ "x", "y", "color", "fill", "stroke", "strokeWidth", "size", "shape", "fillOpacity", "strokeOpacity", "opacity", "text" ], "type": "string" }, "SortByChannelDesc": { "enum": [ "-x", "-y", "-color", "-fill", "-stroke", "-strokeWidth", "-size", "-shape", "-fillOpacity", "-strokeOpacity", "-opacity", "-text" ], "type": "string" }, "SortByEncoding": { "additionalProperties": false, "properties": { "encoding": { "$ref": "#/definitions/SortByChannel", "description": "The [encoding channel](https://vega.github.io/vega-lite/docs/encoding.html#channels) to sort by (e.g., `\"x\"`, `\"y\"`)" }, "order": { "anyOf": [ { "$ref": "#/definitions/SortOrder" }, { "type": "null" } ], "description": "The sort order. One of `\"ascending\"` (default), `\"descending\"`, or `null` (no not sort)." } }, "required": [ "encoding" ], "type": "object" }, "SortField": { "additionalProperties": false, "description": "A sort definition for transform", "properties": { "field": { "$ref": "#/definitions/FieldName", "description": "The name of the field to sort." }, "order": { "anyOf": [ { "$ref": "#/definitions/SortOrder" }, { "type": "null" } ], "description": "Whether to sort the field in ascending or descending order. One of `\"ascending\"` (default), `\"descending\"`, or `null` (no not sort)." } }, "required": [ "field" ], "type": "object" }, "SortOrder": { "enum": [ "ascending", "descending" ], "type": "string" }, "SphereGenerator": { "additionalProperties": false, "properties": { "name": { "description": "Provide a placeholder name and bind data at runtime.", "type": "string" }, "sphere": { "anyOf": [ { "const": true, "type": "boolean" }, { "additionalProperties": false, "type": "object" } ], "description": "Generate sphere GeoJSON data for the full globe." } }, "required": [ "sphere" ], "type": "object" }, "StackOffset": { "enum": [ "zero", "center", "normalize" ], "type": "string" }, "StackTransform": { "additionalProperties": false, "properties": { "as": { "anyOf": [ { "$ref": "#/definitions/FieldName" }, { "items": { "$ref": "#/definitions/FieldName" }, "maxItems": 2, "minItems": 2, "type": "array" } ], "description": "Output field names. This can be either a string or an array of strings with two elements denoting the name for the fields for stack start and stack end respectively. If a single string(e.g., `\"val\"`) is provided, the end field will be `\"val_end\"`." }, "groupby": { "description": "The data fields to group by.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "offset": { "description": "Mode for stacking marks. One of `\"zero\"` (default), `\"center\"`, or `\"normalize\"`. The `\"zero\"` offset will stack starting at `0`. The `\"center\"` offset will center the stacks. The `\"normalize\"` offset will compute percentage values for each stack point, with output values in the range `[0,1]`.\n\n__Default value:__ `\"zero\"`", "enum": [ "zero", "center", "normalize" ], "type": "string" }, "sort": { "description": "Field that determines the order of leaves in the stacked charts.", "items": { "$ref": "#/definitions/SortField" }, "type": "array" }, "stack": { "$ref": "#/definitions/FieldName", "description": "The field which is stacked." } }, "required": [ "stack", "groupby", "as" ], "type": "object" }, "StandardType": { "enum": [ "quantitative", "ordinal", "temporal", "nominal" ], "type": "string" }, "Step": { "additionalProperties": false, "properties": { "for": { "$ref": "#/definitions/StepFor", "description": "Whether to apply the step to position scale or offset scale when there are both `x` and `xOffset` or both `y` and `yOffset` encodings." }, "step": { "description": "The size (width/height) per discrete step.", "type": "number" } }, "required": [ "step" ], "type": "object" }, "StepFor": { "enum": [ "position", "offset" ], "type": "string" }, "Stream": { "anyOf": [ { "$ref": "#/definitions/EventStream" }, { "$ref": "#/definitions/DerivedStream" }, { "$ref": "#/definitions/MergedStream" } ] }, "StringFieldDef": { "additionalProperties": false, "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "format": { "anyOf": [ { "type": "string" }, { "$ref": "#/definitions/Dict" } ], "description": "When used with the default `\"number\"` and `\"time\"` format type, the text formatting pattern for labels of guides (axes, legends, headers) and text marks.\n\n- If the format type is `\"number\"` (e.g., for quantitative fields), this is D3's [number format pattern](https://github.com/d3/d3-format#locale_format).\n- If the format type is `\"time\"` (e.g., for temporal fields), this is D3's [time format pattern](https://github.com/d3/d3-time-format#locale_format).\n\nSee the [format documentation](https://vega.github.io/vega-lite/docs/format.html) for more examples.\n\nWhen used with a [custom `formatType`](https://vega.github.io/vega-lite/docs/config.html#custom-format-type), this value will be passed as `format` alongside `datum.value` to the registered function.\n\n__Default value:__ Derived from [numberFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for number format and from [timeFormat](https://vega.github.io/vega-lite/docs/config.html#format) config for time format." }, "formatType": { "description": "The format type for labels. One of `\"number\"`, `\"time\"`, or a [registered custom format type](https://vega.github.io/vega-lite/docs/config.html#custom-format-type).\n\n__Default value:__\n- `\"time\"` for temporal fields and ordinal and nominal fields with `timeUnit`.\n- `\"number\"` for quantitative fields as well as ordinal and nominal fields without `timeUnit`.", "type": "string" }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "StringFieldDefWithCondition": { "$ref": "#/definitions/FieldOrDatumDefWithCondition" }, "StringValueDefWithCondition": { "$ref": "#/definitions/ValueDefWithCondition" }, "StrokeCap": { "enum": [ "butt", "round", "square" ], "type": "string" }, "StrokeJoin": { "enum": [ "miter", "round", "bevel" ], "type": "string" }, "StyleConfigIndex": { "additionalProperties": { "anyOf": [ { "$ref": "#/definitions/AnyMarkConfig" }, { "$ref": "#/definitions/Axis" } ] }, "properties": { "arc": { "$ref": "#/definitions/RectConfig", "description": "Arc-specific Config" }, "area": { "$ref": "#/definitions/AreaConfig", "description": "Area-Specific Config" }, "bar": { "$ref": "#/definitions/BarConfig", "description": "Bar-Specific Config" }, "circle": { "$ref": "#/definitions/MarkConfig", "description": "Circle-Specific Config" }, "geoshape": { "$ref": "#/definitions/MarkConfig", "description": "Geoshape-Specific Config" }, "group-subtitle": { "$ref": "#/definitions/MarkConfig", "description": "Default style for chart subtitles" }, "group-title": { "$ref": "#/definitions/MarkConfig", "description": "Default style for chart titles" }, "guide-label": { "$ref": "#/definitions/MarkConfig", "description": "Default style for axis, legend, and header labels." }, "guide-title": { "$ref": "#/definitions/MarkConfig", "description": "Default style for axis, legend, and header titles." }, "image": { "$ref": "#/definitions/RectConfig", "description": "Image-specific Config" }, "line": { "$ref": "#/definitions/LineConfig", "description": "Line-Specific Config" }, "mark": { "$ref": "#/definitions/MarkConfig", "description": "Mark Config" }, "point": { "$ref": "#/definitions/MarkConfig", "description": "Point-Specific Config" }, "rect": { "$ref": "#/definitions/RectConfig", "description": "Rect-Specific Config" }, "rule": { "$ref": "#/definitions/MarkConfig", "description": "Rule-Specific Config" }, "square": { "$ref": "#/definitions/MarkConfig", "description": "Square-Specific Config" }, "text": { "$ref": "#/definitions/MarkConfig", "description": "Text-Specific Config" }, "tick": { "$ref": "#/definitions/TickConfig", "description": "Tick-Specific Config" }, "trail": { "$ref": "#/definitions/LineConfig", "description": "Trail-Specific Config" } }, "type": "object" }, "SymbolShape": { "type": "string" }, "Text": { "anyOf": [ { "type": "string" }, { "items": { "type": "string" }, "type": "array" } ] }, "TextBaseline": { "anyOf": [ { "const": "alphabetic", "type": "string" }, { "$ref": "#/definitions/Baseline" }, { "const": "line-top", "type": "string" }, { "const": "line-bottom", "type": "string" } ] }, "TextDef": { "anyOf": [ { "$ref": "#/definitions/FieldOrDatumDefWithCondition" }, { "$ref": "#/definitions/FieldOrDatumDefWithCondition" }, { "$ref": "#/definitions/ValueDefWithCondition" } ] }, "TextDirection": { "enum": [ "ltr", "rtl" ], "type": "string" }, "TickConfig": { "additionalProperties": false, "properties": { "align": { "anyOf": [ { "$ref": "#/definitions/Align" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The horizontal alignment of the text or ranged marks (area, bar, image, rect, rule). One of `\"left\"`, `\"right\"`, `\"center\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "angle": { "anyOf": [ { "description": "The rotation angle of the text, in degrees.", "maximum": 360, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG element, removing the mark item from the ARIA accessibility tree.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRole": { "anyOf": [ { "description": "Sets the type of user interface element of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"role\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "ariaRoleDescription": { "anyOf": [ { "description": "A human-readable, author-localized description for the role of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the \"aria-roledescription\" attribute. Warning: this property is experimental and may be changed in the future.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "aspect": { "anyOf": [ { "description": "Whether to keep aspect ratio of image marks.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "bandSize": { "description": "The width of the ticks.\n\n__Default value:__ 3/4 of step (width step for horizontal ticks and height step for vertical ticks).", "minimum": 0, "type": "number" }, "baseline": { "anyOf": [ { "$ref": "#/definitions/TextBaseline" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For text marks, the vertical text baseline. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, `\"line-bottom\"`, or an expression reference that provides one of the valid values. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the `lineHeight` rather than `fontSize` alone.\n\nFor range marks, the vertical alignment of the marks. One of `\"top\"`, `\"middle\"`, `\"bottom\"`.\n\n__Note:__ Expression reference is *not* supported for range marks." }, "blend": { "anyOf": [ { "$ref": "#/definitions/Blend", "description": "The color blend mode for drawing an item on its current background. Any valid [CSS mix-blend-mode](https://developer.mozilla.org/en-US/docs/Web/CSS/mix-blend-mode) value can be used.\n\n__Default value: `\"source-over\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "color": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default color.\n\n__Default value:__ `\"#4682b4\"`\n\n__Note:__\n- This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).\n- The `fill` and `stroke` properties have higher precedence than `color` and will override `color`." }, "cornerRadius": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles or arcs' corners.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusBottomRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' bottom right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopLeft": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top right corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cornerRadiusTopRight": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles' top left corner.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cursor": { "anyOf": [ { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the mark. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, { "$ref": "#/definitions/ExprRef" } ] }, "description": { "anyOf": [ { "description": "A text description of the mark item for [ARIA accessibility](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) (SVG output only). If specified, this property determines the [\"aria-label\" attribute](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA/ARIA_Techniques/Using_the_aria-label_attribute).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "dir": { "anyOf": [ { "$ref": "#/definitions/TextDirection", "description": "The direction of the text. One of `\"ltr\"` (left-to-right) or `\"rtl\"` (right-to-left). This property determines on which side is truncated in response to the limit parameter.\n\n__Default value:__ `\"ltr\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "dx": { "anyOf": [ { "description": "The horizontal offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "dy": { "anyOf": [ { "description": "The vertical offset, in pixels, between the text label and its anchor point. The offset is applied after rotation by the _angle_ property.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "ellipsis": { "anyOf": [ { "description": "The ellipsis string for text truncated in response to the limit parameter.\n\n__Default value:__ `\"…\"`", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "endAngle": { "anyOf": [ { "description": "The end angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fill": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default fill color. This property has higher precedence than `config.color`. Set to `null` to remove fill.\n\n__Default value:__ (None)" }, "fillOpacity": { "anyOf": [ { "description": "The fill opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "filled": { "description": "Whether the mark's color should be used as fill color instead of stroke color.\n\n__Default value:__ `false` for all `point`, `line`, and `rule` marks as well as `geoshape` marks for [`graticule`](https://vega.github.io/vega-lite/docs/data.html#graticule) data sources; otherwise, `true`.\n\n__Note:__ This property cannot be used in a [style config](https://vega.github.io/vega-lite/docs/mark.html#style-config).", "type": "boolean" }, "font": { "anyOf": [ { "description": "The typeface to set the text in (e.g., `\"Helvetica Neue\"`).", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontSize": { "anyOf": [ { "description": "The font size, in pixels.\n\n__Default value:__ `11`", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "The font style (e.g., `\"italic\"`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "fontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "The font weight. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "height": { "anyOf": [ { "description": "Height of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "href": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "A URL to load upon mouse click. If defined, the mark acts as a hyperlink." }, { "$ref": "#/definitions/ExprRef" } ] }, "innerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The inner radius in pixels of arc marks. `innerRadius` is an alias for `radius2`.\n\n__Default value:__ `0`", "minimum": 0 }, "interpolate": { "anyOf": [ { "$ref": "#/definitions/Interpolate", "description": "The line interpolation method to use for line and area marks. One of the following:\n- `\"linear\"`: piecewise linear segments, as in a polyline.\n- `\"linear-closed\"`: close the linear segments to form a polygon.\n- `\"step\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"step-before\"`: alternate between vertical and horizontal segments, as in a step function.\n- `\"step-after\"`: alternate between horizontal and vertical segments, as in a step function.\n- `\"basis\"`: a B-spline, with control point duplication on the ends.\n- `\"basis-open\"`: an open B-spline; may not intersect the start or end.\n- `\"basis-closed\"`: a closed B-spline, as in a loop.\n- `\"cardinal\"`: a Cardinal spline, with control point duplication on the ends.\n- `\"cardinal-open\"`: an open Cardinal spline; may not intersect the start or end, but will intersect other control points.\n- `\"cardinal-closed\"`: a closed Cardinal spline, as in a loop.\n- `\"bundle\"`: equivalent to basis, except the tension parameter is used to straighten the spline.\n- `\"monotone\"`: cubic interpolation that preserves monotonicity in y." }, { "$ref": "#/definitions/ExprRef" } ] }, "invalid": { "description": "Defines how Vega-Lite should handle marks for invalid values (`null` and `NaN`).\n- If set to `\"filter\"` (default), all data items with null values will be skipped (for line, trail, and area marks) or filtered (for other marks).\n- If `null`, all data items are included. In this case, invalid values will be interpreted as zeroes.", "enum": [ "filter", null ], "type": [ "string", "null" ] }, "limit": { "anyOf": [ { "description": "The maximum length of the text mark in pixels. The text value will be automatically truncated if the rendered size exceeds the limit.\n\n__Default value:__ `0` -- indicating no limit", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineBreak": { "anyOf": [ { "description": "A delimiter, such as a newline character, upon which to break text strings into multiple lines. This property is ignored if the text is array-valued.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineHeight": { "anyOf": [ { "description": "The line height in pixels (the spacing between subsequent lines of text) for multi-line text marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "opacity": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The overall opacity (value between [0,1]).\n\n__Default value:__ `0.7` for non-aggregate plots with `point`, `tick`, `circle`, or `square` marks or layered `bar` charts and `1` otherwise.", "maximum": 1, "minimum": 0 }, "order": { "description": "For line and trail marks, this `order` property can be set to `null` or `false` to make the lines use the original order in the data sources.", "type": [ "null", "boolean" ] }, "orient": { "$ref": "#/definitions/Orientation", "description": "The orientation of a non-stacked bar, tick, area, and line charts. The value is either horizontal (default) or vertical.\n- For bar, rule and tick, this determines whether the size of the bar and tick should be applied to x or y dimension.\n- For area, this property determines the orient property of the Vega output.\n- For line and trail marks, this property determines the sort order of the points in the line if `config.sortLineBy` is not specified. For stacked charts, this is always determined by the orientation of the stack; therefore explicitly specified value will be ignored." }, "outerRadius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The outer radius in pixels of arc marks. `outerRadius` is an alias for `radius`.\n\n__Default value:__ `0`", "minimum": 0 }, "padAngle": { "anyOf": [ { "description": "The angular padding applied to sides of the arc, in radians.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "radius": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "For arc mark, the primary (outer) radius in pixels.\n\nFor text marks, polar coordinate radial offset, in pixels, of the text from the origin determined by the `x` and `y` properties.\n\n__Default value:__ `min(plot_width, plot_height)/2`", "minimum": 0 }, "radius2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The secondary (inner) radius in pixels of arc marks.\n\n__Default value:__ `0`", "minimum": 0 }, "shape": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/SymbolShape" }, { "type": "string" } ], "description": "Shape of the point marks. Supported values include:\n- plotting shapes: `\"circle\"`, `\"square\"`, `\"cross\"`, `\"diamond\"`, `\"triangle-up\"`, `\"triangle-down\"`, `\"triangle-right\"`, or `\"triangle-left\"`.\n- the line symbol `\"stroke\"`\n- centered directional shapes `\"arrow\"`, `\"wedge\"`, or `\"triangle\"`\n- a custom [SVG path string](https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Paths) (For correct sizing, custom shape paths should be defined within a square bounding box with coordinates ranging from -1 to 1 along both the x and y dimensions.)\n\n__Default value:__ `\"circle\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "size": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default size for marks.\n- For `point`/`circle`/`square`, this represents the pixel area of the marks. Note that this value sets the area of the symbol; the side lengths will increase with the square root of this value.\n- For `bar`, this represents the band size of the bar, in pixels.\n- For `text`, this represents the font size, in pixels.\n\n__Default value:__\n- `30` for point, circle, square marks; width/height's `step`\n- `2` for bar marks with discrete dimensions;\n- `5` for bar marks with continuous dimensions;\n- `11` for text marks.", "minimum": 0 }, "smooth": { "anyOf": [ { "description": "A boolean flag (default true) indicating if the image should be smoothed when resized. If false, individual pixels should be scaled directly rather than interpolated with smoothing. For SVG rendering, this option may not work in some browsers due to lack of standardization.", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "startAngle": { "anyOf": [ { "description": "The start angle in radians for arc marks. A value of `0` indicates up (north), increasing values proceed clockwise.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "stroke": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/Gradient" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Default stroke color. This property has higher precedence than `config.color`. Set to `null` to remove stroke.\n\n__Default value:__ (None)" }, "strokeCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for line ending style. One of `\"butt\"`, `\"round\"`, or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "An array of alternating stroke, space lengths for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDashOffset": { "anyOf": [ { "description": "The offset (in pixels) into which to begin drawing with the stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeJoin": { "anyOf": [ { "$ref": "#/definitions/StrokeJoin", "description": "The stroke line join method. One of `\"miter\"`, `\"round\"` or `\"bevel\"`.\n\n__Default value:__ `\"miter\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeMiterLimit": { "anyOf": [ { "description": "The miter limit at which to bevel a line join.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOffset": { "anyOf": [ { "description": "The offset in pixels at which to draw the group stroke and fill. If unspecified, the default behavior is to dynamically offset stroked groups such that 1 pixel stroke widths align with the pixel grid.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOpacity": { "anyOf": [ { "description": "The stroke opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "The stroke width, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "tension": { "anyOf": [ { "description": "Depending on the interpolation type, sets the tension parameter (for line and area marks).", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "text": { "anyOf": [ { "$ref": "#/definitions/Text", "description": "Placeholder text if the `text` channel is not specified" }, { "$ref": "#/definitions/ExprRef" } ] }, "theta": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "- For arc marks, the arc length in radians if theta2 is not specified, otherwise the start arc angle. (A value of 0 indicates up or “north”, increasing values proceed clockwise.)\n\n- For text marks, polar coordinate angle in radians.", "maximum": 360, "minimum": 0 }, "theta2": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The end angle of arc marks in radians. A value of 0 indicates up or “north”, increasing values proceed clockwise." }, "thickness": { "description": "Thickness of the tick mark.\n\n__Default value:__ `1`", "minimum": 0, "type": "number" }, "timeUnitBandPosition": { "description": "Default relative band position for a time unit. If set to `0`, the marks will be positioned at the beginning of the time unit band step. If set to `0.5`, the marks will be positioned in the middle of the time unit band step.", "type": "number" }, "timeUnitBandSize": { "description": "Default relative band size for a time unit. If set to `1`, the bandwidth of the marks will be equal to the time unit band step. If set to `0.5`, bandwidth of the marks will be half of the time unit band step.", "type": "number" }, "tooltip": { "anyOf": [ { "type": "number" }, { "type": "string" }, { "type": "boolean" }, { "$ref": "#/definitions/TooltipContent" }, { "$ref": "#/definitions/ExprRef" }, { "type": "null" } ], "description": "The tooltip text string to show upon mouse hover or an object defining which fields should the tooltip be derived from.\n\n- If `tooltip` is `true` or `{\"content\": \"encoding\"}`, then all fields from `encoding` will be used.\n- If `tooltip` is `{\"content\": \"data\"}`, then all fields that appear in the highlighted data point will be used.\n- If set to `null` or `false`, then no tooltip will be used.\n\nSee the [`tooltip`](https://vega.github.io/vega-lite/docs/tooltip.html) documentation for a detailed discussion about tooltip in Vega-Lite.\n\n__Default value:__ `null`" }, "url": { "anyOf": [ { "$ref": "#/definitions/URI", "description": "The URL of the image file for image marks." }, { "$ref": "#/definitions/ExprRef" } ] }, "width": { "anyOf": [ { "description": "Width of the marks.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "x": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X coordinates of the marks, or width of horizontal `\"bar\"` and `\"area\"` without specified `x2` or `width`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "x2": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "X2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"width\"` for the width of the plot." }, "y": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y coordinates of the marks, or height of vertical `\"bar\"` and `\"area\"` without specified `y2` or `height`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." }, "y2": { "anyOf": [ { "type": "number" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "Y2 coordinates for ranged `\"area\"`, `\"bar\"`, `\"rect\"`, and `\"rule\"`.\n\nThe `value` of this channel can be a number or a string `\"height\"` for the height of the plot." } }, "type": "object" }, "TickCount": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/TimeInterval" }, { "$ref": "#/definitions/TimeIntervalStep" } ] }, "TimeInterval": { "enum": [ "millisecond", "second", "minute", "hour", "day", "week", "month", "year" ], "type": "string" }, "TimeIntervalStep": { "additionalProperties": false, "properties": { "interval": { "$ref": "#/definitions/TimeInterval" }, "step": { "type": "number" } }, "required": [ "interval", "step" ], "type": "object" }, "TimeLocale": { "additionalProperties": false, "description": "Locale definition for formatting dates and times.", "properties": { "date": { "description": "The date (%x) format specifier (e.g., \"%m/%d/%Y\").", "type": "string" }, "dateTime": { "description": "The date and time (%c) format specifier (e.g., \"%a %b %e %X %Y\").", "type": "string" }, "days": { "$ref": "#/definitions/Vector7", "description": "The full names of the weekdays, starting with Sunday." }, "months": { "$ref": "#/definitions/Vector12", "description": "The full names of the months (starting with January)." }, "periods": { "$ref": "#/definitions/Vector2", "description": "The A.M. and P.M. equivalents (e.g., [\"AM\", \"PM\"])." }, "shortDays": { "$ref": "#/definitions/Vector7", "description": "The abbreviated names of the weekdays, starting with Sunday." }, "shortMonths": { "$ref": "#/definitions/Vector12", "description": "The abbreviated names of the months (starting with January)." }, "time": { "description": "The time (%X) format specifier (e.g., \"%H:%M:%S\").", "type": "string" } }, "required": [ "dateTime", "date", "time", "periods", "days", "shortDays", "months", "shortMonths" ], "type": "object" }, "TimeUnit": { "anyOf": [ { "$ref": "#/definitions/SingleTimeUnit" }, { "$ref": "#/definitions/MultiTimeUnit" } ] }, "TimeUnitParams": { "additionalProperties": false, "description": "Time Unit Params for encoding predicate, which can specified if the data is already \"binned\".", "properties": { "binned": { "description": "Whether the data has already been binned to this time unit. If true, Vega-Lite will only format the data, marks, and guides, without applying the timeUnit transform to re-bin the data again.", "type": "boolean" }, "maxbins": { "description": "If no `unit` is specified, maxbins is used to infer time units.", "type": "number" }, "step": { "description": "The number of steps between bins, in terms of the least significant unit provided.", "type": "number" }, "unit": { "$ref": "#/definitions/TimeUnit", "description": "Defines how date-time values should be binned." }, "utc": { "description": "True to use UTC timezone. Equivalent to using a `utc` prefixed `TimeUnit`.", "type": "boolean" } }, "type": "object" }, "TimeUnitTransform": { "additionalProperties": false, "properties": { "as": { "$ref": "#/definitions/FieldName", "description": "The output field to write the timeUnit value." }, "field": { "$ref": "#/definitions/FieldName", "description": "The data field to apply time unit." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/TimeUnitTransformParams" } ], "description": "The timeUnit." } }, "required": [ "timeUnit", "field", "as" ], "type": "object" }, "TimeUnitTransformParams": { "additionalProperties": false, "properties": { "maxbins": { "description": "If no `unit` is specified, maxbins is used to infer time units.", "type": "number" }, "step": { "description": "The number of steps between bins, in terms of the least significant unit provided.", "type": "number" }, "unit": { "$ref": "#/definitions/TimeUnit", "description": "Defines how date-time values should be binned." }, "utc": { "description": "True to use UTC timezone. Equivalent to using a `utc` prefixed `TimeUnit`.", "type": "boolean" } }, "type": "object" }, "TitleAnchor": { "enum": [ null, "start", "middle", "end" ], "type": [ "null", "string" ] }, "TitleConfig": { "$ref": "#/definitions/BaseTitleNoValueRefs" }, "TitleFrame": { "enum": [ "bounds", "group" ], "type": "string" }, "TitleOrient": { "enum": [ "none", "left", "right", "top", "bottom" ], "type": "string" }, "TitleParams": { "additionalProperties": false, "properties": { "align": { "$ref": "#/definitions/Align", "description": "Horizontal text alignment for title text. One of `\"left\"`, `\"center\"`, or `\"right\"`." }, "anchor": { "$ref": "#/definitions/TitleAnchor", "description": "The anchor position for placing the title. One of `\"start\"`, `\"middle\"`, or `\"end\"`. For example, with an orientation of top these anchor positions map to a left-, center-, or right-aligned title.\n\n__Default value:__ `\"middle\"` for [single](https://vega.github.io/vega-lite/docs/spec.html) and [layered](https://vega.github.io/vega-lite/docs/layer.html) views. `\"start\"` for other composite views.\n\n__Note:__ [For now](https://github.com/vega/vega-lite/issues/2875), `anchor` is only customizable only for [single](https://vega.github.io/vega-lite/docs/spec.html) and [layered](https://vega.github.io/vega-lite/docs/layer.html) views. For other composite views, `anchor` is always `\"start\"`." }, "angle": { "anyOf": [ { "description": "Angle in degrees of title and subtitle text.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "aria": { "anyOf": [ { "description": "A boolean flag indicating if [ARIA attributes](https://developer.mozilla.org/en-US/docs/Web/Accessibility/ARIA) should be included (SVG output only). If `false`, the \"aria-hidden\" attribute will be set on the output SVG group, removing the title from the ARIA accessibility tree.\n\n__Default value:__ `true`", "type": "boolean" }, { "$ref": "#/definitions/ExprRef" } ] }, "baseline": { "$ref": "#/definitions/TextBaseline", "description": "Vertical text baseline for title and subtitle text. One of `\"alphabetic\"` (default), `\"top\"`, `\"middle\"`, `\"bottom\"`, `\"line-top\"`, or `\"line-bottom\"`. The `\"line-top\"` and `\"line-bottom\"` values operate similarly to `\"top\"` and `\"bottom\"`, but are calculated relative to the *lineHeight* rather than *fontSize* alone." }, "color": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Text color for title text." }, { "$ref": "#/definitions/ExprRef" } ] }, "dx": { "anyOf": [ { "description": "Delta offset for title and subtitle text x-coordinate.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "dy": { "anyOf": [ { "description": "Delta offset for title and subtitle text y-coordinate.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "font": { "anyOf": [ { "description": "Font name for title text.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontSize": { "anyOf": [ { "description": "Font size in pixels for title text.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "fontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "Font style for title text." }, { "$ref": "#/definitions/ExprRef" } ] }, "fontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "Font weight for title text. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "frame": { "anyOf": [ { "anyOf": [ { "$ref": "#/definitions/TitleFrame" }, { "type": "string" } ], "description": "The reference frame for the anchor position, one of `\"bounds\"` (to anchor relative to the full bounding box) or `\"group\"` (to anchor relative to the group width or height)." }, { "$ref": "#/definitions/ExprRef" } ] }, "limit": { "anyOf": [ { "description": "The maximum allowed length in pixels of title and subtitle text.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "lineHeight": { "anyOf": [ { "description": "Line height in pixels for multi-line title text or title text with `\"line-top\"` or `\"line-bottom\"` baseline.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "offset": { "anyOf": [ { "description": "The orthogonal offset in pixels by which to displace the title group from its position along the edge of the chart.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "orient": { "anyOf": [ { "$ref": "#/definitions/TitleOrient", "description": "Default title orientation (`\"top\"`, `\"bottom\"`, `\"left\"`, or `\"right\"`)" }, { "$ref": "#/definitions/ExprRef" } ] }, "style": { "anyOf": [ { "type": "string" }, { "items": { "type": "string" }, "type": "array" } ], "description": "A [mark style property](https://vega.github.io/vega-lite/docs/config.html#style) to apply to the title text mark.\n\n__Default value:__ `\"group-title\"`." }, "subtitle": { "$ref": "#/definitions/Text", "description": "The subtitle Text." }, "subtitleColor": { "anyOf": [ { "anyOf": [ { "type": "null" }, { "$ref": "#/definitions/Color" } ], "description": "Text color for subtitle text." }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleFont": { "anyOf": [ { "description": "Font name for subtitle text.", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleFontSize": { "anyOf": [ { "description": "Font size in pixels for subtitle text.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleFontStyle": { "anyOf": [ { "$ref": "#/definitions/FontStyle", "description": "Font style for subtitle text." }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleFontWeight": { "anyOf": [ { "$ref": "#/definitions/FontWeight", "description": "Font weight for subtitle text. This can be either a string (e.g `\"bold\"`, `\"normal\"`) or a number (`100`, `200`, `300`, ..., `900` where `\"normal\"` = `400` and `\"bold\"` = `700`)." }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitleLineHeight": { "anyOf": [ { "description": "Line height in pixels for multi-line subtitle text.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "subtitlePadding": { "anyOf": [ { "description": "The padding in pixels between title and subtitle text.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "text": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The title text." }, "zindex": { "description": "The integer z-index indicating the layering of the title group relative to other axis, mark and legend groups.\n\n__Default value:__ `0`.", "minimum": 0, "type": "number" } }, "required": [ "text" ], "type": "object" }, "TooltipContent": { "additionalProperties": false, "properties": { "content": { "enum": [ "encoding", "data" ], "type": "string" } }, "required": [ "content" ], "type": "object" }, "TopLevelConcatSpec": { "additionalProperties": false, "properties": { "$schema": { "description": "URL to [JSON schema](http://json-schema.org/) for a Vega-Lite specification. Unless you have a reason to change this, use `https://vega.github.io/schema/vega-lite/v5.json`. Setting the `$schema` property allows automatic validation and autocomplete in editors that support JSON schema.", "format": "uri", "type": "string" }, "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "autosize": { "anyOf": [ { "$ref": "#/definitions/AutosizeType" }, { "$ref": "#/definitions/AutoSizeParams" } ], "description": "How the visualization size should be determined. If a string, should be one of `\"pad\"`, `\"fit\"` or `\"none\"`. Object values can additionally specify parameters for content sizing and automatic resizing.\n\n__Default value__: `pad`" }, "background": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "CSS color property to use as the background of the entire view.\n\n__Default value:__ `\"white\"`" }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "columns": { "description": "The number of columns to include in the view composition layout.\n\n__Default value__: `undefined` -- An infinite number of columns (a single row) will be assumed. This is equivalent to `hconcat` (for `concat`) and to using the `column` channel (for `facet` and `repeat`).\n\n__Note__:\n\n1) This property is only for:\n- the general (wrappable) `concat` operator (not `hconcat`/`vconcat`)\n- the `facet` and `repeat` operator with one field/repetition definition (without row/column nesting)\n\n2) Setting the `columns` to `1` is equivalent to `vconcat` (for `concat`) and to using the `row` channel (for `facet` and `repeat`).", "type": "number" }, "concat": { "description": "A list of views to be concatenated.", "items": { "$ref": "#/definitions/NonNormalizedSpec" }, "type": "array" }, "config": { "$ref": "#/definitions/Config", "description": "Vega-Lite configuration object. This property can only be defined at the top-level of a specification." }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "datasets": { "$ref": "#/definitions/Datasets", "description": "A global data store for named datasets. This is a mapping from names to inline datasets. This can be an array of objects or primitive values or a string. Arrays of primitive values are ingested as objects with a `data` property." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "padding": { "anyOf": [ { "$ref": "#/definitions/Padding" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The default visualization padding, in pixels, from the edge of the visualization canvas to the data rectangle. If a number, specifies padding for all sides. If an object, the value should have the format `{\"left\": 5, \"top\": 5, \"right\": 5, \"bottom\": 5}` to specify padding for each side of the visualization.\n\n__Default value__: `5`" }, "params": { "description": "Dynamic variables or selections that parameterize a visualization.", "items": { "$ref": "#/definitions/TopLevelParameter" }, "type": "array" }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "usermeta": { "$ref": "#/definitions/Dict", "description": "Optional metadata that will be passed to Vega. This object is completely ignored by Vega and Vega-Lite and can be used for custom metadata." } }, "required": [ "concat" ], "type": "object" }, "TopLevelHConcatSpec": { "additionalProperties": false, "properties": { "$schema": { "description": "URL to [JSON schema](http://json-schema.org/) for a Vega-Lite specification. Unless you have a reason to change this, use `https://vega.github.io/schema/vega-lite/v5.json`. Setting the `$schema` property allows automatic validation and autocomplete in editors that support JSON schema.", "format": "uri", "type": "string" }, "autosize": { "anyOf": [ { "$ref": "#/definitions/AutosizeType" }, { "$ref": "#/definitions/AutoSizeParams" } ], "description": "How the visualization size should be determined. If a string, should be one of `\"pad\"`, `\"fit\"` or `\"none\"`. Object values can additionally specify parameters for content sizing and automatic resizing.\n\n__Default value__: `pad`" }, "background": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "CSS color property to use as the background of the entire view.\n\n__Default value:__ `\"white\"`" }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\n__Default value:__ `false`", "type": "boolean" }, "config": { "$ref": "#/definitions/Config", "description": "Vega-Lite configuration object. This property can only be defined at the top-level of a specification." }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "datasets": { "$ref": "#/definitions/Datasets", "description": "A global data store for named datasets. This is a mapping from names to inline datasets. This can be an array of objects or primitive values or a string. Arrays of primitive values are ingested as objects with a `data` property." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "hconcat": { "description": "A list of views to be concatenated and put into a row.", "items": { "$ref": "#/definitions/NonNormalizedSpec" }, "type": "array" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "padding": { "anyOf": [ { "$ref": "#/definitions/Padding" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The default visualization padding, in pixels, from the edge of the visualization canvas to the data rectangle. If a number, specifies padding for all sides. If an object, the value should have the format `{\"left\": 5, \"top\": 5, \"right\": 5, \"bottom\": 5}` to specify padding for each side of the visualization.\n\n__Default value__: `5`" }, "params": { "description": "Dynamic variables or selections that parameterize a visualization.", "items": { "$ref": "#/definitions/TopLevelParameter" }, "type": "array" }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "description": "The spacing in pixels between sub-views of the concat operator.\n\n__Default value__: `10`", "type": "number" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "usermeta": { "$ref": "#/definitions/Dict", "description": "Optional metadata that will be passed to Vega. This object is completely ignored by Vega and Vega-Lite and can be used for custom metadata." } }, "required": [ "hconcat" ], "type": "object" }, "TopLevelVConcatSpec": { "additionalProperties": false, "properties": { "$schema": { "description": "URL to [JSON schema](http://json-schema.org/) for a Vega-Lite specification. Unless you have a reason to change this, use `https://vega.github.io/schema/vega-lite/v5.json`. Setting the `$schema` property allows automatic validation and autocomplete in editors that support JSON schema.", "format": "uri", "type": "string" }, "autosize": { "anyOf": [ { "$ref": "#/definitions/AutosizeType" }, { "$ref": "#/definitions/AutoSizeParams" } ], "description": "How the visualization size should be determined. If a string, should be one of `\"pad\"`, `\"fit\"` or `\"none\"`. Object values can additionally specify parameters for content sizing and automatic resizing.\n\n__Default value__: `pad`" }, "background": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "CSS color property to use as the background of the entire view.\n\n__Default value:__ `\"white\"`" }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\n__Default value:__ `false`", "type": "boolean" }, "config": { "$ref": "#/definitions/Config", "description": "Vega-Lite configuration object. This property can only be defined at the top-level of a specification." }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "datasets": { "$ref": "#/definitions/Datasets", "description": "A global data store for named datasets. This is a mapping from names to inline datasets. This can be an array of objects or primitive values or a string. Arrays of primitive values are ingested as objects with a `data` property." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "padding": { "anyOf": [ { "$ref": "#/definitions/Padding" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The default visualization padding, in pixels, from the edge of the visualization canvas to the data rectangle. If a number, specifies padding for all sides. If an object, the value should have the format `{\"left\": 5, \"top\": 5, \"right\": 5, \"bottom\": 5}` to specify padding for each side of the visualization.\n\n__Default value__: `5`" }, "params": { "description": "Dynamic variables or selections that parameterize a visualization.", "items": { "$ref": "#/definitions/TopLevelParameter" }, "type": "array" }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "description": "The spacing in pixels between sub-views of the concat operator.\n\n__Default value__: `10`", "type": "number" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "usermeta": { "$ref": "#/definitions/Dict", "description": "Optional metadata that will be passed to Vega. This object is completely ignored by Vega and Vega-Lite and can be used for custom metadata." }, "vconcat": { "description": "A list of views to be concatenated and put into a column.", "items": { "$ref": "#/definitions/NonNormalizedSpec" }, "type": "array" } }, "required": [ "vconcat" ], "type": "object" }, "TopLevelLayerSpec": { "additionalProperties": false, "properties": { "$schema": { "description": "URL to [JSON schema](http://json-schema.org/) for a Vega-Lite specification. Unless you have a reason to change this, use `https://vega.github.io/schema/vega-lite/v5.json`. Setting the `$schema` property allows automatic validation and autocomplete in editors that support JSON schema.", "format": "uri", "type": "string" }, "autosize": { "anyOf": [ { "$ref": "#/definitions/AutosizeType" }, { "$ref": "#/definitions/AutoSizeParams" } ], "description": "How the visualization size should be determined. If a string, should be one of `\"pad\"`, `\"fit\"` or `\"none\"`. Object values can additionally specify parameters for content sizing and automatic resizing.\n\n__Default value__: `pad`" }, "background": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "CSS color property to use as the background of the entire view.\n\n__Default value:__ `\"white\"`" }, "config": { "$ref": "#/definitions/Config", "description": "Vega-Lite configuration object. This property can only be defined at the top-level of a specification." }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "datasets": { "$ref": "#/definitions/Datasets", "description": "A global data store for named datasets. This is a mapping from names to inline datasets. This can be an array of objects or primitive values or a string. Arrays of primitive values are ingested as objects with a `data` property." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "encoding": { "$ref": "#/definitions/SharedEncoding", "description": "A shared key-value mapping between encoding channels and definition of fields in the underlying layers." }, "height": { "anyOf": [ { "type": "number" }, { "const": "container", "type": "string" }, { "$ref": "#/definitions/Step" } ], "description": "The height of a visualization.\n\n- For a plot with a continuous y-field, height should be a number.\n- For a plot with either a discrete y-field or no y-field, height can be either a number indicating a fixed height or an object in the form of `{step: number}` defining the height per discrete step. (No y-field is equivalent to having one discrete step.)\n- To enable responsive sizing on height, it should be set to `\"container\"`.\n\n__Default value:__ Based on `config.view.continuousHeight` for a plot with a continuous y-field and `config.view.discreteHeight` otherwise.\n\n__Note:__ For plots with [`row` and `column` channels](https://vega.github.io/vega-lite/docs/encoding.html#facet), this represents the height of a single view and the `\"container\"` option cannot be used.\n\n__See also:__ [`height`](https://vega.github.io/vega-lite/docs/size.html) documentation." }, "layer": { "description": "Layer or single view specifications to be layered.\n\n__Note__: Specifications inside `layer` cannot use `row` and `column` channels as layering facet specifications is not allowed. Instead, use the [facet operator](https://vega.github.io/vega-lite/docs/facet.html) and place a layer inside a facet.", "items": { "anyOf": [ { "$ref": "#/definitions/LayerSpec" }, { "$ref": "#/definitions/UnitSpec" } ] }, "type": "array" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "padding": { "anyOf": [ { "$ref": "#/definitions/Padding" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The default visualization padding, in pixels, from the edge of the visualization canvas to the data rectangle. If a number, specifies padding for all sides. If an object, the value should have the format `{\"left\": 5, \"top\": 5, \"right\": 5, \"bottom\": 5}` to specify padding for each side of the visualization.\n\n__Default value__: `5`" }, "params": { "description": "Dynamic variables or selections that parameterize a visualization.", "items": { "$ref": "#/definitions/TopLevelParameter" }, "type": "array" }, "projection": { "$ref": "#/definitions/Projection", "description": "An object defining properties of the geographic projection shared by underlying layers." }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "usermeta": { "$ref": "#/definitions/Dict", "description": "Optional metadata that will be passed to Vega. This object is completely ignored by Vega and Vega-Lite and can be used for custom metadata." }, "view": { "$ref": "#/definitions/ViewBackground", "description": "An object defining the view background's fill and stroke.\n\n__Default value:__ none (transparent)" }, "width": { "anyOf": [ { "type": "number" }, { "const": "container", "type": "string" }, { "$ref": "#/definitions/Step" } ], "description": "The width of a visualization.\n\n- For a plot with a continuous x-field, width should be a number.\n- For a plot with either a discrete x-field or no x-field, width can be either a number indicating a fixed width or an object in the form of `{step: number}` defining the width per discrete step. (No x-field is equivalent to having one discrete step.)\n- To enable responsive sizing on width, it should be set to `\"container\"`.\n\n__Default value:__ Based on `config.view.continuousWidth` for a plot with a continuous x-field and `config.view.discreteWidth` otherwise.\n\n__Note:__ For plots with [`row` and `column` channels](https://vega.github.io/vega-lite/docs/encoding.html#facet), this represents the width of a single view and the `\"container\"` option cannot be used.\n\n__See also:__ [`width`](https://vega.github.io/vega-lite/docs/size.html) documentation." } }, "required": [ "layer" ], "type": "object" }, "TopLevelRepeatSpec": { "anyOf": [ { "additionalProperties": false, "properties": { "$schema": { "description": "URL to [JSON schema](http://json-schema.org/) for a Vega-Lite specification. Unless you have a reason to change this, use `https://vega.github.io/schema/vega-lite/v5.json`. Setting the `$schema` property allows automatic validation and autocomplete in editors that support JSON schema.", "format": "uri", "type": "string" }, "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "autosize": { "anyOf": [ { "$ref": "#/definitions/AutosizeType" }, { "$ref": "#/definitions/AutoSizeParams" } ], "description": "How the visualization size should be determined. If a string, should be one of `\"pad\"`, `\"fit\"` or `\"none\"`. Object values can additionally specify parameters for content sizing and automatic resizing.\n\n__Default value__: `pad`" }, "background": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "CSS color property to use as the background of the entire view.\n\n__Default value:__ `\"white\"`" }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "columns": { "description": "The number of columns to include in the view composition layout.\n\n__Default value__: `undefined` -- An infinite number of columns (a single row) will be assumed. This is equivalent to `hconcat` (for `concat`) and to using the `column` channel (for `facet` and `repeat`).\n\n__Note__:\n\n1) This property is only for:\n- the general (wrappable) `concat` operator (not `hconcat`/`vconcat`)\n- the `facet` and `repeat` operator with one field/repetition definition (without row/column nesting)\n\n2) Setting the `columns` to `1` is equivalent to `vconcat` (for `concat`) and to using the `row` channel (for `facet` and `repeat`).", "type": "number" }, "config": { "$ref": "#/definitions/Config", "description": "Vega-Lite configuration object. This property can only be defined at the top-level of a specification." }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "datasets": { "$ref": "#/definitions/Datasets", "description": "A global data store for named datasets. This is a mapping from names to inline datasets. This can be an array of objects or primitive values or a string. Arrays of primitive values are ingested as objects with a `data` property." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "padding": { "anyOf": [ { "$ref": "#/definitions/Padding" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The default visualization padding, in pixels, from the edge of the visualization canvas to the data rectangle. If a number, specifies padding for all sides. If an object, the value should have the format `{\"left\": 5, \"top\": 5, \"right\": 5, \"bottom\": 5}` to specify padding for each side of the visualization.\n\n__Default value__: `5`" }, "params": { "description": "Dynamic variables or selections that parameterize a visualization.", "items": { "$ref": "#/definitions/TopLevelParameter" }, "type": "array" }, "repeat": { "anyOf": [ { "items": { "type": "string" }, "type": "array" }, { "$ref": "#/definitions/RepeatMapping" } ], "description": "Definition for fields to be repeated. One of: 1) An array of fields to be repeated. If `\"repeat\"` is an array, the field can be referred to as `{\"repeat\": \"repeat\"}`. The repeated views are laid out in a wrapped row. You can set the number of columns to control the wrapping. 2) An object that maps `\"row\"` and/or `\"column\"` to the listed fields to be repeated along the particular orientations. The objects `{\"repeat\": \"row\"}` and `{\"repeat\": \"column\"}` can be used to refer to the repeated field respectively." }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "spec": { "$ref": "#/definitions/NonNormalizedSpec", "description": "A specification of the view that gets repeated." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "usermeta": { "$ref": "#/definitions/Dict", "description": "Optional metadata that will be passed to Vega. This object is completely ignored by Vega and Vega-Lite and can be used for custom metadata." } }, "required": [ "repeat", "spec" ], "type": "object" }, { "additionalProperties": false, "properties": { "$schema": { "description": "URL to [JSON schema](http://json-schema.org/) for a Vega-Lite specification. Unless you have a reason to change this, use `https://vega.github.io/schema/vega-lite/v5.json`. Setting the `$schema` property allows automatic validation and autocomplete in editors that support JSON schema.", "format": "uri", "type": "string" }, "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "autosize": { "anyOf": [ { "$ref": "#/definitions/AutosizeType" }, { "$ref": "#/definitions/AutoSizeParams" } ], "description": "How the visualization size should be determined. If a string, should be one of `\"pad\"`, `\"fit\"` or `\"none\"`. Object values can additionally specify parameters for content sizing and automatic resizing.\n\n__Default value__: `pad`" }, "background": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "CSS color property to use as the background of the entire view.\n\n__Default value:__ `\"white\"`" }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "columns": { "description": "The number of columns to include in the view composition layout.\n\n__Default value__: `undefined` -- An infinite number of columns (a single row) will be assumed. This is equivalent to `hconcat` (for `concat`) and to using the `column` channel (for `facet` and `repeat`).\n\n__Note__:\n\n1) This property is only for:\n- the general (wrappable) `concat` operator (not `hconcat`/`vconcat`)\n- the `facet` and `repeat` operator with one field/repetition definition (without row/column nesting)\n\n2) Setting the `columns` to `1` is equivalent to `vconcat` (for `concat`) and to using the `row` channel (for `facet` and `repeat`).", "type": "number" }, "config": { "$ref": "#/definitions/Config", "description": "Vega-Lite configuration object. This property can only be defined at the top-level of a specification." }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "datasets": { "$ref": "#/definitions/Datasets", "description": "A global data store for named datasets. This is a mapping from names to inline datasets. This can be an array of objects or primitive values or a string. Arrays of primitive values are ingested as objects with a `data` property." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "padding": { "anyOf": [ { "$ref": "#/definitions/Padding" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The default visualization padding, in pixels, from the edge of the visualization canvas to the data rectangle. If a number, specifies padding for all sides. If an object, the value should have the format `{\"left\": 5, \"top\": 5, \"right\": 5, \"bottom\": 5}` to specify padding for each side of the visualization.\n\n__Default value__: `5`" }, "params": { "description": "Dynamic variables or selections that parameterize a visualization.", "items": { "$ref": "#/definitions/TopLevelParameter" }, "type": "array" }, "repeat": { "$ref": "#/definitions/LayerRepeatMapping", "description": "Definition for fields to be repeated. One of: 1) An array of fields to be repeated. If `\"repeat\"` is an array, the field can be referred to as `{\"repeat\": \"repeat\"}`. The repeated views are laid out in a wrapped row. You can set the number of columns to control the wrapping. 2) An object that maps `\"row\"` and/or `\"column\"` to the listed fields to be repeated along the particular orientations. The objects `{\"repeat\": \"row\"}` and `{\"repeat\": \"column\"}` can be used to refer to the repeated field respectively." }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "spec": { "anyOf": [ { "$ref": "#/definitions/LayerSpec" }, { "$ref": "#/definitions/UnitSpecWithFrame" } ], "description": "A specification of the view that gets repeated." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "usermeta": { "$ref": "#/definitions/Dict", "description": "Optional metadata that will be passed to Vega. This object is completely ignored by Vega and Vega-Lite and can be used for custom metadata." } }, "required": [ "repeat", "spec" ], "type": "object" } ] }, "TopLevelFacetSpec": { "additionalProperties": false, "properties": { "$schema": { "description": "URL to [JSON schema](http://json-schema.org/) for a Vega-Lite specification. Unless you have a reason to change this, use `https://vega.github.io/schema/vega-lite/v5.json`. Setting the `$schema` property allows automatic validation and autocomplete in editors that support JSON schema.", "format": "uri", "type": "string" }, "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "autosize": { "anyOf": [ { "$ref": "#/definitions/AutosizeType" }, { "$ref": "#/definitions/AutoSizeParams" } ], "description": "How the visualization size should be determined. If a string, should be one of `\"pad\"`, `\"fit\"` or `\"none\"`. Object values can additionally specify parameters for content sizing and automatic resizing.\n\n__Default value__: `pad`" }, "background": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "CSS color property to use as the background of the entire view.\n\n__Default value:__ `\"white\"`" }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "columns": { "description": "The number of columns to include in the view composition layout.\n\n__Default value__: `undefined` -- An infinite number of columns (a single row) will be assumed. This is equivalent to `hconcat` (for `concat`) and to using the `column` channel (for `facet` and `repeat`).\n\n__Note__:\n\n1) This property is only for:\n- the general (wrappable) `concat` operator (not `hconcat`/`vconcat`)\n- the `facet` and `repeat` operator with one field/repetition definition (without row/column nesting)\n\n2) Setting the `columns` to `1` is equivalent to `vconcat` (for `concat`) and to using the `row` channel (for `facet` and `repeat`).", "type": "number" }, "config": { "$ref": "#/definitions/Config", "description": "Vega-Lite configuration object. This property can only be defined at the top-level of a specification." }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "datasets": { "$ref": "#/definitions/Datasets", "description": "A global data store for named datasets. This is a mapping from names to inline datasets. This can be an array of objects or primitive values or a string. Arrays of primitive values are ingested as objects with a `data` property." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "facet": { "anyOf": [ { "$ref": "#/definitions/FacetFieldDef" }, { "$ref": "#/definitions/FacetMapping" } ], "description": "Definition for how to facet the data. One of: 1) [a field definition for faceting the plot by one field](https://vega.github.io/vega-lite/docs/facet.html#field-def) 2) [An object that maps `row` and `column` channels to their field definitions](https://vega.github.io/vega-lite/docs/facet.html#mapping)" }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "padding": { "anyOf": [ { "$ref": "#/definitions/Padding" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The default visualization padding, in pixels, from the edge of the visualization canvas to the data rectangle. If a number, specifies padding for all sides. If an object, the value should have the format `{\"left\": 5, \"top\": 5, \"right\": 5, \"bottom\": 5}` to specify padding for each side of the visualization.\n\n__Default value__: `5`" }, "params": { "description": "Dynamic variables or selections that parameterize a visualization.", "items": { "$ref": "#/definitions/TopLevelParameter" }, "type": "array" }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "spec": { "anyOf": [ { "$ref": "#/definitions/LayerSpec" }, { "$ref": "#/definitions/UnitSpecWithFrame" } ], "description": "A specification of the view that gets faceted." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "usermeta": { "$ref": "#/definitions/Dict", "description": "Optional metadata that will be passed to Vega. This object is completely ignored by Vega and Vega-Lite and can be used for custom metadata." } }, "required": [ "data", "facet", "spec" ], "type": "object" }, "TopLevelParameter": { "anyOf": [ { "$ref": "#/definitions/VariableParameter" }, { "$ref": "#/definitions/TopLevelSelectionParameter" } ] }, "TopLevelSelectionParameter": { "additionalProperties": false, "properties": { "bind": { "anyOf": [ { "$ref": "#/definitions/Binding" }, { "additionalProperties": { "$ref": "#/definitions/Binding" }, "type": "object" }, { "$ref": "#/definitions/LegendBinding" }, { "const": "scales", "type": "string" } ], "description": "When set, a selection is populated by input elements (also known as dynamic query widgets) or by interacting with the corresponding legend. Direct manipulation interaction is disabled by default; to re-enable it, set the selection's [`on`](https://vega.github.io/vega-lite/docs/selection.html#common-selection-properties) property.\n\nLegend bindings are restricted to selections that only specify a single field or encoding.\n\nQuery widget binding takes the form of Vega's [input element binding definition](https://vega.github.io/vega/docs/signals/#bind) or can be a mapping between projected field/encodings and binding definitions.\n\n__See also:__ [`bind`](https://vega.github.io/vega-lite/docs/bind.html) documentation." }, "name": { "$ref": "#/definitions/ParameterName", "description": "Required. A unique name for the selection parameter. Selection names should be valid JavaScript identifiers: they should contain only alphanumeric characters (or \"$\", or \"_\") and may not start with a digit. Reserved keywords that may not be used as parameter names are \"datum\", \"event\", \"item\", and \"parent\"." }, "select": { "anyOf": [ { "$ref": "#/definitions/SelectionType" }, { "$ref": "#/definitions/PointSelectionConfig" }, { "$ref": "#/definitions/IntervalSelectionConfig" } ], "description": "Determines the default event processing and data query for the selection. Vega-Lite currently supports two selection types:\n\n- `\"point\"` -- to select multiple discrete data values; the first value is selected on `click` and additional values toggled on shift-click.\n- `\"interval\"` -- to select a continuous range of data values on `drag`." }, "value": { "anyOf": [ { "$ref": "#/definitions/SelectionInit" }, { "items": { "$ref": "#/definitions/SelectionInitMapping" }, "type": "array" }, { "$ref": "#/definitions/SelectionInitIntervalMapping" } ], "description": "Initialize the selection with a mapping between [projected channels or field names](https://vega.github.io/vega-lite/docs/selection.html#project) and initial values.\n\n__See also:__ [`init`](https://vega.github.io/vega-lite/docs/value.html) documentation." }, "views": { "description": "By default, top-level selections are applied to every view in the visualization. If this property is specified, selections will only be applied to views with the given names.", "items": { "type": "string" }, "type": "array" } }, "required": [ "name", "select" ], "type": "object" }, "TopLevelSpec": { "anyOf": [ { "$ref": "#/definitions/TopLevelUnitSpec" }, { "$ref": "#/definitions/TopLevelFacetSpec" }, { "$ref": "#/definitions/TopLevelLayerSpec" }, { "$ref": "#/definitions/TopLevelRepeatSpec" }, { "$ref": "#/definitions/TopLevelConcatSpec" }, { "$ref": "#/definitions/TopLevelVConcatSpec" }, { "$ref": "#/definitions/TopLevelHConcatSpec" } ], "description": "A Vega-Lite top-level specification. This is the root class for all Vega-Lite specifications. (The json schema is generated from this type.)" }, "TopLevelUnitSpec": { "additionalProperties": false, "properties": { "$schema": { "description": "URL to [JSON schema](http://json-schema.org/) for a Vega-Lite specification. Unless you have a reason to change this, use `https://vega.github.io/schema/vega-lite/v5.json`. Setting the `$schema` property allows automatic validation and autocomplete in editors that support JSON schema.", "format": "uri", "type": "string" }, "align": { "anyOf": [ { "$ref": "#/definitions/LayoutAlign" }, { "$ref": "#/definitions/RowCol" } ], "description": "The alignment to apply to grid rows and columns. The supported string values are `\"all\"`, `\"each\"`, and `\"none\"`.\n\n- For `\"none\"`, a flow layout will be used, in which adjacent subviews are simply placed one after the other.\n- For `\"each\"`, subviews will be aligned into a clean grid structure, but each row or column may be of variable size.\n- For `\"all\"`, subviews will be aligned and each row or column will be sized identically based on the maximum observed size. String values for this property will be applied to both grid rows and columns.\n\nAlternatively, an object value of the form `{\"row\": string, \"column\": string}` can be used to supply different alignments for rows and columns.\n\n__Default value:__ `\"all\"`." }, "autosize": { "anyOf": [ { "$ref": "#/definitions/AutosizeType" }, { "$ref": "#/definitions/AutoSizeParams" } ], "description": "How the visualization size should be determined. If a string, should be one of `\"pad\"`, `\"fit\"` or `\"none\"`. Object values can additionally specify parameters for content sizing and automatic resizing.\n\n__Default value__: `pad`" }, "background": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "$ref": "#/definitions/ExprRef" } ], "description": "CSS color property to use as the background of the entire view.\n\n__Default value:__ `\"white\"`" }, "bounds": { "description": "The bounds calculation method to use for determining the extent of a sub-plot. One of `full` (the default) or `flush`.\n\n- If set to `full`, the entire calculated bounds (including axes, title, and legend) will be used.\n- If set to `flush`, only the specified width and height values for the sub-view will be used. The `flush` setting can be useful when attempting to place sub-plots without axes or legends into a uniform grid structure.\n\n__Default value:__ `\"full\"`", "enum": [ "full", "flush" ], "type": "string" }, "center": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/RowCol" } ], "description": "Boolean flag indicating if subviews should be centered relative to their respective rows or columns.\n\nAn object value of the form `{\"row\": boolean, \"column\": boolean}` can be used to supply different centering values for rows and columns.\n\n__Default value:__ `false`" }, "config": { "$ref": "#/definitions/Config", "description": "Vega-Lite configuration object. This property can only be defined at the top-level of a specification." }, "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "datasets": { "$ref": "#/definitions/Datasets", "description": "A global data store for named datasets. This is a mapping from names to inline datasets. This can be an array of objects or primitive values or a string. Arrays of primitive values are ingested as objects with a `data` property." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "encoding": { "$ref": "#/definitions/FacetedEncoding", "description": "A key-value mapping between encoding channels and definition of fields." }, "height": { "anyOf": [ { "type": "number" }, { "const": "container", "type": "string" }, { "$ref": "#/definitions/Step" } ], "description": "The height of a visualization.\n\n- For a plot with a continuous y-field, height should be a number.\n- For a plot with either a discrete y-field or no y-field, height can be either a number indicating a fixed height or an object in the form of `{step: number}` defining the height per discrete step. (No y-field is equivalent to having one discrete step.)\n- To enable responsive sizing on height, it should be set to `\"container\"`.\n\n__Default value:__ Based on `config.view.continuousHeight` for a plot with a continuous y-field and `config.view.discreteHeight` otherwise.\n\n__Note:__ For plots with [`row` and `column` channels](https://vega.github.io/vega-lite/docs/encoding.html#facet), this represents the height of a single view and the `\"container\"` option cannot be used.\n\n__See also:__ [`height`](https://vega.github.io/vega-lite/docs/size.html) documentation." }, "mark": { "$ref": "#/definitions/AnyMark", "description": "A string describing the mark type (one of `\"bar\"`, `\"circle\"`, `\"square\"`, `\"tick\"`, `\"line\"`, `\"area\"`, `\"point\"`, `\"rule\"`, `\"geoshape\"`, and `\"text\"`) or a [mark definition object](https://vega.github.io/vega-lite/docs/mark.html#mark-def)." }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "padding": { "anyOf": [ { "$ref": "#/definitions/Padding" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The default visualization padding, in pixels, from the edge of the visualization canvas to the data rectangle. If a number, specifies padding for all sides. If an object, the value should have the format `{\"left\": 5, \"top\": 5, \"right\": 5, \"bottom\": 5}` to specify padding for each side of the visualization.\n\n__Default value__: `5`" }, "params": { "description": "An array of parameters that may either be simple variables, or more complex selections that map user input to data queries.", "items": { "$ref": "#/definitions/TopLevelParameter" }, "type": "array" }, "projection": { "$ref": "#/definitions/Projection", "description": "An object defining properties of geographic projection, which will be applied to `shape` path for `\"geoshape\"` marks and to `latitude` and `\"longitude\"` channels for other marks." }, "resolve": { "$ref": "#/definitions/Resolve", "description": "Scale, axis, and legend resolutions for view composition specifications." }, "spacing": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/RowCol" } ], "description": "The spacing in pixels between sub-views of the composition operator. An object of the form `{\"row\": number, \"column\": number}` can be used to set different spacing values for rows and columns.\n\n__Default value__: Depends on `\"spacing\"` property of [the view composition configuration](https://vega.github.io/vega-lite/docs/config.html#view-config) (`20` by default)" }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "usermeta": { "$ref": "#/definitions/Dict", "description": "Optional metadata that will be passed to Vega. This object is completely ignored by Vega and Vega-Lite and can be used for custom metadata." }, "view": { "$ref": "#/definitions/ViewBackground", "description": "An object defining the view background's fill and stroke.\n\n__Default value:__ none (transparent)" }, "width": { "anyOf": [ { "type": "number" }, { "const": "container", "type": "string" }, { "$ref": "#/definitions/Step" } ], "description": "The width of a visualization.\n\n- For a plot with a continuous x-field, width should be a number.\n- For a plot with either a discrete x-field or no x-field, width can be either a number indicating a fixed width or an object in the form of `{step: number}` defining the width per discrete step. (No x-field is equivalent to having one discrete step.)\n- To enable responsive sizing on width, it should be set to `\"container\"`.\n\n__Default value:__ Based on `config.view.continuousWidth` for a plot with a continuous x-field and `config.view.discreteWidth` otherwise.\n\n__Note:__ For plots with [`row` and `column` channels](https://vega.github.io/vega-lite/docs/encoding.html#facet), this represents the width of a single view and the `\"container\"` option cannot be used.\n\n__See also:__ [`width`](https://vega.github.io/vega-lite/docs/size.html) documentation." } }, "required": [ "data", "mark" ], "type": "object" }, "TopoDataFormat": { "additionalProperties": false, "properties": { "feature": { "description": "The name of the TopoJSON object set to convert to a GeoJSON feature collection. For example, in a map of the world, there may be an object set named `\"countries\"`. Using the feature property, we can extract this set and generate a GeoJSON feature object for each country.", "type": "string" }, "mesh": { "description": "The name of the TopoJSON object set to convert to mesh. Similar to the `feature` option, `mesh` extracts a named TopoJSON object set. Unlike the `feature` option, the corresponding geo data is returned as a single, unified mesh instance, not as individual GeoJSON features. Extracting a mesh is useful for more efficiently drawing borders or other geographic elements that you do not need to associate with specific regions such as individual countries, states or counties.", "type": "string" }, "parse": { "anyOf": [ { "$ref": "#/definitions/Parse" }, { "type": "null" } ], "description": "If set to `null`, disable type inference based on the spec and only use type inference based on the data. Alternatively, a parsing directive object can be provided for explicit data types. Each property of the object corresponds to a field name, and the value to the desired data type (one of `\"number\"`, `\"boolean\"`, `\"date\"`, or null (do not parse the field)). For example, `\"parse\": {\"modified_on\": \"date\"}` parses the `modified_on` field in each input record a Date value.\n\nFor `\"date\"`, we parse data based using JavaScript's [`Date.parse()`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Date/parse). For Specific date formats can be provided (e.g., `{foo: \"date:'%m%d%Y'\"}`), using the [d3-time-format syntax](https://github.com/d3/d3-time-format#locale_format). UTC date format parsing is supported similarly (e.g., `{foo: \"utc:'%m%d%Y'\"}`). See more about [UTC time](https://vega.github.io/vega-lite/docs/timeunit.html#utc)" }, "type": { "const": "topojson", "description": "Type of input data: `\"json\"`, `\"csv\"`, `\"tsv\"`, `\"dsv\"`.\n\n__Default value:__ The default format type is determined by the extension of the file URL. If no extension is detected, `\"json\"` will be used by default.", "type": "string" } }, "type": "object" }, "Transform": { "anyOf": [ { "$ref": "#/definitions/AggregateTransform" }, { "$ref": "#/definitions/BinTransform" }, { "$ref": "#/definitions/CalculateTransform" }, { "$ref": "#/definitions/DensityTransform" }, { "$ref": "#/definitions/ExtentTransform" }, { "$ref": "#/definitions/FilterTransform" }, { "$ref": "#/definitions/FlattenTransform" }, { "$ref": "#/definitions/FoldTransform" }, { "$ref": "#/definitions/ImputeTransform" }, { "$ref": "#/definitions/JoinAggregateTransform" }, { "$ref": "#/definitions/LoessTransform" }, { "$ref": "#/definitions/LookupTransform" }, { "$ref": "#/definitions/QuantileTransform" }, { "$ref": "#/definitions/RegressionTransform" }, { "$ref": "#/definitions/TimeUnitTransform" }, { "$ref": "#/definitions/SampleTransform" }, { "$ref": "#/definitions/StackTransform" }, { "$ref": "#/definitions/WindowTransform" }, { "$ref": "#/definitions/PivotTransform" } ] }, "Type": { "description": "Data type based on level of measurement", "enum": [ "quantitative", "ordinal", "temporal", "nominal", "geojson" ], "type": "string" }, "TypeForShape": { "enum": [ "nominal", "ordinal", "geojson" ], "type": "string" }, "TypedFieldDef": { "additionalProperties": false, "description": "Definition object for a data field, its type and transformation of an encoding channel.", "properties": { "aggregate": { "$ref": "#/definitions/Aggregate", "description": "Aggregation function for the field (e.g., `\"mean\"`, `\"sum\"`, `\"median\"`, `\"min\"`, `\"max\"`, `\"count\"`).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html) documentation." }, "bandPosition": { "description": "Relative position on a band of a stacked, binned, time unit, or band scale. For example, the marks will be positioned at the beginning of the band if set to `0`, and at the middle of the band if set to `0.5`.", "maximum": 1, "minimum": 0, "type": "number" }, "bin": { "anyOf": [ { "type": "boolean" }, { "$ref": "#/definitions/BinParams" }, { "const": "binned", "type": "string" }, { "type": "null" } ], "description": "A flag for binning a `quantitative` field, [an object defining binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters), or indicating that the data for `x` or `y` channel are binned before they are imported into Vega-Lite (`\"binned\"`).\n\n- If `true`, default [binning parameters](https://vega.github.io/vega-lite/docs/bin.html#bin-parameters) will be applied.\n\n- If `\"binned\"`, this indicates that the data for the `x` (or `y`) channel are already binned. You can map the bin-start field to `x` (or `y`) and the bin-end field to `x2` (or `y2`). The scale and axis will be formatted similar to binning in Vega-Lite. To adjust the axis ticks based on the bin step, you can also set the axis's [`tickMinStep`](https://vega.github.io/vega-lite/docs/axis.html#ticks) property.\n\n__Default value:__ `false`\n\n__See also:__ [`bin`](https://vega.github.io/vega-lite/docs/bin.html) documentation." }, "field": { "$ref": "#/definitions/Field", "description": "__Required.__ A string defining the name of the field from which to pull a data value or an object defining iterated values from the [`repeat`](https://vega.github.io/vega-lite/docs/repeat.html) operator.\n\n__See also:__ [`field`](https://vega.github.io/vega-lite/docs/field.html) documentation.\n\n__Notes:__ 1) Dots (`.`) and brackets (`[` and `]`) can be used to access nested objects (e.g., `\"field\": \"foo.bar\"` and `\"field\": \"foo['bar']\"`). If field names contain dots or brackets but are not nested, you can use `\\\\` to escape dots and brackets (e.g., `\"a\\\\.b\"` and `\"a\\\\[0\\\\]\"`). See more details about escaping in the [field documentation](https://vega.github.io/vega-lite/docs/field.html). 2) `field` is not required if `aggregate` is `count`." }, "timeUnit": { "anyOf": [ { "$ref": "#/definitions/TimeUnit" }, { "$ref": "#/definitions/BinnedTimeUnit" }, { "$ref": "#/definitions/TimeUnitParams" } ], "description": "Time unit (e.g., `year`, `yearmonth`, `month`, `hours`) for a temporal field. or [a temporal field that gets casted as ordinal](https://vega.github.io/vega-lite/docs/type.html#cast).\n\n__Default value:__ `undefined` (None)\n\n__See also:__ [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html) documentation." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "type": "null" } ], "description": "A title for the field. If `null`, the title will be removed.\n\n__Default value:__ derived from the field's name and transformation function (`aggregate`, `bin` and `timeUnit`). If the field has an aggregate function, the function is displayed as part of the title (e.g., `\"Sum of Profit\"`). If the field is binned or has a time unit applied, the applied function is shown in parentheses (e.g., `\"Profit (binned)\"`, `\"Transaction Date (year-month)\"`). Otherwise, the title is simply the field name.\n\n__Notes__:\n\n1) You can customize the default field title format by providing the [`fieldTitle`](https://vega.github.io/vega-lite/docs/config.html#top-level-config) property in the [config](https://vega.github.io/vega-lite/docs/config.html) or [`fieldTitle` function via the `compile` function's options](https://vega.github.io/vega-lite/usage/compile.html#field-title).\n\n2) If both field definition's `title` and axis, header, or legend `title` are defined, axis/header/legend title will be used." }, "type": { "$ref": "#/definitions/StandardType", "description": "The type of measurement (`\"quantitative\"`, `\"temporal\"`, `\"ordinal\"`, or `\"nominal\"`) for the encoded field or constant value (`datum`). It can also be a `\"geojson\"` type for encoding ['geoshape'](https://vega.github.io/vega-lite/docs/geoshape.html).\n\nVega-Lite automatically infers data types in many cases as discussed below. However, type is required for a field if: (1) the field is not nominal and the field encoding has no specified `aggregate` (except `argmin` and `argmax`), `bin`, scale type, custom `sort` order, nor `timeUnit` or (2) if you wish to use an ordinal scale for a field with `bin` or `timeUnit`.\n\n__Default value:__\n\n1) For a data `field`, `\"nominal\"` is the default data type unless the field encoding has `aggregate`, `channel`, `bin`, scale type, `sort`, or `timeUnit` that satisfies the following criteria:\n- `\"quantitative\"` is the default type if (1) the encoded field contains `bin` or `aggregate` except `\"argmin\"` and `\"argmax\"`, (2) the encoding channel is `latitude` or `longitude` channel or (3) if the specified scale type is [a quantitative scale](https://vega.github.io/vega-lite/docs/scale.html#type).\n- `\"temporal\"` is the default type if (1) the encoded field contains `timeUnit` or (2) the specified scale type is a time or utc scale\n- `\"ordinal\"` is the default type if (1) the encoded field contains a [custom `sort` order](https://vega.github.io/vega-lite/docs/sort.html#specifying-custom-sort-order), (2) the specified scale type is an ordinal/point/band scale, or (3) the encoding channel is `order`.\n\n2) For a constant value in data domain (`datum`):\n- `\"quantitative\"` if the datum is a number\n- `\"nominal\"` if the datum is a string\n- `\"temporal\"` if the datum is [a date time object](https://vega.github.io/vega-lite/docs/datetime.html)\n\n__Note:__\n- Data `type` describes the semantics of the data rather than the primitive data types (number, string, etc.). The same primitive data type can have different types of measurement. For example, numeric data can represent quantitative, ordinal, or nominal data.\n- Data values for a temporal field can be either a date-time string (e.g., `\"2015-03-07 12:32:17\"`, `\"17:01\"`, `\"2015-03-16\"`. `\"2015\"`) or a timestamp number (e.g., `1552199579097`).\n- When using with [`bin`](https://vega.github.io/vega-lite/docs/bin.html), the `type` property can be either `\"quantitative\"` (for using a linear bin scale) or [`\"ordinal\"` (for using an ordinal bin scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`timeUnit`](https://vega.github.io/vega-lite/docs/timeunit.html), the `type` property can be either `\"temporal\"` (default, for using a temporal scale) or [`\"ordinal\"` (for using an ordinal scale)](https://vega.github.io/vega-lite/docs/type.html#cast-bin).\n- When using with [`aggregate`](https://vega.github.io/vega-lite/docs/aggregate.html), the `type` property refers to the post-aggregation data type. For example, we can calculate count `distinct` of a categorical field `\"cat\"` using `{\"aggregate\": \"distinct\", \"field\": \"cat\"}`. The `\"type\"` of the aggregate output is `\"quantitative\"`.\n- Secondary channels (e.g., `x2`, `y2`, `xError`, `yError`) do not have `type` as they must have exactly the same type as their primary channels (e.g., `x`, `y`).\n\n__See also:__ [`type`](https://vega.github.io/vega-lite/docs/type.html) documentation." } }, "type": "object" }, "URI": { "format": "uri-reference", "type": "string" }, "UnitSpec": { "$ref": "#/definitions/GenericUnitSpec", "description": "A unit specification, which can contain either [primitive marks or composite marks](https://vega.github.io/vega-lite/docs/mark.html#types)." }, "UnitSpecWithFrame": { "additionalProperties": false, "properties": { "data": { "anyOf": [ { "$ref": "#/definitions/Data" }, { "type": "null" } ], "description": "An object describing the data source. Set to `null` to ignore the parent's data source. If no data is set, it is derived from the parent." }, "description": { "description": "Description of this mark for commenting purpose.", "type": "string" }, "encoding": { "$ref": "#/definitions/Encoding", "description": "A key-value mapping between encoding channels and definition of fields." }, "height": { "anyOf": [ { "type": "number" }, { "const": "container", "type": "string" }, { "$ref": "#/definitions/Step" } ], "description": "The height of a visualization.\n\n- For a plot with a continuous y-field, height should be a number.\n- For a plot with either a discrete y-field or no y-field, height can be either a number indicating a fixed height or an object in the form of `{step: number}` defining the height per discrete step. (No y-field is equivalent to having one discrete step.)\n- To enable responsive sizing on height, it should be set to `\"container\"`.\n\n__Default value:__ Based on `config.view.continuousHeight` for a plot with a continuous y-field and `config.view.discreteHeight` otherwise.\n\n__Note:__ For plots with [`row` and `column` channels](https://vega.github.io/vega-lite/docs/encoding.html#facet), this represents the height of a single view and the `\"container\"` option cannot be used.\n\n__See also:__ [`height`](https://vega.github.io/vega-lite/docs/size.html) documentation." }, "mark": { "$ref": "#/definitions/AnyMark", "description": "A string describing the mark type (one of `\"bar\"`, `\"circle\"`, `\"square\"`, `\"tick\"`, `\"line\"`, `\"area\"`, `\"point\"`, `\"rule\"`, `\"geoshape\"`, and `\"text\"`) or a [mark definition object](https://vega.github.io/vega-lite/docs/mark.html#mark-def)." }, "name": { "description": "Name of the visualization for later reference.", "type": "string" }, "params": { "description": "An array of parameters that may either be simple variables, or more complex selections that map user input to data queries.", "items": { "$ref": "#/definitions/SelectionParameter" }, "type": "array" }, "projection": { "$ref": "#/definitions/Projection", "description": "An object defining properties of geographic projection, which will be applied to `shape` path for `\"geoshape\"` marks and to `latitude` and `\"longitude\"` channels for other marks." }, "title": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/TitleParams" } ], "description": "Title for the plot." }, "transform": { "description": "An array of data transformations such as filter and new field calculation.", "items": { "$ref": "#/definitions/Transform" }, "type": "array" }, "view": { "$ref": "#/definitions/ViewBackground", "description": "An object defining the view background's fill and stroke.\n\n__Default value:__ none (transparent)" }, "width": { "anyOf": [ { "type": "number" }, { "const": "container", "type": "string" }, { "$ref": "#/definitions/Step" } ], "description": "The width of a visualization.\n\n- For a plot with a continuous x-field, width should be a number.\n- For a plot with either a discrete x-field or no x-field, width can be either a number indicating a fixed width or an object in the form of `{step: number}` defining the width per discrete step. (No x-field is equivalent to having one discrete step.)\n- To enable responsive sizing on width, it should be set to `\"container\"`.\n\n__Default value:__ Based on `config.view.continuousWidth` for a plot with a continuous x-field and `config.view.discreteWidth` otherwise.\n\n__Note:__ For plots with [`row` and `column` channels](https://vega.github.io/vega-lite/docs/encoding.html#facet), this represents the width of a single view and the `\"container\"` option cannot be used.\n\n__See also:__ [`width`](https://vega.github.io/vega-lite/docs/size.html) documentation." } }, "required": [ "mark" ], "type": "object" }, "UrlData": { "additionalProperties": false, "properties": { "format": { "$ref": "#/definitions/DataFormat", "description": "An object that specifies the format for parsing the data." }, "name": { "description": "Provide a placeholder name and bind data at runtime.", "type": "string" }, "url": { "description": "An URL from which to load the data set. Use the `format.type` property to ensure the loaded data is correctly parsed.", "type": "string" } }, "required": [ "url" ], "type": "object" }, "UtcMultiTimeUnit": { "enum": [ "utcyearquarter", "utcyearquartermonth", "utcyearmonth", "utcyearmonthdate", "utcyearmonthdatehours", "utcyearmonthdatehoursminutes", "utcyearmonthdatehoursminutesseconds", "utcyearweek", "utcyearweekday", "utcyearweekdayhours", "utcyearweekdayhoursminutes", "utcyearweekdayhoursminutesseconds", "utcyeardayofyear", "utcquartermonth", "utcmonthdate", "utcmonthdatehours", "utcmonthdatehoursminutes", "utcmonthdatehoursminutesseconds", "utcweekday", "utcweeksdayhours", "utcweekdayhoursminutes", "utcweekdayhoursminutesseconds", "utcdayhours", "utcdayhoursminutes", "utcdayhoursminutesseconds", "utchoursminutes", "utchoursminutesseconds", "utcminutesseconds", "utcsecondsmilliseconds" ], "type": "string" }, "UtcSingleTimeUnit": { "enum": [ "utcyear", "utcquarter", "utcmonth", "utcweek", "utcday", "utcdayofyear", "utcdate", "utchours", "utcminutes", "utcseconds", "utcmilliseconds" ], "type": "string" }, "ValueDef<(number|\"width\"|\"height\"|ExprRef)>": { "additionalProperties": false, "description": "Definition object for a constant value (primitive value or gradient definition) of an encoding channel.", "properties": { "value": { "anyOf": [ { "type": "number" }, { "const": "width", "type": "string" }, { "const": "height", "type": "string" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "required": [ "value" ], "type": "object" }, "ValueDef": { "additionalProperties": false, "description": "Definition object for a constant value (primitive value or gradient definition) of an encoding channel.", "properties": { "value": { "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity).", "type": "number" } }, "required": [ "value" ], "type": "object" }, "ValueDefWithCondition": { "additionalProperties": false, "minProperties": 1, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Gradient|string|null|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." }, "value": { "anyOf": [ { "$ref": "#/definitions/Gradient" }, { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "ValueDefWithCondition": { "additionalProperties": false, "minProperties": 1, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." }, "value": { "anyOf": [ { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "ValueDefWithCondition": { "additionalProperties": false, "minProperties": 1, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." }, "value": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "ValueDefWithCondition": { "additionalProperties": false, "minProperties": 1, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(number[]|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(number[]|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." }, "value": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "ValueDefWithCondition,(string|null)>": { "additionalProperties": false, "minProperties": 1, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalMarkPropFieldOrDatumDef" }, { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(string|null|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." }, "value": { "anyOf": [ { "type": "string" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "ValueDefWithCondition": { "additionalProperties": false, "minProperties": 1, "properties": { "condition": { "anyOf": [ { "$ref": "#/definitions/ConditionalStringFieldDef" }, { "$ref": "#/definitions/ConditionalValueDef<(Text|ExprRef)>" }, { "items": { "$ref": "#/definitions/ConditionalValueDef<(Text|ExprRef)>" }, "type": "array" } ], "description": "A field definition or one or more value definition(s) with a parameter predicate." }, "value": { "anyOf": [ { "$ref": "#/definitions/Text" }, { "$ref": "#/definitions/ExprRef" } ], "description": "A constant value in visual domain (e.g., `\"red\"` / `\"#0099ff\"` / [gradient definition](https://vega.github.io/vega-lite/docs/types.html#gradient) for color, values between `0` to `1` for opacity)." } }, "type": "object" }, "VariableParameter": { "additionalProperties": false, "properties": { "bind": { "$ref": "#/definitions/Binding", "description": "Binds the parameter to an external input element such as a slider, selection list or radio button group." }, "expr": { "$ref": "#/definitions/Expr", "description": "An expression for the value of the parameter. This expression may include other parameters, in which case the parameter will automatically update in response to upstream parameter changes." }, "name": { "$ref": "#/definitions/ParameterName", "description": "A unique name for the variable parameter. Parameter names should be valid JavaScript identifiers: they should contain only alphanumeric characters (or \"$\", or \"_\") and may not start with a digit. Reserved keywords that may not be used as parameter names are \"datum\", \"event\", \"item\", and \"parent\"." }, "value": { "description": "The [initial value](http://vega.github.io/vega-lite/docs/value.html) of the parameter.\n\n__Default value:__ `undefined`" } }, "required": [ "name" ], "type": "object" }, "Vector10": { "items": { "type": "string" }, "maxItems": 10, "minItems": 10, "type": "array" }, "Vector12": { "items": { "type": "string" }, "maxItems": 12, "minItems": 12, "type": "array" }, "Vector2": { "items": { "$ref": "#/definitions/DateTime" }, "maxItems": 2, "minItems": 2, "type": "array" }, "Vector2>": { "items": { "$ref": "#/definitions/Vector2" }, "maxItems": 2, "minItems": 2, "type": "array" }, "Vector2": { "items": { "type": "boolean" }, "maxItems": 2, "minItems": 2, "type": "array" }, "Vector2": { "items": { "type": "number" }, "maxItems": 2, "minItems": 2, "type": "array" }, "Vector2": { "items": { "type": "string" }, "maxItems": 2, "minItems": 2, "type": "array" }, "Vector3": { "items": { "type": "number" }, "maxItems": 3, "minItems": 3, "type": "array" }, "Vector7": { "items": { "type": "string" }, "maxItems": 7, "minItems": 7, "type": "array" }, "ViewBackground": { "additionalProperties": false, "properties": { "cornerRadius": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles or arcs' corners.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cursor": { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the view. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, "fill": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The fill color.\n\n__Default value:__ `undefined`" }, "fillOpacity": { "anyOf": [ { "description": "The fill opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "opacity": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The overall opacity (value between [0,1]).\n\n__Default value:__ `0.7` for non-aggregate plots with `point`, `tick`, `circle`, or `square` marks or layered `bar` charts and `1` otherwise.", "maximum": 1, "minimum": 0 }, "stroke": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The stroke color.\n\n__Default value:__ `\"#ddd\"`" }, "strokeCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for line ending style. One of `\"butt\"`, `\"round\"`, or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "An array of alternating stroke, space lengths for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDashOffset": { "anyOf": [ { "description": "The offset (in pixels) into which to begin drawing with the stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeJoin": { "anyOf": [ { "$ref": "#/definitions/StrokeJoin", "description": "The stroke line join method. One of `\"miter\"`, `\"round\"` or `\"bevel\"`.\n\n__Default value:__ `\"miter\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeMiterLimit": { "anyOf": [ { "description": "The miter limit at which to bevel a line join.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOpacity": { "anyOf": [ { "description": "The stroke opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "The stroke width, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "style": { "anyOf": [ { "type": "string" }, { "items": { "type": "string" }, "type": "array" } ], "description": "A string or array of strings indicating the name of custom styles to apply to the view background. A style is a named collection of mark property defaults defined within the [style configuration](https://vega.github.io/vega-lite/docs/mark.html#style-config). If style is an array, later styles will override earlier styles.\n\n__Default value:__ `\"cell\"` __Note:__ Any specified view background properties will augment the default style." } }, "type": "object" }, "ViewConfig": { "additionalProperties": false, "properties": { "clip": { "description": "Whether the view should be clipped.", "type": "boolean" }, "continuousHeight": { "description": "The default height when the plot has a continuous y-field for x or latitude, or has arc marks.\n\n__Default value:__ `200`", "type": "number" }, "continuousWidth": { "description": "The default width when the plot has a continuous field for x or longitude, or has arc marks.\n\n__Default value:__ `200`", "type": "number" }, "cornerRadius": { "anyOf": [ { "description": "The radius in pixels of rounded rectangles or arcs' corners.\n\n__Default value:__ `0`", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "cursor": { "$ref": "#/definitions/Cursor", "description": "The mouse cursor used over the view. Any valid [CSS cursor type](https://developer.mozilla.org/en-US/docs/Web/CSS/cursor#Values) can be used." }, "discreteHeight": { "anyOf": [ { "type": "number" }, { "additionalProperties": false, "properties": { "step": { "type": "number" } }, "required": [ "step" ], "type": "object" } ], "description": "The default height when the plot has non arc marks and either a discrete y-field or no y-field. The height can be either a number indicating a fixed height or an object in the form of `{step: number}` defining the height per discrete step.\n\n__Default value:__ a step size based on `config.view.step`." }, "discreteWidth": { "anyOf": [ { "type": "number" }, { "additionalProperties": false, "properties": { "step": { "type": "number" } }, "required": [ "step" ], "type": "object" } ], "description": "The default width when the plot has non-arc marks and either a discrete x-field or no x-field. The width can be either a number indicating a fixed width or an object in the form of `{step: number}` defining the width per discrete step.\n\n__Default value:__ a step size based on `config.view.step`." }, "fill": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The fill color.\n\n__Default value:__ `undefined`" }, "fillOpacity": { "anyOf": [ { "description": "The fill opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "opacity": { "anyOf": [ { "type": "number" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The overall opacity (value between [0,1]).\n\n__Default value:__ `0.7` for non-aggregate plots with `point`, `tick`, `circle`, or `square` marks or layered `bar` charts and `1` otherwise.", "maximum": 1, "minimum": 0 }, "step": { "description": "Default step size for x-/y- discrete fields.", "type": "number" }, "stroke": { "anyOf": [ { "$ref": "#/definitions/Color" }, { "type": "null" }, { "$ref": "#/definitions/ExprRef" } ], "description": "The stroke color.\n\n__Default value:__ `\"#ddd\"`" }, "strokeCap": { "anyOf": [ { "$ref": "#/definitions/StrokeCap", "description": "The stroke cap for line ending style. One of `\"butt\"`, `\"round\"`, or `\"square\"`.\n\n__Default value:__ `\"butt\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDash": { "anyOf": [ { "description": "An array of alternating stroke, space lengths for creating dashed or dotted lines.", "items": { "type": "number" }, "type": "array" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeDashOffset": { "anyOf": [ { "description": "The offset (in pixels) into which to begin drawing with the stroke dash array.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeJoin": { "anyOf": [ { "$ref": "#/definitions/StrokeJoin", "description": "The stroke line join method. One of `\"miter\"`, `\"round\"` or `\"bevel\"`.\n\n__Default value:__ `\"miter\"`" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeMiterLimit": { "anyOf": [ { "description": "The miter limit at which to bevel a line join.", "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeOpacity": { "anyOf": [ { "description": "The stroke opacity (value between [0,1]).\n\n__Default value:__ `1`", "maximum": 1, "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] }, "strokeWidth": { "anyOf": [ { "description": "The stroke width, in pixels.", "minimum": 0, "type": "number" }, { "$ref": "#/definitions/ExprRef" } ] } }, "type": "object" }, "WindowEventType": { "anyOf": [ { "$ref": "#/definitions/EventType" }, { "type": "string" } ] }, "WindowFieldDef": { "additionalProperties": false, "properties": { "as": { "$ref": "#/definitions/FieldName", "description": "The output name for the window operation." }, "field": { "$ref": "#/definitions/FieldName", "description": "The data field for which to compute the aggregate or window function. This can be omitted for window functions that do not operate over a field such as `\"count\"`, `\"rank\"`, `\"dense_rank\"`." }, "op": { "anyOf": [ { "$ref": "#/definitions/AggregateOp" }, { "$ref": "#/definitions/WindowOnlyOp" } ], "description": "The window or aggregation operation to apply within a window (e.g., `\"rank\"`, `\"lead\"`, `\"sum\"`, `\"average\"` or `\"count\"`). See the list of all supported operations [here](https://vega.github.io/vega-lite/docs/window.html#ops)." }, "param": { "description": "Parameter values for the window functions. Parameter values can be omitted for operations that do not accept a parameter.\n\nSee the list of all supported operations and their parameters [here](https://vega.github.io/vega-lite/docs/transforms/window.html).", "type": "number" } }, "required": [ "op", "as" ], "type": "object" }, "WindowOnlyOp": { "enum": [ "row_number", "rank", "dense_rank", "percent_rank", "cume_dist", "ntile", "lag", "lead", "first_value", "last_value", "nth_value" ], "type": "string" }, "WindowTransform": { "additionalProperties": false, "properties": { "frame": { "description": "A frame specification as a two-element array indicating how the sliding window should proceed. The array entries should either be a number indicating the offset from the current data object, or null to indicate unbounded rows preceding or following the current data object. The default value is `[null, 0]`, indicating that the sliding window includes the current object and all preceding objects. The value `[-5, 5]` indicates that the window should include five objects preceding and five objects following the current object. Finally, `[null, null]` indicates that the window frame should always include all data objects. If you this frame and want to assign the same value to add objects, you can use the simpler [join aggregate transform](https://vega.github.io/vega-lite/docs/joinaggregate.html). The only operators affected are the aggregation operations and the `first_value`, `last_value`, and `nth_value` window operations. The other window operations are not affected by this.\n\n__Default value:__: `[null, 0]` (includes the current object and all preceding objects)", "items": { "type": [ "null", "number" ] }, "type": "array" }, "groupby": { "description": "The data fields for partitioning the data objects into separate windows. If unspecified, all data points will be in a single window.", "items": { "$ref": "#/definitions/FieldName" }, "type": "array" }, "ignorePeers": { "description": "Indicates if the sliding window frame should ignore peer values (data that are considered identical by the sort criteria). The default is false, causing the window frame to expand to include all peer values. If set to true, the window frame will be defined by offset values only. This setting only affects those operations that depend on the window frame, namely aggregation operations and the first_value, last_value, and nth_value window operations.\n\n__Default value:__ `false`", "type": "boolean" }, "sort": { "description": "A sort field definition for sorting data objects within a window. If two data objects are considered equal by the comparator, they are considered \"peer\" values of equal rank. If sort is not specified, the order is undefined: data objects are processed in the order they are observed and none are considered peers (the ignorePeers parameter is ignored and treated as if set to `true`).", "items": { "$ref": "#/definitions/SortField" }, "type": "array" }, "window": { "description": "The definition of the fields in the window, and what calculations to use.", "items": { "$ref": "#/definitions/WindowFieldDef" }, "type": "array" } }, "required": [ "window" ], "type": "object" } } } ================================================ FILE: docs/_static/js/vega-embed@5.js ================================================ !function(e,t){"object"==typeof exports&&"undefined"!=typeof module?module.exports=t(require("vega"),require("vega-lite")):"function"==typeof define&&define.amd?define(["vega","vega-lite"],t):(e=e||self).vegaEmbed=t(e.vega,e.vegaLite)}(this,(function(e,t){"use strict";var n="http://www.w3.org/1999/xhtml",r={svg:"http://www.w3.org/2000/svg",xhtml:n,xlink:"http://www.w3.org/1999/xlink",xml:"http://www.w3.org/XML/1998/namespace",xmlns:"http://www.w3.org/2000/xmlns/"};function i(e){var t=e+="",n=t.indexOf(":");return n>=0&&"xmlns"!==(t=e.slice(0,n))&&(e=e.slice(n+1)),r.hasOwnProperty(t)?{space:r[t],local:e}:e}function o(e){return function(){var t=this.ownerDocument,r=this.namespaceURI;return r===n&&t.documentElement.namespaceURI===n?t.createElement(e):t.createElementNS(r,e)}}function a(e){return function(){return this.ownerDocument.createElementNS(e.space,e.local)}}function s(e){var t=i(e);return(t.local?a:o)(t)}function l(){}function c(e){return null==e?l:function(){return this.querySelector(e)}}function u(){return[]}function f(e){return new Array(e.length)}function p(e,t){this.ownerDocument=e.ownerDocument,this.namespaceURI=e.namespaceURI,this._next=null,this._parent=e,this.__data__=t}p.prototype={constructor:p,appendChild:function(e){return this._parent.insertBefore(e,this._next)},insertBefore:function(e,t){return this._parent.insertBefore(e,t)},querySelector:function(e){return this._parent.querySelector(e)},querySelectorAll:function(e){return this._parent.querySelectorAll(e)}};var h="$";function d(e,t,n,r,i,o){for(var a,s=0,l=t.length,c=o.length;st?1:e>=t?0:NaN}function v(e){return function(){this.removeAttribute(e)}}function E(e){return function(){this.removeAttributeNS(e.space,e.local)}}function y(e,t){return function(){this.setAttribute(e,t)}}function b(e,t){return function(){this.setAttributeNS(e.space,e.local,t)}}function I(e,t){return function(){var n=t.apply(this,arguments);null==n?this.removeAttribute(e):this.setAttribute(e,n)}}function O(e,t){return function(){var n=t.apply(this,arguments);null==n?this.removeAttributeNS(e.space,e.local):this.setAttributeNS(e.space,e.local,n)}}function R(e){return e.ownerDocument&&e.ownerDocument.defaultView||e.document&&e||e.defaultView}function w(e){return function(){this.style.removeProperty(e)}}function N(e,t,n){return function(){this.style.setProperty(e,t,n)}}function A(e,t,n){return function(){var r=t.apply(this,arguments);null==r?this.style.removeProperty(e):this.style.setProperty(e,r,n)}}function S(e){return function(){delete this[e]}}function L(e,t){return function(){this[e]=t}}function T(e,t){return function(){var n=t.apply(this,arguments);null==n?delete this[e]:this[e]=n}}function x(e){return e.trim().split(/^|\s+/)}function _(e){return e.classList||new C(e)}function C(e){this._node=e,this._names=x(e.getAttribute("class")||"")}function P(e,t){for(var n=_(e),r=-1,i=t.length;++r=0&&(this._names.splice(t,1),this._node.setAttribute("class",this._names.join(" ")))},contains:function(e){return this._names.indexOf(e)>=0}};var J={},Z=null;"undefined"!=typeof document&&("onmouseenter"in document.documentElement||(J={mouseenter:"mouseover",mouseleave:"mouseout"}));function K(e,t,n){return e=Q(e,t,n),function(t){var n=t.relatedTarget;n&&(n===this||8&n.compareDocumentPosition(this))||e.call(this,t)}}function Q(e,t,n){return function(r){var i=Z;Z=r;try{e.call(this,this.__data__,t,n)}finally{Z=i}}}function ee(e){return function(){var t=this.__on;if(t){for(var n,r=0,i=-1,o=t.length;r=R&&(R=O+1);!(I=y[R])&&++R=0;)(r=i[o])&&(a&&4^r.compareDocumentPosition(a)&&a.parentNode.insertBefore(r,a),a=r);return this},sort:function(e){function t(t,n){return t&&n?e(t.__data__,n.__data__):!t-!n}e||(e=m);for(var n=this._groups,r=n.length,i=new Array(r),o=0;o1?this.each((null==t?w:"function"==typeof t?A:N)(e,t,null==n?"":n)):function(e,t){return e.style.getPropertyValue(t)||R(e).getComputedStyle(e,null).getPropertyValue(t)}(this.node(),e)},property:function(e,t){return arguments.length>1?this.each((null==t?S:"function"==typeof t?T:L)(e,t)):this.node()[e]},classed:function(e,t){var n=x(e+"");if(arguments.length<2){for(var r=_(this.node()),i=-1,o=n.length;++i=0&&(t=e.slice(n+1),e=e.slice(0,n)),{type:e,name:t}}))}(e+""),a=o.length;if(!(arguments.length<2)){for(s=t?te:ee,null==n&&(n=!1),r=0;r0)return[m,n+c.join(",\n"+d),s].join("\n"+o)}return v}(e,"",0)};function be(e,t){return e(t={exports:{}},t.exports),t.exports}var Ie,Oe=be((function(e,t){var n;t=e.exports=p,n="object"==typeof process&&process.env&&process.env.NODE_DEBUG&&/\bsemver\b/i.test(process.env.NODE_DEBUG)?function(){var e=Array.prototype.slice.call(arguments,0);e.unshift("SEMVER"),console.log.apply(console,e)}:function(){},t.SEMVER_SPEC_VERSION="2.0.0";var r=256,i=Number.MAX_SAFE_INTEGER||9007199254740991,o=t.re=[],a=t.src=[],s=t.tokens={},l=0;function c(e){s[e]=l++}c("NUMERICIDENTIFIER"),a[s.NUMERICIDENTIFIER]="0|[1-9]\\d*",c("NUMERICIDENTIFIERLOOSE"),a[s.NUMERICIDENTIFIERLOOSE]="[0-9]+",c("NONNUMERICIDENTIFIER"),a[s.NONNUMERICIDENTIFIER]="\\d*[a-zA-Z-][a-zA-Z0-9-]*",c("MAINVERSION"),a[s.MAINVERSION]="("+a[s.NUMERICIDENTIFIER]+")\\.("+a[s.NUMERICIDENTIFIER]+")\\.("+a[s.NUMERICIDENTIFIER]+")",c("MAINVERSIONLOOSE"),a[s.MAINVERSIONLOOSE]="("+a[s.NUMERICIDENTIFIERLOOSE]+")\\.("+a[s.NUMERICIDENTIFIERLOOSE]+")\\.("+a[s.NUMERICIDENTIFIERLOOSE]+")",c("PRERELEASEIDENTIFIER"),a[s.PRERELEASEIDENTIFIER]="(?:"+a[s.NUMERICIDENTIFIER]+"|"+a[s.NONNUMERICIDENTIFIER]+")",c("PRERELEASEIDENTIFIERLOOSE"),a[s.PRERELEASEIDENTIFIERLOOSE]="(?:"+a[s.NUMERICIDENTIFIERLOOSE]+"|"+a[s.NONNUMERICIDENTIFIER]+")",c("PRERELEASE"),a[s.PRERELEASE]="(?:-("+a[s.PRERELEASEIDENTIFIER]+"(?:\\."+a[s.PRERELEASEIDENTIFIER]+")*))",c("PRERELEASELOOSE"),a[s.PRERELEASELOOSE]="(?:-?("+a[s.PRERELEASEIDENTIFIERLOOSE]+"(?:\\."+a[s.PRERELEASEIDENTIFIERLOOSE]+")*))",c("BUILDIDENTIFIER"),a[s.BUILDIDENTIFIER]="[0-9A-Za-z-]+",c("BUILD"),a[s.BUILD]="(?:\\+("+a[s.BUILDIDENTIFIER]+"(?:\\."+a[s.BUILDIDENTIFIER]+")*))",c("FULL"),c("FULLPLAIN"),a[s.FULLPLAIN]="v?"+a[s.MAINVERSION]+a[s.PRERELEASE]+"?"+a[s.BUILD]+"?",a[s.FULL]="^"+a[s.FULLPLAIN]+"$",c("LOOSEPLAIN"),a[s.LOOSEPLAIN]="[v=\\s]*"+a[s.MAINVERSIONLOOSE]+a[s.PRERELEASELOOSE]+"?"+a[s.BUILD]+"?",c("LOOSE"),a[s.LOOSE]="^"+a[s.LOOSEPLAIN]+"$",c("GTLT"),a[s.GTLT]="((?:<|>)?=?)",c("XRANGEIDENTIFIERLOOSE"),a[s.XRANGEIDENTIFIERLOOSE]=a[s.NUMERICIDENTIFIERLOOSE]+"|x|X|\\*",c("XRANGEIDENTIFIER"),a[s.XRANGEIDENTIFIER]=a[s.NUMERICIDENTIFIER]+"|x|X|\\*",c("XRANGEPLAIN"),a[s.XRANGEPLAIN]="[v=\\s]*("+a[s.XRANGEIDENTIFIER]+")(?:\\.("+a[s.XRANGEIDENTIFIER]+")(?:\\.("+a[s.XRANGEIDENTIFIER]+")(?:"+a[s.PRERELEASE]+")?"+a[s.BUILD]+"?)?)?",c("XRANGEPLAINLOOSE"),a[s.XRANGEPLAINLOOSE]="[v=\\s]*("+a[s.XRANGEIDENTIFIERLOOSE]+")(?:\\.("+a[s.XRANGEIDENTIFIERLOOSE]+")(?:\\.("+a[s.XRANGEIDENTIFIERLOOSE]+")(?:"+a[s.PRERELEASELOOSE]+")?"+a[s.BUILD]+"?)?)?",c("XRANGE"),a[s.XRANGE]="^"+a[s.GTLT]+"\\s*"+a[s.XRANGEPLAIN]+"$",c("XRANGELOOSE"),a[s.XRANGELOOSE]="^"+a[s.GTLT]+"\\s*"+a[s.XRANGEPLAINLOOSE]+"$",c("COERCE"),a[s.COERCE]="(^|[^\\d])(\\d{1,16})(?:\\.(\\d{1,16}))?(?:\\.(\\d{1,16}))?(?:$|[^\\d])",c("COERCERTL"),o[s.COERCERTL]=new RegExp(a[s.COERCE],"g"),c("LONETILDE"),a[s.LONETILDE]="(?:~>?)",c("TILDETRIM"),a[s.TILDETRIM]="(\\s*)"+a[s.LONETILDE]+"\\s+",o[s.TILDETRIM]=new RegExp(a[s.TILDETRIM],"g");c("TILDE"),a[s.TILDE]="^"+a[s.LONETILDE]+a[s.XRANGEPLAIN]+"$",c("TILDELOOSE"),a[s.TILDELOOSE]="^"+a[s.LONETILDE]+a[s.XRANGEPLAINLOOSE]+"$",c("LONECARET"),a[s.LONECARET]="(?:\\^)",c("CARETTRIM"),a[s.CARETTRIM]="(\\s*)"+a[s.LONECARET]+"\\s+",o[s.CARETTRIM]=new RegExp(a[s.CARETTRIM],"g");c("CARET"),a[s.CARET]="^"+a[s.LONECARET]+a[s.XRANGEPLAIN]+"$",c("CARETLOOSE"),a[s.CARETLOOSE]="^"+a[s.LONECARET]+a[s.XRANGEPLAINLOOSE]+"$",c("COMPARATORLOOSE"),a[s.COMPARATORLOOSE]="^"+a[s.GTLT]+"\\s*("+a[s.LOOSEPLAIN]+")$|^$",c("COMPARATOR"),a[s.COMPARATOR]="^"+a[s.GTLT]+"\\s*("+a[s.FULLPLAIN]+")$|^$",c("COMPARATORTRIM"),a[s.COMPARATORTRIM]="(\\s*)"+a[s.GTLT]+"\\s*("+a[s.LOOSEPLAIN]+"|"+a[s.XRANGEPLAIN]+")",o[s.COMPARATORTRIM]=new RegExp(a[s.COMPARATORTRIM],"g");c("HYPHENRANGE"),a[s.HYPHENRANGE]="^\\s*("+a[s.XRANGEPLAIN]+")\\s+-\\s+("+a[s.XRANGEPLAIN]+")\\s*$",c("HYPHENRANGELOOSE"),a[s.HYPHENRANGELOOSE]="^\\s*("+a[s.XRANGEPLAINLOOSE]+")\\s+-\\s+("+a[s.XRANGEPLAINLOOSE]+")\\s*$",c("STAR"),a[s.STAR]="(<|>)?=?\\s*\\*";for(var u=0;ur)return null;if(!(t.loose?o[s.LOOSE]:o[s.FULL]).test(e))return null;try{return new p(e,t)}catch(e){return null}}function p(e,t){if(t&&"object"==typeof t||(t={loose:!!t,includePrerelease:!1}),e instanceof p){if(e.loose===t.loose)return e;e=e.version}else if("string"!=typeof e)throw new TypeError("Invalid Version: "+e);if(e.length>r)throw new TypeError("version is longer than "+r+" characters");if(!(this instanceof p))return new p(e,t);n("SemVer",e,t),this.options=t,this.loose=!!t.loose;var a=e.trim().match(t.loose?o[s.LOOSE]:o[s.FULL]);if(!a)throw new TypeError("Invalid Version: "+e);if(this.raw=e,this.major=+a[1],this.minor=+a[2],this.patch=+a[3],this.major>i||this.major<0)throw new TypeError("Invalid major version");if(this.minor>i||this.minor<0)throw new TypeError("Invalid minor version");if(this.patch>i||this.patch<0)throw new TypeError("Invalid patch version");a[4]?this.prerelease=a[4].split(".").map((function(e){if(/^[0-9]+$/.test(e)){var t=+e;if(t>=0&&t=0;)"number"==typeof this.prerelease[n]&&(this.prerelease[n]++,n=-2);-1===n&&this.prerelease.push(0)}t&&(this.prerelease[0]===t?isNaN(this.prerelease[1])&&(this.prerelease=[t,0]):this.prerelease=[t,0]);break;default:throw new Error("invalid increment argument: "+e)}return this.format(),this.raw=this.version,this},t.inc=function(e,t,n,r){"string"==typeof n&&(r=n,n=void 0);try{return new p(e,n).inc(t,r).version}catch(e){return null}},t.diff=function(e,t){if(E(e,t))return null;var n=f(e),r=f(t),i="";if(n.prerelease.length||r.prerelease.length){i="pre";var o="prerelease"}for(var a in n)if(("major"===a||"minor"===a||"patch"===a)&&n[a]!==r[a])return i+a;return o},t.compareIdentifiers=d;var h=/^[0-9]+$/;function d(e,t){var n=h.test(e),r=h.test(t);return n&&r&&(e=+e,t=+t),e===t?0:n&&!r?-1:r&&!n?1:e0}function v(e,t,n){return g(e,t,n)<0}function E(e,t,n){return 0===g(e,t,n)}function y(e,t,n){return 0!==g(e,t,n)}function b(e,t,n){return g(e,t,n)>=0}function I(e,t,n){return g(e,t,n)<=0}function O(e,t,n,r){switch(t){case"===":return"object"==typeof e&&(e=e.version),"object"==typeof n&&(n=n.version),e===n;case"!==":return"object"==typeof e&&(e=e.version),"object"==typeof n&&(n=n.version),e!==n;case"":case"=":case"==":return E(e,n,r);case"!=":return y(e,n,r);case">":return m(e,n,r);case">=":return b(e,n,r);case"<":return v(e,n,r);case"<=":return I(e,n,r);default:throw new TypeError("Invalid operator: "+t)}}function R(e,t){if(t&&"object"==typeof t||(t={loose:!!t,includePrerelease:!1}),e instanceof R){if(e.loose===!!t.loose)return e;e=e.value}if(!(this instanceof R))return new R(e,t);n("comparator",e,t),this.options=t,this.loose=!!t.loose,this.parse(e),this.semver===w?this.value="":this.value=this.operator+this.semver.version,n("comp",this)}t.rcompareIdentifiers=function(e,t){return d(t,e)},t.major=function(e,t){return new p(e,t).major},t.minor=function(e,t){return new p(e,t).minor},t.patch=function(e,t){return new p(e,t).patch},t.compare=g,t.compareLoose=function(e,t){return g(e,t,!0)},t.compareBuild=function(e,t,n){var r=new p(e,n),i=new p(t,n);return r.compare(i)||r.compareBuild(i)},t.rcompare=function(e,t,n){return g(t,e,n)},t.sort=function(e,n){return e.sort((function(e,r){return t.compareBuild(e,r,n)}))},t.rsort=function(e,n){return e.sort((function(e,r){return t.compareBuild(r,e,n)}))},t.gt=m,t.lt=v,t.eq=E,t.neq=y,t.gte=b,t.lte=I,t.cmp=O,t.Comparator=R;var w={};function N(e,t){if(t&&"object"==typeof t||(t={loose:!!t,includePrerelease:!1}),e instanceof N)return e.loose===!!t.loose&&e.includePrerelease===!!t.includePrerelease?e:new N(e.raw,t);if(e instanceof R)return new N(e.value,t);if(!(this instanceof N))return new N(e,t);if(this.options=t,this.loose=!!t.loose,this.includePrerelease=!!t.includePrerelease,this.raw=e,this.set=e.split(/\s*\|\|\s*/).map((function(e){return this.parseRange(e.trim())}),this).filter((function(e){return e.length})),!this.set.length)throw new TypeError("Invalid SemVer Range: "+e);this.format()}function A(e,t){for(var n=!0,r=e.slice(),i=r.pop();n&&r.length;)n=r.every((function(e){return i.intersects(e,t)})),i=r.pop();return n}function S(e){return!e||"x"===e.toLowerCase()||"*"===e}function L(e,t,n,r,i,o,a,s,l,c,u,f,p){return((t=S(n)?"":S(r)?">="+n+".0.0":S(i)?">="+n+"."+r+".0":">="+t)+" "+(s=S(l)?"":S(c)?"<"+(+l+1)+".0.0":S(u)?"<"+l+"."+(+c+1)+".0":f?"<="+l+"."+c+"."+u+"-"+f:"<="+s)).trim()}function T(e,t,r){for(var i=0;i0){var o=e[i].semver;if(o.major===t.major&&o.minor===t.minor&&o.patch===t.patch)return!0}return!1}return!0}function x(e,t,n){try{t=new N(t,n)}catch(e){return!1}return t.test(e)}function _(e,t,n,r){var i,o,a,s,l;switch(e=new p(e,r),t=new N(t,r),n){case">":i=m,o=I,a=v,s=">",l=">=";break;case"<":i=v,o=b,a=m,s="<",l="<=";break;default:throw new TypeError('Must provide a hilo val of "<" or ">"')}if(x(e,t,r))return!1;for(var c=0;c=0.0.0")),f=f||e,h=h||e,i(e.semver,f.semver,r)?f=e:a(e.semver,h.semver,r)&&(h=e)})),f.operator===s||f.operator===l)return!1;if((!h.operator||h.operator===s)&&o(e,h.semver))return!1;if(h.operator===l&&a(e,h.semver))return!1}return!0}R.prototype.parse=function(e){var t=this.options.loose?o[s.COMPARATORLOOSE]:o[s.COMPARATOR],n=e.match(t);if(!n)throw new TypeError("Invalid comparator: "+e);this.operator=void 0!==n[1]?n[1]:"","="===this.operator&&(this.operator=""),n[2]?this.semver=new p(n[2],this.options.loose):this.semver=w},R.prototype.toString=function(){return this.value},R.prototype.test=function(e){if(n("Comparator.test",e,this.options.loose),this.semver===w||e===w)return!0;if("string"==typeof e)try{e=new p(e,this.options)}catch(e){return!1}return O(e,this.operator,this.semver,this.options)},R.prototype.intersects=function(e,t){if(!(e instanceof R))throw new TypeError("a Comparator is required");var n;if(t&&"object"==typeof t||(t={loose:!!t,includePrerelease:!1}),""===this.operator)return""===this.value||(n=new N(e.value,t),x(this.value,n,t));if(""===e.operator)return""===e.value||(n=new N(this.value,t),x(e.semver,n,t));var r=!(">="!==this.operator&&">"!==this.operator||">="!==e.operator&&">"!==e.operator),i=!("<="!==this.operator&&"<"!==this.operator||"<="!==e.operator&&"<"!==e.operator),o=this.semver.version===e.semver.version,a=!(">="!==this.operator&&"<="!==this.operator||">="!==e.operator&&"<="!==e.operator),s=O(this.semver,"<",e.semver,t)&&(">="===this.operator||">"===this.operator)&&("<="===e.operator||"<"===e.operator),l=O(this.semver,">",e.semver,t)&&("<="===this.operator||"<"===this.operator)&&(">="===e.operator||">"===e.operator);return r||i||o&&a||s||l},t.Range=N,N.prototype.format=function(){return this.range=this.set.map((function(e){return e.join(" ").trim()})).join("||").trim(),this.range},N.prototype.toString=function(){return this.range},N.prototype.parseRange=function(e){var t=this.options.loose;e=e.trim();var r=t?o[s.HYPHENRANGELOOSE]:o[s.HYPHENRANGE];e=e.replace(r,L),n("hyphen replace",e),e=e.replace(o[s.COMPARATORTRIM],"$1$2$3"),n("comparator trim",e,o[s.COMPARATORTRIM]),e=(e=(e=e.replace(o[s.TILDETRIM],"$1~")).replace(o[s.CARETTRIM],"$1^")).split(/\s+/).join(" ");var i=t?o[s.COMPARATORLOOSE]:o[s.COMPARATOR],a=e.split(" ").map((function(e){return function(e,t){return n("comp",e,t),e=function(e,t){return e.trim().split(/\s+/).map((function(e){return function(e,t){n("caret",e,t);var r=t.loose?o[s.CARETLOOSE]:o[s.CARET];return e.replace(r,(function(t,r,i,o,a){var s;return n("caret",e,t,r,i,o,a),S(r)?s="":S(i)?s=">="+r+".0.0 <"+(+r+1)+".0.0":S(o)?s="0"===r?">="+r+"."+i+".0 <"+r+"."+(+i+1)+".0":">="+r+"."+i+".0 <"+(+r+1)+".0.0":a?(n("replaceCaret pr",a),s="0"===r?"0"===i?">="+r+"."+i+"."+o+"-"+a+" <"+r+"."+i+"."+(+o+1):">="+r+"."+i+"."+o+"-"+a+" <"+r+"."+(+i+1)+".0":">="+r+"."+i+"."+o+"-"+a+" <"+(+r+1)+".0.0"):(n("no pr"),s="0"===r?"0"===i?">="+r+"."+i+"."+o+" <"+r+"."+i+"."+(+o+1):">="+r+"."+i+"."+o+" <"+r+"."+(+i+1)+".0":">="+r+"."+i+"."+o+" <"+(+r+1)+".0.0"),n("caret return",s),s}))}(e,t)})).join(" ")}(e,t),n("caret",e),e=function(e,t){return e.trim().split(/\s+/).map((function(e){return function(e,t){var r=t.loose?o[s.TILDELOOSE]:o[s.TILDE];return e.replace(r,(function(t,r,i,o,a){var s;return n("tilde",e,t,r,i,o,a),S(r)?s="":S(i)?s=">="+r+".0.0 <"+(+r+1)+".0.0":S(o)?s=">="+r+"."+i+".0 <"+r+"."+(+i+1)+".0":a?(n("replaceTilde pr",a),s=">="+r+"."+i+"."+o+"-"+a+" <"+r+"."+(+i+1)+".0"):s=">="+r+"."+i+"."+o+" <"+r+"."+(+i+1)+".0",n("tilde return",s),s}))}(e,t)})).join(" ")}(e,t),n("tildes",e),e=function(e,t){return n("replaceXRanges",e,t),e.split(/\s+/).map((function(e){return function(e,t){e=e.trim();var r=t.loose?o[s.XRANGELOOSE]:o[s.XRANGE];return e.replace(r,(function(r,i,o,a,s,l){n("xRange",e,r,i,o,a,s,l);var c=S(o),u=c||S(a),f=u||S(s),p=f;return"="===i&&p&&(i=""),l=t.includePrerelease?"-0":"",c?r=">"===i||"<"===i?"<0.0.0-0":"*":i&&p?(u&&(a=0),s=0,">"===i?(i=">=",u?(o=+o+1,a=0,s=0):(a=+a+1,s=0)):"<="===i&&(i="<",u?o=+o+1:a=+a+1),r=i+o+"."+a+"."+s+l):u?r=">="+o+".0.0"+l+" <"+(+o+1)+".0.0"+l:f&&(r=">="+o+"."+a+".0"+l+" <"+o+"."+(+a+1)+".0"+l),n("xRange return",r),r}))}(e,t)})).join(" ")}(e,t),n("xrange",e),e=function(e,t){return n("replaceStars",e,t),e.trim().replace(o[s.STAR],"")}(e,t),n("stars",e),e}(e,this.options)}),this).join(" ").split(/\s+/);return this.options.loose&&(a=a.filter((function(e){return!!e.match(i)}))),a=a.map((function(e){return new R(e,this.options)}),this)},N.prototype.intersects=function(e,t){if(!(e instanceof N))throw new TypeError("a Range is required");return this.set.some((function(n){return A(n,t)&&e.set.some((function(e){return A(e,t)&&n.every((function(n){return e.every((function(e){return n.intersects(e,t)}))}))}))}))},t.toComparators=function(e,t){return new N(e,t).set.map((function(e){return e.map((function(e){return e.value})).join(" ").trim().split(" ")}))},N.prototype.test=function(e){if(!e)return!1;if("string"==typeof e)try{e=new p(e,this.options)}catch(e){return!1}for(var t=0;t":0===t.prerelease.length?t.patch++:t.prerelease.push(0),t.raw=t.format();case"":case">=":n&&!m(n,t)||(n=t);break;case"<":case"<=":break;default:throw new Error("Unexpected operation: "+e.operator)}}))}if(n&&e.test(n))return n;return null},t.validRange=function(e,t){try{return new N(e,t).range||"*"}catch(e){return null}},t.ltr=function(e,t,n){return _(e,t,"<",n)},t.gtr=function(e,t,n){return _(e,t,">",n)},t.outside=_,t.prerelease=function(e,t){var n=f(e,t);return n&&n.prerelease.length?n.prerelease:null},t.intersects=function(e,t,n){return e=new N(e,n),t=new N(t,n),e.intersects(t)},t.coerce=function(e,t){if(e instanceof p)return e;"number"==typeof e&&(e=String(e));if("string"!=typeof e)return null;var n=null;if((t=t||{}).rtl){for(var r;(r=o[s.COERCERTL].exec(e))&&(!n||n.index+n[0].length!==e.length);)n&&r.index+r[0].length===n.index+n[0].length||(n=r),o[s.COERCERTL].lastIndex=r.index+r[1].length+r[2].length;o[s.COERCERTL].lastIndex=-1}else n=e.match(o[s.COERCE]);if(null===n)return null;return f(n[2]+"."+(n[3]||"0")+"."+(n[4]||"0"),t)}})),Re=(Oe.SEMVER_SPEC_VERSION,Oe.re,Oe.src,Oe.tokens,Oe.parse,Oe.valid,Oe.clean,Oe.SemVer,Oe.inc,Oe.diff,Oe.compareIdentifiers,Oe.rcompareIdentifiers,Oe.major,Oe.minor,Oe.patch,Oe.compare,Oe.compareLoose,Oe.compareBuild,Oe.rcompare,Oe.sort,Oe.rsort,Oe.gt,Oe.lt,Oe.eq,Oe.neq,Oe.gte,Oe.lte,Oe.cmp,Oe.Comparator,Oe.Range,Oe.toComparators,Oe.satisfies),we=(Oe.maxSatisfying,Oe.minSatisfying,Oe.minVersion,Oe.validRange,Oe.ltr,Oe.gtr,Oe.outside,Oe.prerelease,Oe.intersects,Oe.coerce,be((function(e,t){Object.defineProperty(t,"__esModule",{value:!0}),t.default=function(e){var t=/\/schema\/([\w-]+)\/([\w\.\-]+)\.json$/g.exec(e).slice(1,3);return{library:t[0],version:t[1]}}}))),Ne=(Ie=we)&&Ie.__esModule&&Object.prototype.hasOwnProperty.call(Ie,"default")?Ie.default:Ie;const Ae={background:"#333",title:{color:"#fff"},style:{"guide-label":{fill:"#fff"},"guide-title":{fill:"#fff"}},axis:{domainColor:"#fff",gridColor:"#888",tickColor:"#fff"}},Se={background:"#fff",arc:{fill:"#4572a7"},area:{fill:"#4572a7"},line:{stroke:"#4572a7",strokeWidth:2},path:{stroke:"#4572a7"},rect:{fill:"#4572a7"},shape:{stroke:"#4572a7"},symbol:{fill:"#4572a7",strokeWidth:1.5,size:50},axis:{bandPosition:.5,grid:!0,gridColor:"#000000",gridOpacity:1,gridWidth:.5,labelPadding:10,tickSize:5,tickWidth:.5},axisBand:{grid:!1,tickExtra:!0},legend:{labelBaseline:"middle",labelFontSize:11,symbolSize:50,symbolType:"square"},range:{category:["#4572a7","#aa4643","#8aa453","#71598e","#4598ae","#d98445","#94aace","#d09393","#b9cc98","#a99cbc"]}},Le={arc:{fill:"#30a2da"},area:{fill:"#30a2da"},axis:{domainColor:"#cbcbcb",grid:!0,gridColor:"#cbcbcb",gridWidth:1,labelColor:"#999",labelFontSize:10,titleColor:"#333",tickColor:"#cbcbcb",tickSize:10,titleFontSize:14,titlePadding:10,labelPadding:4},axisBand:{grid:!1},background:"#f0f0f0",group:{fill:"#f0f0f0"},legend:{labelColor:"#333",labelFontSize:11,padding:1,symbolSize:30,symbolType:"square",titleColor:"#333",titleFontSize:14,titlePadding:10},line:{stroke:"#30a2da",strokeWidth:2},path:{stroke:"#30a2da",strokeWidth:.5},rect:{fill:"#30a2da"},range:{category:["#30a2da","#fc4f30","#e5ae38","#6d904f","#8b8b8b","#b96db8","#ff9e27","#56cc60","#52d2ca","#52689e","#545454","#9fe4f8"],diverging:["#cc0020","#e77866","#f6e7e1","#d6e8ed","#91bfd9","#1d78b5"],heatmap:["#d6e8ed","#cee0e5","#91bfd9","#549cc6","#1d78b5"]},point:{filled:!0,shape:"circle"},shape:{stroke:"#30a2da"},style:{bar:{binSpacing:2,fill:"#30a2da",stroke:null}},title:{anchor:"start",fontSize:24,fontWeight:600,offset:20}},Te={group:{fill:"#e5e5e5"},arc:{fill:"#000"},area:{fill:"#000"},line:{stroke:"#000"},path:{stroke:"#000"},rect:{fill:"#000"},shape:{stroke:"#000"},symbol:{fill:"#000",size:40},axis:{domain:!1,grid:!0,gridColor:"#FFFFFF",gridOpacity:1,labelColor:"#7F7F7F",labelPadding:4,tickColor:"#7F7F7F",tickSize:5.67,titleFontSize:16,titleFontWeight:"normal"},legend:{labelBaseline:"middle",labelFontSize:11,symbolSize:40},range:{category:["#000000","#7F7F7F","#1A1A1A","#999999","#333333","#B0B0B0","#4D4D4D","#C9C9C9","#666666","#DCDCDC"]}},xe="Benton Gothic Bold, sans-serif",_e={"category-6":["#ec8431","#829eb1","#c89d29","#3580b1","#adc839","#ab7fb4"],"fire-7":["#fbf2c7","#f9e39c","#f8d36e","#f4bb6a","#e68a4f","#d15a40","#ab4232"],"fireandice-6":["#e68a4f","#f4bb6a","#f9e39c","#dadfe2","#a6b7c6","#849eae"],"ice-7":["#edefee","#dadfe2","#c4ccd2","#a6b7c6","#849eae","#607785","#47525d"]},Ce={background:"#ffffff",title:{anchor:"start",color:"#000000",font:xe,fontSize:22,fontWeight:"normal"},arc:{fill:"#82c6df"},area:{fill:"#82c6df"},line:{stroke:"#82c6df",strokeWidth:2},path:{stroke:"#82c6df"},rect:{fill:"#82c6df"},shape:{stroke:"#82c6df"},symbol:{fill:"#82c6df",size:30},axis:{labelFont:"Benton Gothic, sans-serif",labelFontSize:11.5,labelFontWeight:"normal",titleFont:xe,titleFontSize:13,titleFontWeight:"normal"},axisX:{labelAngle:0,labelPadding:4,tickSize:3},axisY:{labelBaseline:"middle",maxExtent:45,minExtent:45,tickSize:2,titleAlign:"left",titleAngle:0,titleX:-45,titleY:-11},legend:{labelFont:"Benton Gothic, sans-serif",labelFontSize:11.5,symbolType:"square",titleFont:xe,titleFontSize:13,titleFontWeight:"normal"},range:{category:_e["category-6"],diverging:_e["fireandice-6"],heatmap:_e["fire-7"],ordinal:_e["fire-7"],ramp:_e["fire-7"]}},Pe={background:"#f9f9f9",arc:{fill:"#ab5787"},area:{fill:"#ab5787"},line:{stroke:"#ab5787"},path:{stroke:"#ab5787"},rect:{fill:"#ab5787"},shape:{stroke:"#ab5787"},symbol:{fill:"#ab5787",size:30},axis:{domainColor:"#979797",domainWidth:.5,gridWidth:.2,labelColor:"#979797",tickColor:"#979797",tickWidth:.2,titleColor:"#979797"},axisBand:{grid:!1},axisX:{grid:!0,tickSize:10},axisY:{domain:!1,grid:!0,tickSize:0},legend:{labelFontSize:11,padding:1,symbolSize:30,symbolType:"square"},range:{category:["#ab5787","#51b2e5","#703c5c","#168dd9","#d190b6","#00609f","#d365ba","#154866","#666666","#c4c4c4"]}},De={background:"#fff",arc:{fill:"#3e5c69"},area:{fill:"#3e5c69"},line:{stroke:"#3e5c69"},path:{stroke:"#3e5c69"},rect:{fill:"#3e5c69"},shape:{stroke:"#3e5c69"},symbol:{fill:"#3e5c69"},axis:{domainWidth:.5,grid:!0,labelPadding:2,tickSize:5,tickWidth:.5,titleFontWeight:"normal"},axisBand:{grid:!1},axisX:{gridWidth:.2},axisY:{gridDash:[3],gridWidth:.4},legend:{labelFontSize:11,padding:1,symbolType:"square"},range:{category:["#3e5c69","#6793a6","#182429","#0570b0","#3690c0","#74a9cf","#a6bddb","#e2ddf2"]}},ke="2.4.0";var Fe=Object.freeze({version:ke,dark:Ae,excel:Se,fivethirtyeight:Le,ggplot2:Te,latimes:Ce,quartz:Pe,vox:De}),Me="#vg-tooltip-element {\n visibility: hidden;\n padding: 8px;\n position: fixed;\n z-index: 1000;\n font-family: sans-serif;\n font-size: 11px;\n border-radius: 3px;\n box-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);\n /* The default theme is the light theme. */\n background-color: rgba(255, 255, 255, 0.95);\n border: 1px solid #d9d9d9;\n color: black; }\n #vg-tooltip-element.visible {\n visibility: visible; }\n #vg-tooltip-element h2 {\n margin-top: 0;\n margin-bottom: 10px;\n font-size: 13px; }\n #vg-tooltip-element table {\n border-spacing: 0; }\n #vg-tooltip-element table tr {\n border: none; }\n #vg-tooltip-element table tr td {\n overflow: hidden;\n text-overflow: ellipsis;\n padding-top: 2px;\n padding-bottom: 2px; }\n #vg-tooltip-element table tr td.key {\n color: #808080;\n max-width: 150px;\n text-align: right;\n padding-right: 4px; }\n #vg-tooltip-element table tr td.value {\n display: block;\n max-width: 300px;\n max-height: 7em;\n text-align: left; }\n #vg-tooltip-element.dark-theme {\n background-color: rgba(32, 32, 32, 0.9);\n border: 1px solid #f5f5f5;\n color: white; }\n #vg-tooltip-element.dark-theme td.key {\n color: #bfbfbf; }\n";const je="vg-tooltip-element",$e={offsetX:10,offsetY:10,id:je,styleId:"vega-tooltip-style",theme:"light",disableDefaultStyle:!1,sanitize:function(e){return String(e).replace(/&/g,"&").replace(/t&&c(),a=t=n+1):"]"===r&&(a||ze("Access path missing open bracket: "+e),a>0&&c(),a=0,t=n+1):n>t?c():t=n+1}return a&&ze("Access path missing closing bracket: "+e),o&&ze("Access path missing closing quote: "+e),n>t&&(n++,c()),i}(e),r="return _["+n.map(Ve).join("][")+"];";Ge(Function("_",r),[e=1===n.length?n[0]:e],t||e)}("id"),Ge((function(e){return e}),We,"identity"),Ge((function(){return 0}),We,"zero"),Ge((function(){return 1}),We,"one"),Ge((function(){return!0}),We,"true"),Ge((function(){return!1}),We,"false"),function(e,t){var n={};for(var r in e)Object.prototype.hasOwnProperty.call(e,r)&&t.indexOf(r)<0&&(n[r]=e[r]);if(null!=e&&"function"==typeof Object.getOwnPropertySymbols){var i=0;for(r=Object.getOwnPropertySymbols(e);ie?"[Object]":t.indexOf(r)>=0?"[Circular]":(t.push(r),r)}}(t))}class Ye{constructor(e){this.options=Object.assign(Object.assign({},$e),e);const t=this.options.id;if(this.call=this.tooltipHandler.bind(this),!this.options.disableDefaultStyle&&!document.getElementById(this.options.styleId)){const e=document.createElement("style");e.setAttribute("id",this.options.styleId),e.innerHTML=function(e){if(!/^[A-Za-z]+[-:.\w]*$/.test(e))throw new Error("Invalid HTML ID");return Me.toString().replace(je,e)}(t);const n=document.head;n.childNodes.length>0?n.insertBefore(e,n.childNodes[0]):n.appendChild(e)}this.el=document.getElementById(t),this.el||(this.el=document.createElement("div"),this.el.setAttribute("id",t),this.el.classList.add("vg-tooltip"),document.body.appendChild(this.el))}tooltipHandler(e,t,n,r){if(null==r||""===r)return void this.el.classList.remove("visible",`${this.options.theme}-theme`);this.el.innerHTML=function(e,t,n){if(Ue(e))return`[${e.map(e=>t(Xe(e)?e:qe(e,n))).join(", ")}]`;if(Be(e)){let r="";const i=e,{title:o}=i,a=He(i,["title"]);o&&(r+=`

${t(o)}

`);const s=Object.keys(a);if(s.length>0){r+="";for(const e of s){let i=a[e];void 0!==i&&(Be(i)&&(i=qe(i,n)),r+=``)}r+="
${t(e)}:${t(i)}
"}return r||"{}"}return t(e)}(r,this.options.sanitize,this.options.maxDepth),this.el.classList.add("visible",`${this.options.theme}-theme`);const{x:i,y:o}=function(e,t,n,r){let i=e.clientX+n;i+t.width>window.innerWidth&&(i=+e.clientX-n-t.width);let o=e.clientY+r;return o+t.height>window.innerHeight&&(o=+e.clientY-r-t.height),{x:i,y:o}}(t,this.el.getBoundingClientRect(),this.options.offsetX,this.options.offsetY);this.el.setAttribute("style",`top: ${o}px; left: ${i}px`)}}var Je='.vega-embed {\n position: relative;\n display: inline-block;\n padding-right: 38px; }\n .vega-embed details:not([open]) > :not(summary) {\n display: none !important; }\n .vega-embed summary {\n list-style: none;\n position: absolute;\n top: 0;\n right: 0;\n padding: 6px;\n z-index: 1000;\n background: white;\n box-shadow: 1px 1px 3px rgba(0, 0, 0, 0.1);\n color: #1b1e23;\n border: 1px solid #aaa;\n border-radius: 999px;\n opacity: 0.2;\n transition: opacity 0.4s ease-in;\n outline: none;\n cursor: pointer;\n line-height: 0px; }\n .vega-embed summary::-webkit-details-marker {\n display: none; }\n .vega-embed summary:active {\n box-shadow: #aaa 0px 0px 0px 1px inset; }\n .vega-embed summary svg {\n width: 14px;\n height: 14px; }\n .vega-embed details[open] summary {\n opacity: 0.7; }\n .vega-embed:hover summary,\n .vega-embed:focus summary {\n opacity: 1 !important;\n transition: opacity 0.2s ease; }\n .vega-embed .vega-actions {\n position: absolute;\n top: 35px;\n right: -9px;\n display: flex;\n flex-direction: column;\n padding-bottom: 8px;\n padding-top: 8px;\n border-radius: 4px;\n box-shadow: 0 2px 8px 0 rgba(0, 0, 0, 0.2);\n border: 1px solid #d9d9d9;\n background: white;\n animation-duration: 0.15s;\n animation-name: scale-in;\n animation-timing-function: cubic-bezier(0.2, 0, 0.13, 1.5); }\n .vega-embed .vega-actions a {\n padding: 8px 16px;\n font-family: sans-serif;\n font-size: 14px;\n font-weight: 600;\n white-space: nowrap;\n color: #434a56;\n text-decoration: none; }\n .vega-embed .vega-actions a:hover {\n background-color: #f7f7f9;\n color: black; }\n .vega-embed .vega-actions::before, .vega-embed .vega-actions::after {\n content: "";\n display: inline-block;\n position: absolute; }\n .vega-embed .vega-actions::before {\n left: auto;\n right: 14px;\n top: -16px;\n border: 8px solid #0000;\n border-bottom-color: #d9d9d9; }\n .vega-embed .vega-actions::after {\n left: auto;\n right: 15px;\n top: -14px;\n border: 7px solid #0000;\n border-bottom-color: #fff; }\n\n.vega-embed-wrapper {\n max-width: 100%;\n overflow: scroll;\n padding-right: 14px; }\n\n@keyframes scale-in {\n from {\n opacity: 0;\n transform: scale(0.6); }\n to {\n opacity: 1;\n transform: scale(1); } }\n';const Ze=e;let Ke=t;const Qe=window;void 0===Ke&&Qe.vl&&Qe.vl.compile&&(Ke=Qe.vl);const et={CLICK_TO_VIEW_ACTIONS:"Click to view actions",COMPILED_ACTION:"View Compiled Vega",EDITOR_ACTION:"Open in Vega Editor",PNG_ACTION:"Save as PNG",SOURCE_ACTION:"View Source",SVG_ACTION:"Save as SVG"},tt={vega:"Vega","vega-lite":"Vega-Lite"},nt={vega:Ze.version,"vega-lite":Ke?Ke.version:"not available"},rt={vega:e=>e,"vega-lite":(e,t)=>Ke.compile(e,{config:t}).spec},it='\n\n \n \n \n';function ot(e,t,n,r){const i=`${t}
`,o=`
${n}`,a=window.open("");a.document.write(i+e+o),a.document.title=`${tt[r]} JSON Source`}function at(t,n,r={}){return ce(this,void 0,void 0,(function*(){const i=(o=r.loader)&&"load"in o?r.loader:Ze.loader(r.loader);var o;if(Ze.isString(n)){const e=yield i.load(n);return at(t,JSON.parse(e),r)}let a=(r=ve(r,n.usermeta&&n.usermeta.embedOptions||{})).config||{};if(Ze.isString(a)){const e=yield i.load(a);return at(t,n,Object.assign(Object.assign({},r),{config:JSON.parse(e)}))}const s=e.isBoolean(r.actions)?r.actions:ve({export:{svg:!0,png:!0},source:!0,compiled:!0,editor:!0},r.actions||{}),l=Object.assign(Object.assign({},et),r.i18n),c=r.renderer||"canvas",u=r.logLevel||Ze.Warn,f=r.downloadFileName||"visualization";if(!1!==r.defaultStyle){const e="vega-embed-style";if(!document.getElementById(e)){const t=document.createElement("style");t.id=e,t.innerText=void 0===r.defaultStyle||!0===r.defaultStyle?Je.toString():r.defaultStyle,document.head.appendChild(t)}}r.theme&&(a=ve(Fe[r.theme],a));const p=function(e,t){if(e.$schema){const n=Ne(e.$schema);t&&t!==n.library&&console.warn(`The given visualization spec is written in ${tt[n.library]}, but mode argument sets ${tt[t]||t}.`);const r=n.library;return Re(nt[r],`^${n.version.slice(1)}`)||console.warn(`The input spec uses ${tt[r]} ${n.version}, but the current version of ${tt[r]} is v${nt[r]}.`),r}return"mark"in e||"encoding"in e||"layer"in e||"hconcat"in e||"vconcat"in e||"facet"in e||"repeat"in e?"vega-lite":"marks"in e||"signals"in e||"scales"in e||"axes"in e?"vega":t||"vega"}(n,r.mode);let h=rt[p](n,a);if("vega-lite"===p&&h.$schema){const e=Ne(h.$schema);Re(nt.vega,`^${e.version.slice(1)}`)||console.warn(`The compiled spec uses Vega ${e.version}, but current version is v${nt.vega}.`)}const d=function(e){return"string"==typeof e?new ae([[document.querySelector(e)]],[document.documentElement]):new ae([[e]],oe)}(t).classed("vega-embed",!0).html(""),g=r.patch;if(g)if(g instanceof Function)h=g(h);else if(Ze.isString(g)){const e=yield i.load(g);h=ve(h,JSON.parse(e))}else h=ve(h,g);const m=Ze.parse(h,"vega-lite"===p?{}:a),v=new Ze.View(m,{loader:i,logLevel:u,renderer:c});if(!1!==r.tooltip){let e;e="function"==typeof r.tooltip?r.tooltip:new Ye(!0===r.tooltip?{}:r.tooltip).call,v.tooltip(e)}let{hover:E}=r;if(void 0===E&&(E="vega"===p),E){const{hoverSet:e,updateSet:t}="boolean"==typeof E?{}:E;v.hover(e,t)}if(r&&(r.width&&v.width(r.width),r.height&&v.height(r.height),r.padding&&v.padding(r.padding)),yield v.initialize(t).runAsync(),!1!==s){let e=d;if(!1!==r.defaultStyle){const t=d.append("details").attr("title",l.CLICK_TO_VIEW_ACTIONS);e=t,t.insert("summary").html(it);const n=t.node();document.addEventListener("click",e=>{n.contains(e.target)||n.removeAttribute("open")})}const t=e.insert("div").attr("class","vega-actions");if(!0===s||!1!==s.export)for(const e of["svg","png"])if(!0===s||!0===s.export||s.export[e]){const n=l[`${e.toUpperCase()}_ACTION`];t.append("a").text(n).attr("href","#").attr("target","_blank").attr("download",`${f}.${e}`).on("mousedown",(function(){v.toImageURL(e,r.scaleFactor).then(e=>{this.href=e}).catch(e=>{throw e}),Z.preventDefault()}))}if(!0!==s&&!1===s.source||t.append("a").text(l.SOURCE_ACTION).attr("href","#").on("mousedown",()=>{ot(ye(n),r.sourceHeader||"",r.sourceFooter||"",p),Z.preventDefault()}),"vega-lite"!==p||!0!==s&&!1===s.compiled||t.append("a").text(l.COMPILED_ACTION).attr("href","#").on("mousedown",()=>{ot(ye(h),r.sourceHeader||"",r.sourceFooter||"","vega"),Z.preventDefault()}),!0===s||!1!==s.editor){const e=r.editorUrl||"https://vega.github.io/editor/";t.append("a").text(l.EDITOR_ACTION).attr("href","#").on("mousedown",()=>{!function(e,t,n){const r=e.open(t),i=250;let o=~~(1e4/i);e.addEventListener("message",(function t(n){n.source===r&&(o=0,e.removeEventListener("message",t,!1))}),!1),setTimeout((function e(){o<=0||(r.postMessage(n,"*"),setTimeout(e,i),o-=1)}),i)}(window,e,{config:a,mode:p,renderer:c,spec:ye(n)}),Z.preventDefault()})}}return{view:v,spec:n,vgSpec:h}}))}function st(e,t={}){return ce(this,void 0,void 0,(function*(){const n=document.createElement("div");n.classList.add("vega-embed-wrapper");const r=document.createElement("div");n.appendChild(r);const i=!0===t.actions||!1===t.actions?t.actions:Object.assign({export:!0,source:!1,compiled:!0,editor:!0},t.actions||{}),o=yield at(r,e,Object.assign({actions:i},t||{}));return n.value=o.view,n}))}String.prototype.startsWith||(String.prototype.startsWith=function(e,t){return this.substr(!t||t<0?0:+t,e.length)===e});const lt=(...t)=>t.length>1&&(e.isString(t[0])&&!function(e){return e.startsWith("http://")||e.startsWith("https://")||e.startsWith("//")}(t[0])||function(e){return e instanceof se||"object"==typeof HTMLElement?e instanceof HTMLElement:e&&"object"==typeof e&&null!==e&&1===e.nodeType&&"string"==typeof e.nodeName}(t[0])||3===t.length)?at(t[0],t[1],t[2]):st(t[0],t[1]);return lt.vegaLite=Ke,lt.vl=Ke,lt.container=st,lt.embed=at,lt.vega=Ze,lt.default=at,lt.version=le,lt})); ================================================ FILE: docs/_static/js/vega-lite@5.js ================================================ !function(e,t){"object"==typeof exports&&"undefined"!=typeof module?t(exports,require("vega")):"function"==typeof define&&define.amd?define(["exports","vega"],t):t((e="undefined"!=typeof globalThis?globalThis:e||self).vegaLite={},e.vega)}(this,(function(e,t){"use strict";var n="5.16.3";function i(e){return!!e.or}function r(e){return!!e.and}function o(e){return!!e.not}function a(e,t){if(o(e))a(e.not,t);else if(r(e))for(const n of e.and)a(n,t);else if(i(e))for(const n of e.or)a(n,t);else t(e)}function s(e,t){return o(e)?{not:s(e.not,t)}:r(e)?{and:e.and.map((e=>s(e,t)))}:i(e)?{or:e.or.map((e=>s(e,t)))}:t(e)}const l=structuredClone;function c(e){throw new Error(e)}function u(e,n){const i={};for(const r of n)t.hasOwnProperty(e,r)&&(i[r]=e[r]);return i}function f(e,t){const n={...e};for(const e of t)delete n[e];return n}function d(e){if(t.isNumber(e))return e;const n=t.isString(e)?e:X(e);if(n.length<250)return n;let i=0;for(let e=0;e1?t-1:0),i=1;i0===t?e:`[${e}]`)),r=e.map(((t,n)=>e.slice(0,n+1).join("")));for(const e of r)n.add(e)}return n}function k(e,t){return void 0===e||void 0===t||$(w(e),w(t))}function S(e){return 0===D(e).length}Set.prototype.toJSON=function(){return`Set(${[...this].map((e=>X(e))).join(",")})`};const D=Object.keys,F=Object.values,z=Object.entries;function O(e){return!0===e||!1===e}function _(e){const t=e.replace(/\W/g,"_");return(e.match(/^\d+/)?"_":"")+t}function N(e,t){return o(e)?`!(${N(e.not,t)})`:r(e)?`(${e.and.map((e=>N(e,t))).join(") && (")})`:i(e)?`(${e.or.map((e=>N(e,t))).join(") || (")})`:t(e)}function C(e,t){if(0===t.length)return!0;const n=t.shift();return n in e&&C(e[n],t)&&delete e[n],S(e)}function P(e){return e.charAt(0).toUpperCase()+e.substr(1)}function A(e){let n=arguments.length>1&&void 0!==arguments[1]?arguments[1]:"datum";const i=t.splitAccessPath(e),r=[];for(let e=1;e<=i.length;e++){const o=`[${i.slice(0,e).map(t.stringValue).join("][")}]`;r.push(`${n}${o}`)}return r.join(" && ")}function j(e){return`${arguments.length>1&&void 0!==arguments[1]?arguments[1]:"datum"}[${t.stringValue(t.splitAccessPath(e).join("."))}]`}function T(e){return e.replace(/(\[|\]|\.|'|")/g,"\\$1")}function E(e){return`${t.splitAccessPath(e).map(T).join("\\.")}`}function M(e,t,n){return e.replace(new RegExp(t.replace(/[-/\\^$*+?.()|[\]{}]/g,"\\$&"),"g"),n)}function L(e){return`${t.splitAccessPath(e).join(".")}`}function q(e){return e?t.splitAccessPath(e).length:0}function U(){for(var e=arguments.length,t=new Array(e),n=0;nfn(e[t])?_(`_${t}_${z(e[t])}`):_(`_${t}_${e[t]}`))).join("")}function ln(e){return!0===e||un(e)&&!e.binned}function cn(e){return"binned"===e||un(e)&&!0===e.binned}function un(e){return t.isObject(e)}function fn(e){return e?.param}function dn(e){switch(e){case Q:case J:case ye:case me:case pe:case ge:case we:case be:case xe:case $e:case he:return 6;case ke:return 4;default:return 10}}function mn(e){return!!e?.expr}function pn(e){const t=D(e||{}),n={};for(const i of t)n[i]=Sn(e[i]);return n}function gn(e){const{anchor:t,frame:n,offset:i,orient:r,angle:o,limit:a,color:s,subtitleColor:l,subtitleFont:c,subtitleFontSize:f,subtitleFontStyle:d,subtitleFontWeight:m,subtitleLineHeight:p,subtitlePadding:g,...h}=e,y={...t?{anchor:t}:{},...n?{frame:n}:{},...i?{offset:i}:{},...r?{orient:r}:{},...void 0!==o?{angle:o}:{},...void 0!==a?{limit:a}:{}},v={...l?{subtitleColor:l}:{},...c?{subtitleFont:c}:{},...f?{subtitleFontSize:f}:{},...d?{subtitleFontStyle:d}:{},...m?{subtitleFontWeight:m}:{},...p?{subtitleLineHeight:p}:{},...g?{subtitlePadding:g}:{}};return{titleMarkConfig:{...h,...s?{fill:s}:{}},subtitleMarkConfig:u(e,["align","baseline","dx","dy","limit"]),nonMarkTitleProperties:y,subtitle:v}}function hn(e){return t.isString(e)||t.isArray(e)&&t.isString(e[0])}function yn(e){return!!e?.signal}function vn(e){return!!e.step}function bn(e){return!t.isArray(e)&&("field"in e&&"data"in e)}const xn=D({aria:1,description:1,ariaRole:1,ariaRoleDescription:1,blend:1,opacity:1,fill:1,fillOpacity:1,stroke:1,strokeCap:1,strokeWidth:1,strokeOpacity:1,strokeDash:1,strokeDashOffset:1,strokeJoin:1,strokeOffset:1,strokeMiterLimit:1,startAngle:1,endAngle:1,padAngle:1,innerRadius:1,outerRadius:1,size:1,shape:1,interpolate:1,tension:1,orient:1,align:1,baseline:1,text:1,dir:1,dx:1,dy:1,ellipsis:1,limit:1,radius:1,theta:1,angle:1,font:1,fontSize:1,fontWeight:1,fontStyle:1,lineBreak:1,lineHeight:1,cursor:1,href:1,tooltip:1,cornerRadius:1,cornerRadiusTopLeft:1,cornerRadiusTopRight:1,cornerRadiusBottomLeft:1,cornerRadiusBottomRight:1,aspect:1,width:1,height:1,url:1,smooth:1}),$n={arc:1,area:1,group:1,image:1,line:1,path:1,rect:1,rule:1,shape:1,symbol:1,text:1,trail:1},wn=["cornerRadius","cornerRadiusTopLeft","cornerRadiusTopRight","cornerRadiusBottomLeft","cornerRadiusBottomRight"];function kn(e){const n=t.isArray(e.condition)?e.condition.map(Dn):Dn(e.condition);return{...Sn(e),condition:n}}function Sn(e){if(mn(e)){const{expr:t,...n}=e;return{signal:t,...n}}return e}function Dn(e){if(mn(e)){const{expr:t,...n}=e;return{signal:t,...n}}return e}function Fn(e){if(mn(e)){const{expr:t,...n}=e;return{signal:t,...n}}return yn(e)?e:void 0!==e?{value:e}:void 0}function zn(e){return yn(e)?e.signal:t.stringValue(e.value)}function On(e){return yn(e)?e.signal:null==e?null:t.stringValue(e)}function _n(e,t,n){for(const i of n){const n=Pn(i,t.markDef,t.config);void 0!==n&&(e[i]=Fn(n))}return e}function Nn(e){return[].concat(e.type,e.style??[])}function Cn(e,t,n){let i=arguments.length>3&&void 0!==arguments[3]?arguments[3]:{};const{vgChannel:r,ignoreVgConfig:o}=i;return r&&void 0!==t[r]?t[r]:void 0!==t[e]?t[e]:!o||r&&r!==e?Pn(e,t,n,i):void 0}function Pn(e,t,n){let{vgChannel:i}=arguments.length>3&&void 0!==arguments[3]?arguments[3]:{};return U(i?An(e,t,n.style):void 0,An(e,t,n.style),i?n[t.type][i]:void 0,n[t.type][e],i?n.mark[i]:n.mark[e])}function An(e,t,n){return jn(e,Nn(t),n)}function jn(e,n,i){let r;n=t.array(n);for(const t of n){const n=i[t];n&&void 0!==n[e]&&(r=n[e])}return r}function Tn(e,n){return t.array(e).reduce(((e,t)=>(e.field.push(oa(t,n)),e.order.push(t.sort??"ascending"),e)),{field:[],order:[]})}function En(e,t){const n=[...e];return t.forEach((e=>{for(const t of n)if(Y(t,e))return;n.push(e)})),n}function Mn(e,n){return Y(e,n)||!n?e:e?[...t.array(e),...t.array(n)].join(", "):n}function Ln(e,t){const n=e.value,i=t.value;if(null==n||null===i)return{explicit:e.explicit,value:null};if((hn(n)||yn(n))&&(hn(i)||yn(i)))return{explicit:e.explicit,value:Mn(n,i)};if(hn(n)||yn(n))return{explicit:e.explicit,value:n};if(hn(i)||yn(i))return{explicit:e.explicit,value:i};if(!(hn(n)||yn(n)||hn(i)||yn(i)))return{explicit:e.explicit,value:En(n,i)};throw new Error("It should never reach here")}function qn(e,t,n){return(t=function(e){var t=function(e,t){if("object"!=typeof e||null===e)return e;var n=e[Symbol.toPrimitive];if(void 0!==n){var i=n.call(e,t||"default");if("object"!=typeof i)return i;throw new TypeError("@@toPrimitive must return a primitive value.")}return("string"===t?String:Number)(e)}(e,"string");return"symbol"==typeof t?t:String(t)}(t))in e?Object.defineProperty(e,t,{value:n,enumerable:!0,configurable:!0,writable:!0}):e[t]=n,e}function Un(e,t,n){return function(e,t,n){if(t.set)t.set.call(e,n);else{if(!t.writable)throw new TypeError("attempted to set read only private field");t.value=n}}(e,Rn(e,t,"set"),n),n}function Rn(e,t,n){if(!t.has(e))throw new TypeError("attempted to "+n+" private field on non-instance");return t.get(e)}function Wn(e,t,n){!function(e,t){if(t.has(e))throw new TypeError("Cannot initialize the same private elements twice on an object")}(e,t),t.set(e,n)}function Bn(e){return`Invalid specification ${X(e)}. Make sure the specification includes at least one of the following properties: "mark", "layer", "facet", "hconcat", "vconcat", "concat", or "repeat".`}const In='Autosize "fit" only works for single views and layered views.';function Hn(e){return`${"width"==e?"Width":"Height"} "container" only works for single views and layered views.`}function Vn(e){return`${"width"==e?"Width":"Height"} "container" only works well with autosize "fit" or "fit-${"width"==e?"x":"y"}".`}function Gn(e){return e?`Dropping "fit-${e}" because spec has discrete ${rt(e)}.`:'Dropping "fit" because spec has discrete size.'}function Yn(e){return`Unknown field for ${e}. Cannot calculate view size.`}function Xn(e){return`Cannot project a selection on encoding channel "${e}", which has no field.`}function Qn(e,t){return`Cannot project a selection on encoding channel "${e}" as it uses an aggregate function ("${t}").`}function Jn(e){return`Selection not supported for ${e} yet.`}const Kn="The same selection must be used to override scale domains in a layered view.";function Zn(e){return`The "columns" property cannot be used when "${e}" has nested row/column.`}function ei(e,t,n){return`An ancestor parsed field "${e}" as ${n} but a child wants to parse the field as ${t}.`}function ti(e){return`Config.customFormatTypes is not true, thus custom format type and format for channel ${e} are dropped.`}function ni(e){return`${e}Offset dropped because ${e} is continuous`}function ii(e){return`Invalid field type "${e}".`}function ri(e,t){const{fill:n,stroke:i}=t;return`Dropping color ${e} as the plot also has ${n&&i?"fill and stroke":n?"fill":"stroke"}.`}function oi(e,t){return`Dropping ${X(e)} from channel "${t}" since it does not contain any data field, datum, value, or signal.`}function ai(e,t,n){return`${e} dropped as it is incompatible with "${t}"${n?` when ${n}`:""}.`}function si(e){return`${e} encoding should be discrete (ordinal / nominal / binned).`}function li(e){return`${e} encoding should be discrete (ordinal / nominal / binned) or use a discretizing scale (e.g. threshold).`}function ci(e,t){return`Using discrete channel "${e}" to encode "${t}" field can be misleading as it does not encode ${"ordinal"===t?"order":"magnitude"}.`}function ui(e){return`Using unaggregated domain with raw field has no effect (${X(e)}).`}function fi(e){return`Unaggregated domain not applicable for "${e}" since it produces values outside the origin domain of the source data.`}function di(e){return`Unaggregated domain is currently unsupported for log scale (${X(e)}).`}function mi(e,t,n){return`${n}-scale's "${t}" is dropped as it does not work with ${e} scale.`}function pi(e){return`The step for "${e}" is dropped because the ${"width"===e?"x":"y"} is continuous.`}const gi="Domains that should be unioned has conflicting sort properties. Sort will be set to true.";function hi(e,t){return`Invalid ${e}: ${X(t)}.`}function yi(e){return`1D error band does not support ${e}.`}function vi(e){return`Channel ${e} is required for "binned" bin.`}const bi=t.logger(t.Warn);let xi=bi;function $i(){xi.warn(...arguments)}function wi(e){if(e&&t.isObject(e))for(const t of Ni)if(t in e)return!0;return!1}const ki=["january","february","march","april","may","june","july","august","september","october","november","december"],Si=ki.map((e=>e.substr(0,3))),Di=["sunday","monday","tuesday","wednesday","thursday","friday","saturday"],Fi=Di.map((e=>e.substr(0,3)));function zi(e,n){const i=[];if(n&&void 0!==e.day&&D(e).length>1&&($i(function(e){return`Dropping day from datetime ${X(e)} as day cannot be combined with other units.`}(e)),delete(e=l(e)).day),void 0!==e.year?i.push(e.year):i.push(2012),void 0!==e.month){const r=n?function(e){if(V(e)&&(e=+e),t.isNumber(e))return e-1;{const t=e.toLowerCase(),n=ki.indexOf(t);if(-1!==n)return n;const i=t.substr(0,3),r=Si.indexOf(i);if(-1!==r)return r;throw new Error(hi("month",e))}}(e.month):e.month;i.push(r)}else if(void 0!==e.quarter){const r=n?function(e){if(V(e)&&(e=+e),t.isNumber(e))return e>4&&$i(hi("quarter",e)),e-1;throw new Error(hi("quarter",e))}(e.quarter):e.quarter;i.push(t.isNumber(r)?3*r:`${r}*3`)}else i.push(0);if(void 0!==e.date)i.push(e.date);else if(void 0!==e.day){const r=n?function(e){if(V(e)&&(e=+e),t.isNumber(e))return e%7;{const t=e.toLowerCase(),n=Di.indexOf(t);if(-1!==n)return n;const i=t.substr(0,3),r=Fi.indexOf(i);if(-1!==r)return r;throw new Error(hi("day",e))}}(e.day):e.day;i.push(t.isNumber(r)?r+1:`${r}+1`)}else i.push(1);for(const t of["hours","minutes","seconds","milliseconds"]){const n=e[t];i.push(void 0===n?0:n)}return i}function Oi(e){const t=zi(e,!0).join(", ");return e.utc?`utc(${t})`:`datetime(${t})`}const _i={year:1,quarter:1,month:1,week:1,day:1,dayofyear:1,date:1,hours:1,minutes:1,seconds:1,milliseconds:1},Ni=D(_i);function Ci(e){return t.isObject(e)?e.binned:Pi(e)}function Pi(e){return e&&e.startsWith("binned")}function Ai(e){return e.startsWith("utc")}const ji={"year-month":"%b %Y ","year-month-date":"%b %d, %Y "};function Ti(e){return Ni.filter((t=>Mi(e,t)))}function Ei(e){const t=Ti(e);return t[t.length-1]}function Mi(e,t){const n=e.indexOf(t);return!(n<0)&&(!(n>0&&"seconds"===t&&"i"===e.charAt(n-1))&&(!(e.length>n+3&&"day"===t&&"o"===e.charAt(n+3))&&!(n>0&&"year"===t&&"f"===e.charAt(n-1))))}function Li(e,t){let{end:n}=arguments.length>2&&void 0!==arguments[2]?arguments[2]:{end:!1};const i=A(t),r=Ai(e)?"utc":"";let o;const a={};for(const t of Ni)Mi(e,t)&&(a[t]="quarter"===(s=t)?`(${r}quarter(${i})-1)`:`${r}${s}(${i})`,o=t);var s;return n&&(a[o]+="+1"),function(e){const t=zi(e,!1).join(", ");return e.utc?`utc(${t})`:`datetime(${t})`}(a)}function qi(e){if(!e)return;return`timeUnitSpecifier(${X(Ti(e))}, ${X(ji)})`}function Ui(e){if(!e)return;let n;return t.isString(e)?n=Pi(e)?{unit:e.substring(6),binned:!0}:{unit:e}:t.isObject(e)&&(n={...e,...e.unit?{unit:e.unit}:{}}),Ai(n.unit)&&(n.utc=!0,n.unit=n.unit.substring(3)),n}function Ri(e){let t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:e=>e;const n=Ui(e),i=Ei(n.unit);if(i&&"day"!==i){const e={year:2001,month:1,date:1,hours:0,minutes:0,seconds:0,milliseconds:0},{step:r,part:o}=Bi(i,n.step);return`${t(Oi({...e,[o]:+e[o]+r}))} - ${t(Oi(e))}`}}const Wi={year:1,month:1,date:1,hours:1,minutes:1,seconds:1,milliseconds:1};function Bi(e){let t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:1;if(function(e){return!!Wi[e]}(e))return{part:e,step:t};switch(e){case"day":case"dayofyear":return{part:"date",step:t};case"quarter":return{part:"month",step:3*t};case"week":return{part:"date",step:7*t}}}function Ii(e){return!!e?.field&&void 0!==e.equal}function Hi(e){return!!e?.field&&void 0!==e.lt}function Vi(e){return!!e?.field&&void 0!==e.lte}function Gi(e){return!!e?.field&&void 0!==e.gt}function Yi(e){return!!e?.field&&void 0!==e.gte}function Xi(e){if(e?.field){if(t.isArray(e.range)&&2===e.range.length)return!0;if(yn(e.range))return!0}return!1}function Qi(e){return!!e?.field&&(t.isArray(e.oneOf)||t.isArray(e.in))}function Ji(e){return Qi(e)||Ii(e)||Xi(e)||Hi(e)||Gi(e)||Vi(e)||Yi(e)}function Ki(e,t){return wa(e,{timeUnit:t,wrapTime:!0})}function Zi(e){let t=!(arguments.length>1&&void 0!==arguments[1])||arguments[1];const{field:n}=e,i=Ui(e.timeUnit),{unit:r,binned:o}=i||{},a=oa(e,{expr:"datum"}),s=r?`time(${o?a:Li(r,n)})`:a;if(Ii(e))return`${s}===${Ki(e.equal,r)}`;if(Hi(e)){return`${s}<${Ki(e.lt,r)}`}if(Gi(e)){return`${s}>${Ki(e.gt,r)}`}if(Vi(e)){return`${s}<=${Ki(e.lte,r)}`}if(Yi(e)){return`${s}>=${Ki(e.gte,r)}`}if(Qi(e))return`indexof([${function(e,t){return e.map((e=>Ki(e,t)))}(e.oneOf,r).join(",")}], ${s}) !== -1`;if(function(e){return!!e?.field&&void 0!==e.valid}(e))return er(s,e.valid);if(Xi(e)){const{range:n}=e,i=yn(n)?{signal:`${n.signal}[0]`}:n[0],o=yn(n)?{signal:`${n.signal}[1]`}:n[1];if(null!==i&&null!==o&&t)return"inrange("+s+", ["+Ki(i,r)+", "+Ki(o,r)+"])";const a=[];return null!==i&&a.push(`${s} >= ${Ki(i,r)}`),null!==o&&a.push(`${s} <= ${Ki(o,r)}`),a.length>0?a.join(" && "):"true"}throw new Error(`Invalid field predicate: ${X(e)}`)}function er(e){return!(arguments.length>1&&void 0!==arguments[1])||arguments[1]?`isValid(${e}) && isFinite(+${e})`:`!isValid(${e}) || !isFinite(+${e})`}function tr(e){return Ji(e)&&e.timeUnit?{...e,timeUnit:Ui(e.timeUnit)}:e}function nr(e){return"quantitative"===e||"temporal"===e}function ir(e){return"ordinal"===e||"nominal"===e}const rr="quantitative",or="ordinal",ar="temporal",sr="nominal",lr="geojson";const cr={LINEAR:"linear",LOG:"log",POW:"pow",SQRT:"sqrt",SYMLOG:"symlog",IDENTITY:"identity",SEQUENTIAL:"sequential",TIME:"time",UTC:"utc",QUANTILE:"quantile",QUANTIZE:"quantize",THRESHOLD:"threshold",BIN_ORDINAL:"bin-ordinal",ORDINAL:"ordinal",POINT:"point",BAND:"band"},ur={linear:"numeric",log:"numeric",pow:"numeric",sqrt:"numeric",symlog:"numeric",identity:"numeric",sequential:"numeric",time:"time",utc:"time",ordinal:"ordinal","bin-ordinal":"bin-ordinal",point:"ordinal-position",band:"ordinal-position",quantile:"discretizing",quantize:"discretizing",threshold:"discretizing"};function fr(e,t){const n=ur[e],i=ur[t];return n===i||"ordinal-position"===n&&"time"===i||"ordinal-position"===i&&"time"===n}const dr={linear:0,log:1,pow:1,sqrt:1,symlog:1,identity:1,sequential:1,time:0,utc:0,point:10,band:11,ordinal:0,"bin-ordinal":0,quantile:0,quantize:0,threshold:0};function mr(e){return dr[e]}const pr=new Set(["linear","log","pow","sqrt","symlog"]),gr=new Set([...pr,"time","utc"]);function hr(e){return pr.has(e)}const yr=new Set(["quantile","quantize","threshold"]),vr=new Set([...gr,...yr,"sequential","identity"]),br=new Set(["ordinal","bin-ordinal","point","band"]);function xr(e){return br.has(e)}function $r(e){return vr.has(e)}function wr(e){return gr.has(e)}function kr(e){return yr.has(e)}function Sr(e){return e?.param}const{type:Dr,domain:Fr,range:zr,rangeMax:Or,rangeMin:_r,scheme:Nr,...Cr}={type:1,domain:1,domainMax:1,domainMin:1,domainMid:1,domainRaw:1,align:1,range:1,rangeMax:1,rangeMin:1,scheme:1,bins:1,reverse:1,round:1,clamp:1,nice:1,base:1,exponent:1,constant:1,interpolate:1,zero:1,padding:1,paddingInner:1,paddingOuter:1},Pr=D(Cr);function Ar(e,t){switch(t){case"type":case"domain":case"reverse":case"range":return!0;case"scheme":case"interpolate":return!["point","band","identity"].includes(e);case"bins":return!["point","band","identity","ordinal"].includes(e);case"round":return wr(e)||"band"===e||"point"===e;case"padding":case"rangeMin":case"rangeMax":return wr(e)||["point","band"].includes(e);case"paddingOuter":case"align":return["point","band"].includes(e);case"paddingInner":return"band"===e;case"domainMax":case"domainMid":case"domainMin":case"domainRaw":case"clamp":return wr(e);case"nice":return wr(e)||"quantize"===e||"threshold"===e;case"exponent":return"pow"===e;case"base":return"log"===e;case"constant":return"symlog"===e;case"zero":return $r(e)&&!p(["log","time","utc","threshold","quantile"],e)}}function jr(e,t){switch(t){case"interpolate":case"scheme":case"domainMid":return qe(e)?void 0:`Cannot use the scale property "${t}" with non-color channel.`;case"align":case"type":case"bins":case"domain":case"domainMax":case"domainMin":case"domainRaw":case"range":case"base":case"exponent":case"constant":case"nice":case"padding":case"paddingInner":case"paddingOuter":case"rangeMax":case"rangeMin":case"reverse":case"round":case"clamp":case"zero":return}}const Tr={arc:"arc",area:"area",bar:"bar",image:"image",line:"line",point:"point",rect:"rect",rule:"rule",text:"text",tick:"tick",trail:"trail",circle:"circle",square:"square",geoshape:"geoshape"},Er=Tr.arc,Mr=Tr.area,Lr=Tr.bar,qr=Tr.image,Ur=Tr.line,Rr=Tr.point,Wr=Tr.rect,Br=Tr.rule,Ir=Tr.text,Hr=Tr.tick,Vr=Tr.trail,Gr=Tr.circle,Yr=Tr.square,Xr=Tr.geoshape;function Qr(e){return["line","area","trail"].includes(e)}function Jr(e){return["rect","bar","image","arc"].includes(e)}const Kr=new Set(D(Tr));function Zr(e){return e.type}const eo=["stroke","strokeWidth","strokeDash","strokeDashOffset","strokeOpacity","strokeJoin","strokeMiterLimit","fill","fillOpacity"],to=D({color:1,filled:1,invalid:1,order:1,radius2:1,theta2:1,timeUnitBandSize:1,timeUnitBandPosition:1}),no=D({mark:1,arc:1,area:1,bar:1,circle:1,image:1,line:1,point:1,rect:1,rule:1,square:1,text:1,tick:1,trail:1,geoshape:1});function io(e){return e&&null!=e.band}const ro={horizontal:["cornerRadiusTopRight","cornerRadiusBottomRight"],vertical:["cornerRadiusTopLeft","cornerRadiusTopRight"]},oo={binSpacing:1,continuousBandSize:5,minBandSize:.25,timeUnitBandPosition:.5},ao={binSpacing:0,continuousBandSize:5,minBandSize:.25,timeUnitBandPosition:.5};function so(e){const{channel:t,channelDef:n,markDef:i,scale:r,config:o}=e,a=mo(e);return Ho(n)&&!rn(n.aggregate)&&r&&wr(r.get("type"))?function(e){let{fieldDef:t,channel:n,markDef:i,ref:r,config:o}=e;if(Qr(i.type))return r;const a=Cn("invalid",i,o);if(null===a)return[lo(t,n),r];return r}({fieldDef:n,channel:t,markDef:i,ref:a,config:o}):a}function lo(e,t){return{test:co(e,!0),..."y"===tt(t)?{field:{group:"height"}}:{value:0}}}function co(e){let n=!(arguments.length>1&&void 0!==arguments[1])||arguments[1];return er(t.isString(e)?e:oa(e,{expr:"datum"}),!n)}function uo(e,t,n,i){const r={};if(t&&(r.scale=t),Go(e)){const{datum:t}=e;wi(t)?r.signal=Oi(t):yn(t)?r.signal=t.signal:mn(t)?r.signal=t.expr:r.value=t}else r.field=oa(e,n);if(i){const{offset:e,band:t}=i;e&&(r.offset=e),t&&(r.band=t)}return r}function fo(e){let{scaleName:t,fieldOrDatumDef:n,fieldOrDatumDef2:i,offset:r,startSuffix:o,endSuffix:a="end",bandPosition:s=.5}=e;const l=!yn(s)&&01&&void 0!==arguments[1]?arguments[1]:{},n=e.field;const i=t.prefix;let r=t.suffix,o="";if(function(e){return"count"===e.aggregate}(e))n=B("count");else{let i;if(!t.nofn)if(function(e){return"op"in e}(e))i=e.op;else{const{bin:a,aggregate:s,timeUnit:l}=e;ln(a)?(i=sn(a),r=(t.binSuffix??"")+(t.suffix??"")):s?en(s)?(o=`["${n}"]`,n=`argmax_${s.argmax}`):Zt(s)?(o=`["${n}"]`,n=`argmin_${s.argmin}`):i=String(s):l&&!Ci(l)&&(i=function(e){const{utc:t,...n}=Ui(e);return n.unit?(t?"utc":"")+D(n).map((e=>_(`${"unit"===e?"":`_${e}_`}${n[e]}`))).join(""):(t?"utc":"")+"timeunit"+D(n).map((e=>_(`_${e}_${n[e]}`))).join("")}(l),r=(!["range","mid"].includes(t.binSuffix)&&t.binSuffix||"")+(t.suffix??""))}i&&(n=n?`${i}_${n}`:i)}return r&&(n=`${n}_${r}`),i&&(n=`${i}_${n}`),t.forAs?L(n):t.expr?j(n,t.expr)+o:E(n)+o}function aa(e){switch(e.type){case"nominal":case"ordinal":case"geojson":return!0;case"quantitative":return Ho(e)&&!!e.bin;case"temporal":return!1}throw new Error(ii(e.type))}const sa=(e,t)=>{switch(t.fieldTitle){case"plain":return e.field;case"functional":return function(e){const{aggregate:t,bin:n,timeUnit:i,field:r}=e;if(en(t))return`${r} for argmax(${t.argmax})`;if(Zt(t))return`${r} for argmin(${t.argmin})`;const o=i&&!Ci(i)?Ui(i):void 0,a=t||o?.unit||o?.maxbins&&"timeunit"||ln(n)&&"bin";return a?`${a.toUpperCase()}(${r})`:r}(e);default:return function(e,t){const{field:n,bin:i,timeUnit:r,aggregate:o}=e;if("count"===o)return t.countTitle;if(ln(i))return`${n} (binned)`;if(r&&!Ci(r)){const e=Ui(r)?.unit;if(e)return`${n} (${Ti(e).join("-")})`}else if(o)return en(o)?`${n} for max ${o.argmax}`:Zt(o)?`${n} for min ${o.argmin}`:`${P(o)} of ${n}`;return n}(e,t)}};let la=sa;function ca(e){la=e}function ua(e,t,n){let{allowDisabling:i,includeDefault:r=!0}=n;const o=fa(e)?.title;if(!Ho(e))return o??e.title;const a=e,s=r?da(a,t):void 0;return i?U(o,a.title,s):o??a.title??s}function fa(e){return ta(e)&&e.axis?e.axis:na(e)&&e.legend?e.legend:jo(e)&&e.header?e.header:void 0}function da(e,t){return la(e,t)}function ma(e){if(ia(e)){const{format:t,formatType:n}=e;return{format:t,formatType:n}}{const t=fa(e)??{},{format:n,formatType:i}=t;return{format:n,formatType:i}}}function pa(e){return Ho(e)?e:Bo(e)?e.condition:void 0}function ga(e){return Jo(e)?e:Io(e)?e.condition:void 0}function ha(e,n,i){let r=arguments.length>3&&void 0!==arguments[3]?arguments[3]:{};if(t.isString(e)||t.isNumber(e)||t.isBoolean(e)){return $i(function(e,t,n){return`Channel ${e} is a ${t}. Converted to {value: ${X(n)}}.`}(n,t.isString(e)?"string":t.isNumber(e)?"number":"boolean",e)),{value:e}}return Jo(e)?ya(e,n,i,r):Io(e)?{...e,condition:ya(e.condition,n,i,r)}:e}function ya(e,n,i,r){if(ia(e)){const{format:t,formatType:o,...a}=e;if(go(o)&&!i.customFormatTypes)return $i(ti(n)),ya(a,n,i,r)}else{const t=ta(e)?"axis":na(e)?"legend":jo(e)?"header":null;if(t&&e[t]){const{format:o,formatType:a,...s}=e[t];if(go(a)&&!i.customFormatTypes)return $i(ti(n)),ya({...e,[t]:s},n,i,r)}}return Ho(e)?va(e,n,r):function(e){let n=e.type;if(n)return e;const{datum:i}=e;return n=t.isNumber(i)?"quantitative":t.isString(i)?"nominal":wi(i)?"temporal":void 0,{...e,type:n}}(e)}function va(e,n){let{compositeMark:i=!1}=arguments.length>2&&void 0!==arguments[2]?arguments[2]:{};const{aggregate:r,timeUnit:o,bin:a,field:s}=e,l={...e};if(i||!r||tn(r)||en(r)||Zt(r)||($i(function(e){return`Invalid aggregation operator "${e}".`}(r)),delete l.aggregate),o&&(l.timeUnit=Ui(o)),s&&(l.field=`${s}`),ln(a)&&(l.bin=ba(a,n)),cn(a)&&!zt(n)&&$i(function(e){return`Channel ${e} should not be used with "binned" bin.`}(n)),Ko(l)){const{type:e}=l,t=function(e){if(e)switch(e=e.toLowerCase()){case"q":case rr:return"quantitative";case"t":case ar:return"temporal";case"o":case or:return"ordinal";case"n":case sr:return"nominal";case lr:return"geojson"}}(e);e!==t&&(l.type=t),"quantitative"!==e&&rn(r)&&($i(function(e,t){return`Invalid field type "${e}" for aggregate: "${t}", using "quantitative" instead.`}(e,r)),l.type="quantitative")}else if(!et(n)){const e=function(e,n){switch(n){case"latitude":case"longitude":return"quantitative";case"row":case"column":case"facet":case"shape":case"strokeDash":return"nominal";case"order":return"ordinal"}if(Mo(e)&&t.isArray(e.sort))return"ordinal";const{aggregate:i,bin:r,timeUnit:o}=e;if(o)return"temporal";if(r||i&&!en(i)&&!Zt(i))return"quantitative";if(ea(e)&&e.scale?.type)switch(ur[e.scale.type]){case"numeric":case"discretizing":return"quantitative";case"time":return"temporal"}return"nominal"}(l,n);l.type=e}if(Ko(l)){const{compatible:e,warning:t}=function(e,t){const n=e.type;if("geojson"===n&&"shape"!==t)return{compatible:!1,warning:`Channel ${t} should not be used with a geojson data.`};switch(t){case Q:case J:case K:return aa(e)?xa:{compatible:!1,warning:si(t)};case Z:case ee:case ie:case re:case me:case pe:case ge:case Se:case Fe:case ze:case Oe:case _e:case Ne:case ve:case se:case oe:case Ce:return xa;case ue:case de:case ce:case fe:return n!==rr?{compatible:!1,warning:`Channel ${t} should be used with a quantitative field only, not ${e.type} field.`}:xa;case be:case xe:case $e:case we:case ye:case le:case ae:case te:case ne:return"nominal"!==n||e.sort?xa:{compatible:!1,warning:`Channel ${t} should not be used with an unsorted discrete field.`};case he:case ke:return aa(e)||ea(i=e)&&kr(i.scale?.type)?xa:{compatible:!1,warning:li(t)};case De:return"nominal"!==e.type||"sort"in e?xa:{compatible:!1,warning:"Channel order is inappropriate for nominal field, which has no inherent order."}}var i}(l,n)||{};!1===e&&$i(t)}if(Mo(l)&&t.isString(l.sort)){const{sort:e}=l;if(_o(e))return{...l,sort:{encoding:e}};const t=e.substr(1);if("-"===e.charAt(0)&&_o(t))return{...l,sort:{encoding:t,order:"descending"}}}if(jo(l)){const{header:e}=l;if(e){const{orient:t,...n}=e;if(t)return{...l,header:{...n,labelOrient:e.labelOrient||t,titleOrient:e.titleOrient||t}}}}return l}function ba(e,n){return t.isBoolean(e)?{maxbins:dn(n)}:"binned"===e?{binned:!0}:e.maxbins||e.step?e:{...e,maxbins:dn(n)}}const xa={compatible:!0};function $a(e){const{formatType:t}=ma(e);return"time"===t||!t&&((n=e)&&("temporal"===n.type||Ho(n)&&!!n.timeUnit));var n}function wa(e,n){let{timeUnit:i,type:r,wrapTime:o,undefinedIfExprNotRequired:a}=n;const s=i&&Ui(i)?.unit;let l,c=s||"temporal"===r;return mn(e)?l=e.expr:yn(e)?l=e.signal:wi(e)?(c=!0,l=Oi(e)):(t.isString(e)||t.isNumber(e))&&c&&(l=`datetime(${X(e)})`,function(e){return!!_i[e]}(s)&&(t.isNumber(e)&&e<1e4||t.isString(e)&&isNaN(Date.parse(e)))&&(l=Oi({[s]:e}))),l?o&&c?`time(${l})`:l:a?void 0:X(e)}function ka(e,t){const{type:n}=e;return t.map((t=>{const i=wa(t,{timeUnit:Ho(e)&&!Ci(e.timeUnit)?e.timeUnit:void 0,type:n,undefinedIfExprNotRequired:!0});return void 0!==i?{signal:i}:t}))}function Sa(e,t){return ln(e.bin)?Ht(t)&&["ordinal","nominal"].includes(e.type):(console.warn("Only call this method for binned field defs."),!1)}const Da={labelAlign:{part:"labels",vgProp:"align"},labelBaseline:{part:"labels",vgProp:"baseline"},labelColor:{part:"labels",vgProp:"fill"},labelFont:{part:"labels",vgProp:"font"},labelFontSize:{part:"labels",vgProp:"fontSize"},labelFontStyle:{part:"labels",vgProp:"fontStyle"},labelFontWeight:{part:"labels",vgProp:"fontWeight"},labelOpacity:{part:"labels",vgProp:"opacity"},labelOffset:null,labelPadding:null,gridColor:{part:"grid",vgProp:"stroke"},gridDash:{part:"grid",vgProp:"strokeDash"},gridDashOffset:{part:"grid",vgProp:"strokeDashOffset"},gridOpacity:{part:"grid",vgProp:"opacity"},gridWidth:{part:"grid",vgProp:"strokeWidth"},tickColor:{part:"ticks",vgProp:"stroke"},tickDash:{part:"ticks",vgProp:"strokeDash"},tickDashOffset:{part:"ticks",vgProp:"strokeDashOffset"},tickOpacity:{part:"ticks",vgProp:"opacity"},tickSize:null,tickWidth:{part:"ticks",vgProp:"strokeWidth"}};function Fa(e){return e?.condition}const za=["domain","grid","labels","ticks","title"],Oa={grid:"grid",gridCap:"grid",gridColor:"grid",gridDash:"grid",gridDashOffset:"grid",gridOpacity:"grid",gridScale:"grid",gridWidth:"grid",orient:"main",bandPosition:"both",aria:"main",description:"main",domain:"main",domainCap:"main",domainColor:"main",domainDash:"main",domainDashOffset:"main",domainOpacity:"main",domainWidth:"main",format:"main",formatType:"main",labelAlign:"main",labelAngle:"main",labelBaseline:"main",labelBound:"main",labelColor:"main",labelFlush:"main",labelFlushOffset:"main",labelFont:"main",labelFontSize:"main",labelFontStyle:"main",labelFontWeight:"main",labelLimit:"main",labelLineHeight:"main",labelOffset:"main",labelOpacity:"main",labelOverlap:"main",labelPadding:"main",labels:"main",labelSeparation:"main",maxExtent:"main",minExtent:"main",offset:"both",position:"main",tickCap:"main",tickColor:"main",tickDash:"main",tickDashOffset:"main",tickMinStep:"both",tickOffset:"both",tickOpacity:"main",tickRound:"both",ticks:"main",tickSize:"main",tickWidth:"both",title:"main",titleAlign:"main",titleAnchor:"main",titleAngle:"main",titleBaseline:"main",titleColor:"main",titleFont:"main",titleFontSize:"main",titleFontStyle:"main",titleFontWeight:"main",titleLimit:"main",titleLineHeight:"main",titleOpacity:"main",titlePadding:"main",titleX:"main",titleY:"main",encode:"both",scale:"both",tickBand:"both",tickCount:"both",tickExtra:"both",translate:"both",values:"both",zindex:"both"},_a={orient:1,aria:1,bandPosition:1,description:1,domain:1,domainCap:1,domainColor:1,domainDash:1,domainDashOffset:1,domainOpacity:1,domainWidth:1,format:1,formatType:1,grid:1,gridCap:1,gridColor:1,gridDash:1,gridDashOffset:1,gridOpacity:1,gridWidth:1,labelAlign:1,labelAngle:1,labelBaseline:1,labelBound:1,labelColor:1,labelFlush:1,labelFlushOffset:1,labelFont:1,labelFontSize:1,labelFontStyle:1,labelFontWeight:1,labelLimit:1,labelLineHeight:1,labelOffset:1,labelOpacity:1,labelOverlap:1,labelPadding:1,labels:1,labelSeparation:1,maxExtent:1,minExtent:1,offset:1,position:1,tickBand:1,tickCap:1,tickColor:1,tickCount:1,tickDash:1,tickDashOffset:1,tickExtra:1,tickMinStep:1,tickOffset:1,tickOpacity:1,tickRound:1,ticks:1,tickSize:1,tickWidth:1,title:1,titleAlign:1,titleAnchor:1,titleAngle:1,titleBaseline:1,titleColor:1,titleFont:1,titleFontSize:1,titleFontStyle:1,titleFontWeight:1,titleLimit:1,titleLineHeight:1,titleOpacity:1,titlePadding:1,titleX:1,titleY:1,translate:1,values:1,zindex:1},Na={..._a,style:1,labelExpr:1,encoding:1};function Ca(e){return!!Na[e]}const Pa=D({axis:1,axisBand:1,axisBottom:1,axisDiscrete:1,axisLeft:1,axisPoint:1,axisQuantitative:1,axisRight:1,axisTemporal:1,axisTop:1,axisX:1,axisXBand:1,axisXDiscrete:1,axisXPoint:1,axisXQuantitative:1,axisXTemporal:1,axisY:1,axisYBand:1,axisYDiscrete:1,axisYPoint:1,axisYQuantitative:1,axisYTemporal:1});function Aa(e){return"mark"in e}class ja{constructor(e,t){this.name=e,this.run=t}hasMatchingType(e){return!!Aa(e)&&(Zr(t=e.mark)?t.type:t)===this.name;var t}}function Ta(e,n){const i=e&&e[n];return!!i&&(t.isArray(i)?g(i,(e=>!!e.field)):Ho(i)||Bo(i))}function Ea(e,n){const i=e&&e[n];return!!i&&(t.isArray(i)?g(i,(e=>!!e.field)):Ho(i)||Go(i)||Io(i))}function Ma(e,t){if(zt(t)){const n=e[t];if((Ho(n)||Go(n))&&(ir(n.type)||Ho(n)&&n.timeUnit)){return Ea(e,at(t))}}return!1}function La(e){return g(Be,(n=>{if(Ta(e,n)){const i=e[n];if(t.isArray(i))return g(i,(e=>!!e.aggregate));{const e=pa(i);return e&&!!e.aggregate}}return!1}))}function qa(e,t){const n=[],i=[],r=[],o=[],a={};return Wa(e,((s,l)=>{if(Ho(s)){const{field:c,aggregate:u,bin:f,timeUnit:d,...m}=s;if(u||d||f){const e=fa(s),p=e?.title;let g=oa(s,{forAs:!0});const h={...p?[]:{title:ua(s,t,{allowDisabling:!0})},...m,field:g};if(u){let e;if(en(u)?(e="argmax",g=oa({op:"argmax",field:u.argmax},{forAs:!0}),h.field=`${g}.${c}`):Zt(u)?(e="argmin",g=oa({op:"argmin",field:u.argmin},{forAs:!0}),h.field=`${g}.${c}`):"boxplot"!==u&&"errorbar"!==u&&"errorband"!==u&&(e=u),e){const t={op:e,as:g};c&&(t.field=c),o.push(t)}}else if(n.push(g),Ko(s)&&ln(f)){if(i.push({bin:f,field:c,as:g}),n.push(oa(s,{binSuffix:"end"})),Sa(s,l)&&n.push(oa(s,{binSuffix:"range"})),zt(l)){const e={field:`${g}_end`};a[`${l}2`]=e}h.bin="binned",et(l)||(h.type=rr)}else if(d&&!Ci(d)){r.push({timeUnit:d,field:c,as:g});const e=Ko(s)&&s.type!==ar&&"time";e&&(l===Se||l===Oe?h.formatType=e:!function(e){return!!kt[e]}(l)?zt(l)&&(h.axis={formatType:e,...h.axis}):h.legend={formatType:e,...h.legend})}a[l]=h}else n.push(c),a[l]=e[l]}else a[l]=e[l]})),{bins:i,timeUnits:r,aggregate:o,groupby:n,encoding:a}}function Ua(e,t,n){const i=Vt(t,n);if(!i)return!1;if("binned"===i){const n=e[t===te?Z:ee];return!!(Ho(n)&&Ho(e[t])&&cn(n.bin))}return!0}function Ra(e,t){const n={};for(const i of D(e)){const r=ha(e[i],i,t,{compositeMark:!0});n[i]=r}return n}function Wa(e,n,i){if(e)for(const r of D(e)){const o=e[r];if(t.isArray(o))for(const e of o)n.call(i,e,r);else n.call(i,o,r)}}function Ba(e,n){return D(n).reduce(((i,r)=>{switch(r){case Z:case ee:case _e:case Ce:case Ne:case te:case ne:case ie:case re:case se:case le:case oe:case ae:case ce:case ue:case fe:case de:case Se:case he:case ve:case Oe:return i;case De:if("line"===e||"trail"===e)return i;case Fe:case ze:{const e=n[r];if(t.isArray(e)||Ho(e))for(const n of t.array(e))n.aggregate||i.push(oa(n,{}));return i}case ye:if("trail"===e)return i;case me:case pe:case ge:case be:case xe:case $e:case ke:case we:{const e=pa(n[r]);return e&&!e.aggregate&&i.push(oa(e,{})),i}}}),[])}function Ia(e,n,i){let r=!(arguments.length>3&&void 0!==arguments[3])||arguments[3];if("tooltip"in i)return{tooltip:i.tooltip};return{tooltip:[...e.map((e=>{let{fieldPrefix:t,titlePrefix:i}=e;const o=r?` of ${Ha(n)}`:"";return{field:t+n.field,type:n.type,title:yn(i)?{signal:`${i}"${escape(o)}"`}:i+o}})),...b(function(e){const n=[];for(const i of D(e))if(Ta(e,i)){const r=e[i],o=t.array(r);for(const e of o)Ho(e)?n.push(e):Bo(e)&&n.push(e.condition)}return n}(i).map(ra),d)]}}function Ha(e){const{title:t,field:n}=e;return U(t,n)}function Va(e,n,i,r,o){const{scale:a,axis:s}=i;return l=>{let{partName:c,mark:u,positionPrefix:f,endPositionPrefix:d,extraEncoding:m={}}=l;const p=Ha(i);return Ga(e,c,o,{mark:u,encoding:{[n]:{field:`${f}_${i.field}`,type:i.type,...void 0!==p?{title:p}:{},...void 0!==a?{scale:a}:{},...void 0!==s?{axis:s}:{}},...t.isString(d)?{[`${n}2`]:{field:`${d}_${i.field}`}}:{},...r,...m}})}}function Ga(e,n,i,r){const{clip:o,color:a,opacity:s}=e,l=e.type;return e[n]||void 0===e[n]&&i[n]?[{...r,mark:{...i[n],...o?{clip:o}:{},...a?{color:a}:{},...s?{opacity:s}:{},...Zr(r.mark)?r.mark:{type:r.mark},style:`${l}-${String(n)}`,...t.isBoolean(e[n])?{}:e[n]}}]:[]}function Ya(e,t,n){const{encoding:i}=e,r="vertical"===t?"y":"x",o=i[r],a=i[`${r}2`],s=i[`${r}Error`],l=i[`${r}Error2`];return{continuousAxisChannelDef:Xa(o,n),continuousAxisChannelDef2:Xa(a,n),continuousAxisChannelDefError:Xa(s,n),continuousAxisChannelDefError2:Xa(l,n),continuousAxis:r}}function Xa(e,t){if(e?.aggregate){const{aggregate:n,...i}=e;return n!==t&&$i(function(e,t){return`Continuous axis should not have customized aggregation function ${e}; ${t} already agregates the axis.`}(n,t)),i}return e}function Qa(e,t){const{mark:n,encoding:i}=e,{x:r,y:o}=i;if(Zr(n)&&n.orient)return n.orient;if(Yo(r)){if(Yo(o)){const e=Ho(r)&&r.aggregate,n=Ho(o)&&o.aggregate;if(e||n!==t){if(n||e!==t){if(e===t&&n===t)throw new Error("Both x and y cannot have aggregate");return $a(o)&&!$a(r)?"horizontal":"vertical"}return"horizontal"}return"vertical"}return"horizontal"}if(Yo(o))return"vertical";throw new Error(`Need a valid continuous axis for ${t}s`)}const Ja="boxplot",Ka=new ja(Ja,es);function Za(e){return t.isNumber(e)?"tukey":e}function es(e,n){let{config:i}=n;e={...e,encoding:Ra(e.encoding,i)};const{mark:r,encoding:o,params:a,projection:s,...l}=e,c=Zr(r)?r:{type:r};a&&$i(Jn("boxplot"));const u=c.extent??i.boxplot.extent,d=Cn("size",c,i),m=c.invalid,p=Za(u),{bins:g,timeUnits:h,transform:y,continuousAxisChannelDef:v,continuousAxis:b,groupby:x,aggregate:$,encodingWithoutContinuousAxis:w,ticksOrient:k,boxOrient:D,customTooltipWithoutAggregatedField:F}=function(e,n,i){const r=Qa(e,Ja),{continuousAxisChannelDef:o,continuousAxis:a}=Ya(e,r,Ja),s=o.field,l=Za(n),c=[...ts(s),{op:"median",field:s,as:`mid_box_${s}`},{op:"min",field:s,as:("min-max"===l?"lower_whisker_":"min_")+s},{op:"max",field:s,as:("min-max"===l?"upper_whisker_":"max_")+s}],u="min-max"===l||"tukey"===l?[]:[{calculate:`datum["upper_box_${s}"] - datum["lower_box_${s}"]`,as:`iqr_${s}`},{calculate:`min(datum["upper_box_${s}"] + datum["iqr_${s}"] * ${n}, datum["max_${s}"])`,as:`upper_whisker_${s}`},{calculate:`max(datum["lower_box_${s}"] - datum["iqr_${s}"] * ${n}, datum["min_${s}"])`,as:`lower_whisker_${s}`}],{[a]:f,...d}=e.encoding,{customTooltipWithoutAggregatedField:m,filteredEncoding:p}=function(e){const{tooltip:n,...i}=e;if(!n)return{filteredEncoding:i};let r,o;if(t.isArray(n)){for(const e of n)e.aggregate?(r||(r=[]),r.push(e)):(o||(o=[]),o.push(e));r&&(i.tooltip=r)}else n.aggregate?i.tooltip=n:o=n;return t.isArray(o)&&1===o.length&&(o=o[0]),{customTooltipWithoutAggregatedField:o,filteredEncoding:i}}(d),{bins:g,timeUnits:h,aggregate:y,groupby:v,encoding:b}=qa(p,i),x="vertical"===r?"horizontal":"vertical",$=r,w=[...g,...h,{aggregate:[...y,...c],groupby:v},...u];return{bins:g,timeUnits:h,transform:w,groupby:v,aggregate:y,continuousAxisChannelDef:o,continuousAxis:a,encodingWithoutContinuousAxis:b,ticksOrient:x,boxOrient:$,customTooltipWithoutAggregatedField:m}}(e,u,i),{color:z,size:O,..._}=w,N=e=>Va(c,b,v,e,i.boxplot),C=N(_),P=N(w),A=N({..._,...O?{size:O}:{}}),j=Ia([{fieldPrefix:"min-max"===p?"upper_whisker_":"max_",titlePrefix:"Max"},{fieldPrefix:"upper_box_",titlePrefix:"Q3"},{fieldPrefix:"mid_box_",titlePrefix:"Median"},{fieldPrefix:"lower_box_",titlePrefix:"Q1"},{fieldPrefix:"min-max"===p?"lower_whisker_":"min_",titlePrefix:"Min"}],v,w),T={type:"tick",color:"black",opacity:1,orient:k,invalid:m,aria:!1},E="min-max"===p?j:Ia([{fieldPrefix:"upper_whisker_",titlePrefix:"Upper Whisker"},{fieldPrefix:"lower_whisker_",titlePrefix:"Lower Whisker"}],v,w),M=[...C({partName:"rule",mark:{type:"rule",invalid:m,aria:!1},positionPrefix:"lower_whisker",endPositionPrefix:"lower_box",extraEncoding:E}),...C({partName:"rule",mark:{type:"rule",invalid:m,aria:!1},positionPrefix:"upper_box",endPositionPrefix:"upper_whisker",extraEncoding:E}),...C({partName:"ticks",mark:T,positionPrefix:"lower_whisker",extraEncoding:E}),...C({partName:"ticks",mark:T,positionPrefix:"upper_whisker",extraEncoding:E})],L=[..."tukey"!==p?M:[],...P({partName:"box",mark:{type:"bar",...d?{size:d}:{},orient:D,invalid:m,ariaRoleDescription:"box"},positionPrefix:"lower_box",endPositionPrefix:"upper_box",extraEncoding:j}),...A({partName:"median",mark:{type:"tick",invalid:m,...t.isObject(i.boxplot.median)&&i.boxplot.median.color?{color:i.boxplot.median.color}:{},...d?{size:d}:{},orient:k,aria:!1},positionPrefix:"mid_box",extraEncoding:j})];if("min-max"===p)return{...l,transform:(l.transform??[]).concat(y),layer:L};const q=`datum["lower_box_${v.field}"]`,U=`datum["upper_box_${v.field}"]`,R=`(${U} - ${q})`,W=`${q} - ${u} * ${R}`,B=`${U} + ${u} * ${R}`,I=`datum["${v.field}"]`,H={joinaggregate:ts(v.field),groupby:x},V={transform:[{filter:`(${W} <= ${I}) && (${I} <= ${B})`},{aggregate:[{op:"min",field:v.field,as:`lower_whisker_${v.field}`},{op:"max",field:v.field,as:`upper_whisker_${v.field}`},{op:"min",field:`lower_box_${v.field}`,as:`lower_box_${v.field}`},{op:"max",field:`upper_box_${v.field}`,as:`upper_box_${v.field}`},...$],groupby:x}],layer:M},{tooltip:G,...Y}=_,{scale:X,axis:Q}=v,J=Ha(v),K=f(Q,["title"]),Z=Ga(c,"outliers",i.boxplot,{transform:[{filter:`(${I} < ${W}) || (${I} > ${B})`}],mark:"point",encoding:{[b]:{field:v.field,type:v.type,...void 0!==J?{title:J}:{},...void 0!==X?{scale:X}:{},...S(K)?{}:{axis:K}},...Y,...z?{color:z}:{},...F?{tooltip:F}:{}}})[0];let ee;const te=[...g,...h,H];return Z?ee={transform:te,layer:[Z,V]}:(ee=V,ee.transform.unshift(...te)),{...l,layer:[ee,{transform:y,layer:L}]}}function ts(e){return[{op:"q1",field:e,as:`lower_box_${e}`},{op:"q3",field:e,as:`upper_box_${e}`}]}const ns="errorbar",is=new ja(ns,rs);function rs(e,t){let{config:n}=t;e={...e,encoding:Ra(e.encoding,n)};const{transform:i,continuousAxisChannelDef:r,continuousAxis:o,encodingWithoutContinuousAxis:a,ticksOrient:s,markDef:l,outerSpec:c,tooltipEncoding:u}=as(e,ns,n);delete a.size;const f=Va(l,o,r,a,n.errorbar),d=l.thickness,m=l.size,p={type:"tick",orient:s,aria:!1,...void 0!==d?{thickness:d}:{},...void 0!==m?{size:m}:{}},g=[...f({partName:"ticks",mark:p,positionPrefix:"lower",extraEncoding:u}),...f({partName:"ticks",mark:p,positionPrefix:"upper",extraEncoding:u}),...f({partName:"rule",mark:{type:"rule",ariaRoleDescription:"errorbar",...void 0!==d?{size:d}:{}},positionPrefix:"lower",endPositionPrefix:"upper",extraEncoding:u})];return{...c,transform:i,...g.length>1?{layer:g}:{...g[0]}}}function os(e,t){const{encoding:n}=e;if(function(e){return(Jo(e.x)||Jo(e.y))&&!Jo(e.x2)&&!Jo(e.y2)&&!Jo(e.xError)&&!Jo(e.xError2)&&!Jo(e.yError)&&!Jo(e.yError2)}(n))return{orient:Qa(e,t),inputType:"raw"};const i=function(e){return Jo(e.x2)||Jo(e.y2)}(n),r=function(e){return Jo(e.xError)||Jo(e.xError2)||Jo(e.yError)||Jo(e.yError2)}(n),o=n.x,a=n.y;if(i){if(r)throw new Error(`${t} cannot be both type aggregated-upper-lower and aggregated-error`);const e=n.x2,i=n.y2;if(Jo(e)&&Jo(i))throw new Error(`${t} cannot have both x2 and y2`);if(Jo(e)){if(Yo(o))return{orient:"horizontal",inputType:"aggregated-upper-lower"};throw new Error(`Both x and x2 have to be quantitative in ${t}`)}if(Jo(i)){if(Yo(a))return{orient:"vertical",inputType:"aggregated-upper-lower"};throw new Error(`Both y and y2 have to be quantitative in ${t}`)}throw new Error("No ranged axis")}{const e=n.xError,i=n.xError2,r=n.yError,s=n.yError2;if(Jo(i)&&!Jo(e))throw new Error(`${t} cannot have xError2 without xError`);if(Jo(s)&&!Jo(r))throw new Error(`${t} cannot have yError2 without yError`);if(Jo(e)&&Jo(r))throw new Error(`${t} cannot have both xError and yError with both are quantiative`);if(Jo(e)){if(Yo(o))return{orient:"horizontal",inputType:"aggregated-error"};throw new Error("All x, xError, and xError2 (if exist) have to be quantitative")}if(Jo(r)){if(Yo(a))return{orient:"vertical",inputType:"aggregated-error"};throw new Error("All y, yError, and yError2 (if exist) have to be quantitative")}throw new Error("No ranged axis")}}function as(e,t,n){const{mark:i,encoding:r,params:o,projection:a,...s}=e,l=Zr(i)?i:{type:i};o&&$i(Jn(t));const{orient:c,inputType:u}=os(e,t),{continuousAxisChannelDef:f,continuousAxisChannelDef2:d,continuousAxisChannelDefError:m,continuousAxisChannelDefError2:p,continuousAxis:g}=Ya(e,c,t),{errorBarSpecificAggregate:h,postAggregateCalculates:y,tooltipSummary:v,tooltipTitleWithFieldName:b}=function(e,t,n,i,r,o,a,s){let l=[],c=[];const u=t.field;let f,d=!1;if("raw"===o){const t=e.center?e.center:e.extent?"iqr"===e.extent?"median":"mean":s.errorbar.center,n=e.extent?e.extent:"mean"===t?"stderr":"iqr";if("median"===t!=("iqr"===n)&&$i(function(e,t,n){return`${e} is not usually used with ${t} for ${n}.`}(t,n,a)),"stderr"===n||"stdev"===n)l=[{op:n,field:u,as:`extent_${u}`},{op:t,field:u,as:`center_${u}`}],c=[{calculate:`datum["center_${u}"] + datum["extent_${u}"]`,as:`upper_${u}`},{calculate:`datum["center_${u}"] - datum["extent_${u}"]`,as:`lower_${u}`}],f=[{fieldPrefix:"center_",titlePrefix:P(t)},{fieldPrefix:"upper_",titlePrefix:ss(t,n,"+")},{fieldPrefix:"lower_",titlePrefix:ss(t,n,"-")}],d=!0;else{let e,t,i;"ci"===n?(e="mean",t="ci0",i="ci1"):(e="median",t="q1",i="q3"),l=[{op:t,field:u,as:`lower_${u}`},{op:i,field:u,as:`upper_${u}`},{op:e,field:u,as:`center_${u}`}],f=[{fieldPrefix:"upper_",titlePrefix:ua({field:u,aggregate:i,type:"quantitative"},s,{allowDisabling:!1})},{fieldPrefix:"lower_",titlePrefix:ua({field:u,aggregate:t,type:"quantitative"},s,{allowDisabling:!1})},{fieldPrefix:"center_",titlePrefix:ua({field:u,aggregate:e,type:"quantitative"},s,{allowDisabling:!1})}]}}else{(e.center||e.extent)&&$i((m=e.center,`${(p=e.extent)?"extent ":""}${p&&m?"and ":""}${m?"center ":""}${p&&m?"are ":"is "}not needed when data are aggregated.`)),"aggregated-upper-lower"===o?(f=[],c=[{calculate:`datum["${n.field}"]`,as:`upper_${u}`},{calculate:`datum["${u}"]`,as:`lower_${u}`}]):"aggregated-error"===o&&(f=[{fieldPrefix:"",titlePrefix:u}],c=[{calculate:`datum["${u}"] + datum["${i.field}"]`,as:`upper_${u}`}],r?c.push({calculate:`datum["${u}"] + datum["${r.field}"]`,as:`lower_${u}`}):c.push({calculate:`datum["${u}"] - datum["${i.field}"]`,as:`lower_${u}`}));for(const e of c)f.push({fieldPrefix:e.as.substring(0,6),titlePrefix:M(M(e.calculate,'datum["',""),'"]',"")})}var m,p;return{postAggregateCalculates:c,errorBarSpecificAggregate:l,tooltipSummary:f,tooltipTitleWithFieldName:d}}(l,f,d,m,p,u,t,n),{[g]:x,["x"===g?"x2":"y2"]:$,["x"===g?"xError":"yError"]:w,["x"===g?"xError2":"yError2"]:k,...S}=r,{bins:D,timeUnits:F,aggregate:z,groupby:O,encoding:_}=qa(S,n),N=[...z,...h],C="raw"!==u?[]:O,A=Ia(v,f,_,b);return{transform:[...s.transform??[],...D,...F,...0===N.length?[]:[{aggregate:N,groupby:C}],...y],groupby:C,continuousAxisChannelDef:f,continuousAxis:g,encodingWithoutContinuousAxis:_,ticksOrient:"vertical"===c?"horizontal":"vertical",markDef:l,outerSpec:s,tooltipEncoding:A}}function ss(e,t,n){return`${P(e)} ${n} ${t}`}const ls="errorband",cs=new ja(ls,us);function us(e,t){let{config:n}=t;e={...e,encoding:Ra(e.encoding,n)};const{transform:i,continuousAxisChannelDef:r,continuousAxis:o,encodingWithoutContinuousAxis:a,markDef:s,outerSpec:l,tooltipEncoding:c}=as(e,ls,n),u=s,f=Va(u,o,r,a,n.errorband),d=void 0!==e.encoding.x&&void 0!==e.encoding.y;let m={type:d?"area":"rect"},p={type:d?"line":"rule"};const g={...u.interpolate?{interpolate:u.interpolate}:{},...u.tension&&u.interpolate?{tension:u.tension}:{}};return d?(m={...m,...g,ariaRoleDescription:"errorband"},p={...p,...g,aria:!1}):u.interpolate?$i(yi("interpolate")):u.tension&&$i(yi("tension")),{...l,transform:i,layer:[...f({partName:"band",mark:m,positionPrefix:"lower",endPositionPrefix:"upper",extraEncoding:c}),...f({partName:"borders",mark:p,positionPrefix:"lower",extraEncoding:c}),...f({partName:"borders",mark:p,positionPrefix:"upper",extraEncoding:c})]}}const fs={};function ds(e,t,n){const i=new ja(e,t);fs[e]={normalizer:i,parts:n}}ds(Ja,es,["box","median","outliers","rule","ticks"]),ds(ns,rs,["ticks","rule"]),ds(ls,us,["band","borders"]);const ms=["gradientHorizontalMaxLength","gradientHorizontalMinLength","gradientVerticalMaxLength","gradientVerticalMinLength","unselectedOpacity"],ps={titleAlign:"align",titleAnchor:"anchor",titleAngle:"angle",titleBaseline:"baseline",titleColor:"color",titleFont:"font",titleFontSize:"fontSize",titleFontStyle:"fontStyle",titleFontWeight:"fontWeight",titleLimit:"limit",titleLineHeight:"lineHeight",titleOrient:"orient",titlePadding:"offset"},gs={labelAlign:"align",labelAnchor:"anchor",labelAngle:"angle",labelBaseline:"baseline",labelColor:"color",labelFont:"font",labelFontSize:"fontSize",labelFontStyle:"fontStyle",labelFontWeight:"fontWeight",labelLimit:"limit",labelLineHeight:"lineHeight",labelOrient:"orient",labelPadding:"offset"},hs=D(ps),ys=D(gs),vs=D({header:1,headerRow:1,headerColumn:1,headerFacet:1}),bs=["size","shape","fill","stroke","strokeDash","strokeWidth","opacity"],xs="_vgsid_",$s={point:{on:"click",fields:[xs],toggle:"event.shiftKey",resolve:"global",clear:"dblclick"},interval:{on:"[pointerdown, window:pointerup] > window:pointermove!",encodings:["x","y"],translate:"[pointerdown, window:pointerup] > window:pointermove!",zoom:"wheel!",mark:{fill:"#333",fillOpacity:.125,stroke:"white"},resolve:"global",clear:"dblclick"}};function ws(e){return"legend"===e||!!e?.legend}function ks(e){return ws(e)&&t.isObject(e)}function Ss(e){return!!e?.select}function Ds(e){const t=[];for(const n of e||[]){if(Ss(n))continue;const{expr:e,bind:i,...r}=n;if(i&&e){const n={...r,bind:i,init:e};t.push(n)}else{const n={...r,...e?{update:e}:{},...i?{bind:i}:{}};t.push(n)}}return t}function Fs(e){return"concat"in e}function zs(e){return"vconcat"in e}function Os(e){return"hconcat"in e}function _s(e){let{step:t,offsetIsDiscrete:n}=e;return n?t.for??"offset":"position"}function Ns(e){return t.isObject(e)&&void 0!==e.step}function Cs(e){return e.view||e.width||e.height}const Ps=D({align:1,bounds:1,center:1,columns:1,spacing:1});function As(e,t){return e[t]??e["width"===t?"continuousWidth":"continuousHeight"]}function js(e,t){const n=Ts(e,t);return Ns(n)?n.step:Es}function Ts(e,t){return U(e[t]??e["width"===t?"discreteWidth":"discreteHeight"],{step:e.step})}const Es=20,Ms={background:"white",padding:5,timeFormat:"%b %d, %Y",countTitle:"Count of Records",view:{continuousWidth:200,continuousHeight:200,step:Es},mark:{color:"#4c78a8",invalid:"filter",timeUnitBandSize:1},arc:{},area:{},bar:oo,circle:{},geoshape:{},image:{},line:{},point:{},rect:ao,rule:{color:"black"},square:{},text:{color:"black"},tick:{thickness:1},trail:{},boxplot:{size:14,extent:1.5,box:{},median:{color:"white"},outliers:{},rule:{},ticks:null},errorbar:{center:"mean",rule:!0,ticks:!1},errorband:{band:{opacity:.3},borders:!1},scale:{pointPadding:.5,barBandPaddingInner:.1,rectBandPaddingInner:0,bandWithNestedOffsetPaddingInner:.2,bandWithNestedOffsetPaddingOuter:.2,minBandSize:2,minFontSize:8,maxFontSize:40,minOpacity:.3,maxOpacity:.8,minSize:9,minStrokeWidth:1,maxStrokeWidth:4,quantileCount:4,quantizeCount:4,zero:!0},projection:{},legend:{gradientHorizontalMaxLength:200,gradientHorizontalMinLength:100,gradientVerticalMaxLength:200,gradientVerticalMinLength:64,unselectedOpacity:.35},header:{titlePadding:10,labelPadding:10},headerColumn:{},headerRow:{},headerFacet:{},selection:$s,style:{},title:{},facet:{spacing:20},concat:{spacing:20},normalizedNumberFormat:".0%"},Ls=["#4c78a8","#f58518","#e45756","#72b7b2","#54a24b","#eeca3b","#b279a2","#ff9da6","#9d755d","#bab0ac"],qs={text:11,guideLabel:10,guideTitle:11,groupTitle:13,groupSubtitle:12},Us={blue:Ls[0],orange:Ls[1],red:Ls[2],teal:Ls[3],green:Ls[4],yellow:Ls[5],purple:Ls[6],pink:Ls[7],brown:Ls[8],gray0:"#000",gray1:"#111",gray2:"#222",gray3:"#333",gray4:"#444",gray5:"#555",gray6:"#666",gray7:"#777",gray8:"#888",gray9:"#999",gray10:"#aaa",gray11:"#bbb",gray12:"#ccc",gray13:"#ddd",gray14:"#eee",gray15:"#fff"};function Rs(e){const t=D(e||{}),n={};for(const i of t){const t=e[i];n[i]=Fa(t)?kn(t):Sn(t)}return n}const Ws=[...no,...Pa,...vs,"background","padding","legend","lineBreak","scale","style","title","view"];function Bs(){let e=arguments.length>0&&void 0!==arguments[0]?arguments[0]:{};const{color:n,font:i,fontSize:r,selection:o,...a}=e,s=t.mergeConfig({},l(Ms),i?function(e){return{text:{font:e},style:{"guide-label":{font:e},"guide-title":{font:e},"group-title":{font:e},"group-subtitle":{font:e}}}}(i):{},n?function(){let e=arguments.length>0&&void 0!==arguments[0]?arguments[0]:{};return{signals:[{name:"color",value:t.isObject(e)?{...Us,...e}:Us}],mark:{color:{signal:"color.blue"}},rule:{color:{signal:"color.gray0"}},text:{color:{signal:"color.gray0"}},style:{"guide-label":{fill:{signal:"color.gray0"}},"guide-title":{fill:{signal:"color.gray0"}},"group-title":{fill:{signal:"color.gray0"}},"group-subtitle":{fill:{signal:"color.gray0"}},cell:{stroke:{signal:"color.gray8"}}},axis:{domainColor:{signal:"color.gray13"},gridColor:{signal:"color.gray8"},tickColor:{signal:"color.gray13"}},range:{category:[{signal:"color.blue"},{signal:"color.orange"},{signal:"color.red"},{signal:"color.teal"},{signal:"color.green"},{signal:"color.yellow"},{signal:"color.purple"},{signal:"color.pink"},{signal:"color.brown"},{signal:"color.grey8"}]}}}(n):{},r?function(e){return{signals:[{name:"fontSize",value:t.isObject(e)?{...qs,...e}:qs}],text:{fontSize:{signal:"fontSize.text"}},style:{"guide-label":{fontSize:{signal:"fontSize.guideLabel"}},"guide-title":{fontSize:{signal:"fontSize.guideTitle"}},"group-title":{fontSize:{signal:"fontSize.groupTitle"}},"group-subtitle":{fontSize:{signal:"fontSize.groupSubtitle"}}}}}(r):{},a||{});o&&t.writeConfig(s,"selection",o,!0);const c=f(s,Ws);for(const e of["background","lineBreak","padding"])s[e]&&(c[e]=Sn(s[e]));for(const e of no)s[e]&&(c[e]=pn(s[e]));for(const e of Pa)s[e]&&(c[e]=Rs(s[e]));for(const e of vs)s[e]&&(c[e]=pn(s[e]));return s.legend&&(c.legend=pn(s.legend)),s.scale&&(c.scale=pn(s.scale)),s.style&&(c.style=function(e){const t=D(e),n={};for(const i of t)n[i]=Rs(e[i]);return n}(s.style)),s.title&&(c.title=pn(s.title)),s.view&&(c.view=pn(s.view)),c}const Is=new Set(["view",...Kr]),Hs=["color","fontSize","background","padding","facet","concat","numberFormat","numberFormatType","normalizedNumberFormat","normalizedNumberFormatType","timeFormat","countTitle","header","axisQuantitative","axisTemporal","axisDiscrete","axisPoint","axisXBand","axisXPoint","axisXDiscrete","axisXQuantitative","axisXTemporal","axisYBand","axisYPoint","axisYDiscrete","axisYQuantitative","axisYTemporal","scale","selection","overlay"],Vs={view:["continuousWidth","continuousHeight","discreteWidth","discreteHeight","step"],area:["line","point"],bar:["binSpacing","continuousBandSize","discreteBandSize","minBandSize"],rect:["binSpacing","continuousBandSize","discreteBandSize","minBandSize"],line:["point"],tick:["bandSize","thickness"]};function Gs(e){e=l(e);for(const t of Hs)delete e[t];if(e.axis)for(const t in e.axis)Fa(e.axis[t])&&delete e.axis[t];if(e.legend)for(const t of ms)delete e.legend[t];if(e.mark){for(const t of to)delete e.mark[t];e.mark.tooltip&&t.isObject(e.mark.tooltip)&&delete e.mark.tooltip}e.params&&(e.signals=(e.signals||[]).concat(Ds(e.params)),delete e.params);for(const t of Is){for(const n of to)delete e[t][n];const n=Vs[t];if(n)for(const i of n)delete e[t][i];Ys(e,t)}for(const t of D(fs))delete e[t];!function(e){const{titleMarkConfig:t,subtitleMarkConfig:n,subtitle:i}=gn(e.title);S(t)||(e.style["group-title"]={...e.style["group-title"],...t});S(n)||(e.style["group-subtitle"]={...e.style["group-subtitle"],...n});S(i)?delete e.title:e.title=i}(e);for(const n in e)t.isObject(e[n])&&S(e[n])&&delete e[n];return S(e)?void 0:e}function Ys(e,t,n,i){"view"===t&&(n="cell");const r={...i?e[t][i]:e[t],...e.style[n??t]};S(r)||(e.style[n??t]=r),i||delete e[t]}function Xs(e){return"layer"in e}class Qs{map(e,t){return To(e)?this.mapFacet(e,t):function(e){return"repeat"in e}(e)?this.mapRepeat(e,t):Os(e)?this.mapHConcat(e,t):zs(e)?this.mapVConcat(e,t):Fs(e)?this.mapConcat(e,t):this.mapLayerOrUnit(e,t)}mapLayerOrUnit(e,t){if(Xs(e))return this.mapLayer(e,t);if(Aa(e))return this.mapUnit(e,t);throw new Error(Bn(e))}mapLayer(e,t){return{...e,layer:e.layer.map((e=>this.mapLayerOrUnit(e,t)))}}mapHConcat(e,t){return{...e,hconcat:e.hconcat.map((e=>this.map(e,t)))}}mapVConcat(e,t){return{...e,vconcat:e.vconcat.map((e=>this.map(e,t)))}}mapConcat(e,t){const{concat:n,...i}=e;return{...i,concat:n.map((e=>this.map(e,t)))}}mapFacet(e,t){return{...e,spec:this.map(e.spec,t)}}mapRepeat(e,t){return{...e,spec:this.map(e.spec,t)}}}const Js={zero:1,center:1,normalize:1};const Ks=new Set([Er,Lr,Mr,Br,Rr,Gr,Yr,Ur,Ir,Hr]),Zs=new Set([Lr,Mr,Er]);function el(e){return Ho(e)&&"quantitative"===Vo(e)&&!e.bin}function tl(e,t,n){let{orient:i,type:r}=n;const o="x"===t?"y":"radius",a="x"===t&&["bar","area"].includes(r),s=e[t],l=e[o];if(Ho(s)&&Ho(l))if(el(s)&&el(l)){if(s.stack)return t;if(l.stack)return o;const e=Ho(s)&&!!s.aggregate;if(e!==(Ho(l)&&!!l.aggregate))return e?t:o;if(a){if("vertical"===i)return o;if("horizontal"===i)return t}}else{if(el(s))return t;if(el(l))return o}else{if(el(s)){if(a&&"vertical"===i)return;return t}if(el(l)){if(a&&"horizontal"===i)return;return o}}}function nl(e,n){const i=Zr(e)?e:{type:e},r=i.type;if(!Ks.has(r))return null;const o=tl(n,"x",i)||tl(n,"theta",i);if(!o)return null;const a=n[o],s=Ho(a)?oa(a,{}):void 0,l=function(e){switch(e){case"x":return"y";case"y":return"x";case"theta":return"radius";case"radius":return"theta"}}(o),c=[],u=new Set;if(n[l]){const e=n[l],t=Ho(e)?oa(e,{}):void 0;t&&t!==s&&(c.push(l),u.add(t))}const f="x"===l?"xOffset":"yOffset",d=n[f],m=Ho(d)?oa(d,{}):void 0;m&&m!==s&&(c.push(f),u.add(m));const p=St.reduce(((e,i)=>{if("tooltip"!==i&&Ta(n,i)){const r=n[i];for(const n of t.array(r)){const t=pa(n);if(t.aggregate)continue;const r=oa(t,{});r&&u.has(r)||e.push({channel:i,fieldDef:t})}}return e}),[]);let g;return void 0!==a.stack?g=t.isBoolean(a.stack)?a.stack?"zero":null:a.stack:Zs.has(r)&&(g="zero"),g&&g in Js?La(n)&&0===p.length?null:a?.scale?.type&&a?.scale?.type!==cr.LINEAR?(a?.stack&&$i(function(e){return`Cannot stack non-linear scale (${e}).`}(a.scale.type)),null):Jo(n[it(o)])?(void 0!==a.stack&&$i(`Cannot stack "${h=o}" if there is already "${h}2".`),null):(Ho(a)&&a.aggregate&&!on.has(a.aggregate)&&$i(`Stacking is applied even though the aggregate function is non-summative ("${a.aggregate}").`),{groupbyChannels:c,groupbyFields:u,fieldChannel:o,impute:null!==a.impute&&Qr(r),stackBy:p,offset:g}):null;var h}function il(e,t,n){const i=pn(e),r=Cn("orient",i,n);if(i.orient=function(e,t,n){switch(e){case Rr:case Gr:case Yr:case Ir:case Wr:case qr:return}const{x:i,y:r,x2:o,y2:a}=t;switch(e){case Lr:if(Ho(i)&&(cn(i.bin)||Ho(r)&&r.aggregate&&!i.aggregate))return"vertical";if(Ho(r)&&(cn(r.bin)||Ho(i)&&i.aggregate&&!r.aggregate))return"horizontal";if(a||o){if(n)return n;if(!o)return(Ho(i)&&i.type===rr&&!ln(i.bin)||Qo(i))&&Ho(r)&&cn(r.bin)?"horizontal":"vertical";if(!a)return(Ho(r)&&r.type===rr&&!ln(r.bin)||Qo(r))&&Ho(i)&&cn(i.bin)?"vertical":"horizontal"}case Br:if(o&&(!Ho(i)||!cn(i.bin))&&a&&(!Ho(r)||!cn(r.bin)))return;case Mr:if(a)return Ho(r)&&cn(r.bin)?"horizontal":"vertical";if(o)return Ho(i)&&cn(i.bin)?"vertical":"horizontal";if(e===Br){if(i&&!r)return"vertical";if(r&&!i)return"horizontal"}case Ur:case Hr:{const t=Xo(i),o=Xo(r);if(n)return n;if(t&&!o)return"tick"!==e?"horizontal":"vertical";if(!t&&o)return"tick"!==e?"vertical":"horizontal";if(t&&o)return"vertical";{const e=Ko(i)&&i.type===ar,t=Ko(r)&&r.type===ar;if(e&&!t)return"vertical";if(!e&&t)return"horizontal"}return}}return"vertical"}(i.type,t,r),void 0!==r&&r!==i.orient&&$i(`Specified orient "${i.orient}" overridden with "${r}".`),"bar"===i.type&&i.orient){const e=Cn("cornerRadiusEnd",i,n);if(void 0!==e){const n="horizontal"===i.orient&&t.x2||"vertical"===i.orient&&t.y2?["cornerRadius"]:ro[i.orient];for(const t of n)i[t]=e;void 0!==i.cornerRadiusEnd&&delete i.cornerRadiusEnd}}void 0===Cn("opacity",i,n)&&(i.opacity=function(e,t){if(p([Rr,Hr,Gr,Yr],e)&&!La(t))return.7;return}(i.type,t));return void 0===Cn("cursor",i,n)&&(i.cursor=function(e,t,n){if(t.href||e.href||Cn("href",e,n))return"pointer";return e.cursor}(i,t,n)),i}function rl(e){const{point:t,line:n,...i}=e;return D(i).length>1?i:i.type}function ol(e){for(const t of["line","area","rule","trail"])e[t]&&(e={...e,[t]:f(e[t],["point","line"])});return e}function al(e){let n=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{},i=arguments.length>2?arguments[2]:void 0;return"transparent"===e.point?{opacity:0}:e.point?t.isObject(e.point)?e.point:{}:void 0!==e.point?null:n.point||i.shape?t.isObject(n.point)?n.point:{}:void 0}function sl(e){let t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{};return e.line?!0===e.line?{}:e.line:void 0!==e.line?null:t.line?!0===t.line?{}:t.line:void 0}class ll{constructor(){qn(this,"name","path-overlay")}hasMatchingType(e,t){if(Aa(e)){const{mark:n,encoding:i}=e,r=Zr(n)?n:{type:n};switch(r.type){case"line":case"rule":case"trail":return!!al(r,t[r.type],i);case"area":return!!al(r,t[r.type],i)||!!sl(r,t[r.type])}}return!1}run(e,t,n){const{config:i}=t,{params:r,projection:o,mark:a,name:s,encoding:l,...c}=e,d=Ra(l,i),m=Zr(a)?a:{type:a},p=al(m,i[m.type],d),g="area"===m.type&&sl(m,i[m.type]),h=[{name:s,...r?{params:r}:{},mark:rl({..."area"===m.type&&void 0===m.opacity&&void 0===m.fillOpacity?{opacity:.7}:{},...m}),encoding:f(d,["shape"])}],y=nl(il(m,d,i),d);let v=d;if(y){const{fieldChannel:e,offset:t}=y;v={...d,[e]:{...d[e],...t?{stack:t}:{}}}}return v=f(v,["y2","x2"]),g&&h.push({...o?{projection:o}:{},mark:{type:"line",...u(m,["clip","interpolate","tension","tooltip"]),...g},encoding:v}),p&&h.push({...o?{projection:o}:{},mark:{type:"point",opacity:1,filled:!0,...u(m,["clip","tooltip"]),...p},encoding:v}),n({...c,layer:h},{...t,config:ol(i)})}}function cl(e,t){return t?Ao(e)?gl(e,t):dl(e,t):e}function ul(e,t){return t?gl(e,t):e}function fl(e,n,i){const r=n[e];return(o=r)&&!t.isString(o)&&"repeat"in o?r.repeat in i?{...n,[e]:i[r.repeat]}:void $i(function(e){return`Unknown repeated value "${e}".`}(r.repeat)):n;var o}function dl(e,t){if(void 0!==(e=fl("field",e,t))){if(null===e)return null;if(Mo(e)&&Co(e.sort)){const n=fl("field",e.sort,t);e={...e,...n?{sort:n}:{}}}return e}}function ml(e,t){if(Ho(e))return dl(e,t);{const n=fl("datum",e,t);return n===e||n.type||(n.type="nominal"),n}}function pl(e,t){if(!Jo(e)){if(Io(e)){const n=ml(e.condition,t);if(n)return{...e,condition:n};{const{condition:t,...n}=e;return n}}return e}{const n=ml(e,t);if(n)return n;if(Wo(e))return{condition:e.condition}}}function gl(e,n){const i={};for(const r in e)if(t.hasOwnProperty(e,r)){const o=e[r];if(t.isArray(o))i[r]=o.map((e=>pl(e,n))).filter((e=>e));else{const e=pl(o,n);void 0!==e&&(i[r]=e)}}return i}class hl{constructor(){qn(this,"name","RuleForRangedLine")}hasMatchingType(e){if(Aa(e)){const{encoding:t,mark:n}=e;if("line"===n||Zr(n)&&"line"===n.type)for(const e of Ze){const n=t[tt(e)];if(t[e]&&(Ho(n)&&!cn(n.bin)||Go(n)))return!0}}return!1}run(e,n,i){const{encoding:r,mark:o}=e;var a,s;return $i((a=!!r.x2,s=!!r.y2,`Line mark is for continuous lines and thus cannot be used with ${a&&s?"x2 and y2":a?"x2":"y2"}. We will use the rule mark (line segments) instead.`)),i({...e,mark:t.isObject(o)?{...o,type:"rule"}:"rule"},n)}}function yl(e){let{parentEncoding:n,encoding:i={},layer:r}=e,o={};if(n){const e=new Set([...D(n),...D(i)]);for(const a of e){const e=i[a],s=n[a];if(Jo(e)){const t={...s,...e};o[a]=t}else Io(e)?o[a]={...e,condition:{...s,...e.condition}}:e||null===e?o[a]=e:(r||Zo(s)||yn(s)||Jo(s)||t.isArray(s))&&(o[a]=s)}}else o=i;return!o||S(o)?void 0:o}function vl(e){const{parentProjection:t,projection:n}=e;return t&&n&&$i(function(e){const{parentProjection:t,projection:n}=e;return`Layer's shared projection ${X(t)} is overridden by a child projection ${X(n)}.`}({parentProjection:t,projection:n})),n??t}function bl(e){return"filter"in e}function xl(e){return"lookup"in e}function $l(e){return"pivot"in e}function wl(e){return"density"in e}function kl(e){return"quantile"in e}function Sl(e){return"regression"in e}function Dl(e){return"loess"in e}function Fl(e){return"sample"in e}function zl(e){return"window"in e}function Ol(e){return"joinaggregate"in e}function _l(e){return"flatten"in e}function Nl(e){return"calculate"in e}function Cl(e){return"bin"in e}function Pl(e){return"impute"in e}function Al(e){return"timeUnit"in e}function jl(e){return"aggregate"in e}function Tl(e){return"stack"in e}function El(e){return"fold"in e}function Ml(e){return"extent"in e&&!("density"in e)}function Ll(e,t){const{transform:n,...i}=e;if(n){return{...i,transform:n.map((e=>{if(bl(e))return{filter:Rl(e,t)};if(Cl(e)&&un(e.bin))return{...e,bin:Ul(e.bin)};if(xl(e)){const{selection:t,...n}=e.from;return t?{...e,from:{param:t,...n}}:e}return e}))}}return e}function ql(e,n){const i=l(e);if(Ho(i)&&un(i.bin)&&(i.bin=Ul(i.bin)),ea(i)&&i.scale?.domain?.selection){const{selection:e,...t}=i.scale.domain;i.scale.domain={...t,...e?{param:e}:{}}}if(Wo(i))if(t.isArray(i.condition))i.condition=i.condition.map((e=>{const{selection:t,param:i,test:r,...o}=e;return i?e:{...o,test:Rl(e,n)}}));else{const{selection:e,param:t,test:r,...o}=ql(i.condition,n);i.condition=t?i.condition:{...o,test:Rl(i.condition,n)}}return i}function Ul(e){const t=e.extent;if(t?.selection){const{selection:n,...i}=t;return{...e,extent:{...i,param:n}}}return e}function Rl(e,t){const n=e=>s(e,(e=>{const n={param:e,empty:t.emptySelections[e]??!0};return t.selectionPredicates[e]??=[],t.selectionPredicates[e].push(n),n}));return e.selection?n(e.selection):s(e.test||e.filter,(e=>e.selection?n(e.selection):e))}class Wl extends Qs{map(e,t){const n=t.selections??[];if(e.params&&!Aa(e)){const t=[];for(const i of e.params)Ss(i)?n.push(i):t.push(i);e.params=t}return t.selections=n,super.map(e,t)}mapUnit(e,n){const i=n.selections;if(!i||!i.length)return e;const r=(n.path??[]).concat(e.name),o=[];for(const n of i)if(n.views&&n.views.length)for(const i of n.views)(t.isString(i)&&(i===e.name||r.includes(i))||t.isArray(i)&&i.map((e=>r.indexOf(e))).every(((e,t,n)=>-1!==e&&(0===t||e>n[t-1]))))&&o.push(n);else o.push(n);return o.length&&(e.params=o),e}}for(const e of["mapFacet","mapRepeat","mapHConcat","mapVConcat","mapLayer"]){const t=Wl.prototype[e];Wl.prototype[e]=function(e,n){return t.call(this,e,Bl(e,n))}}function Bl(e,t){return e.name?{...t,path:(t.path??[]).concat(e.name)}:t}function Il(e,t){void 0===t&&(t=Bs(e.config));const n=function(e){let t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{};const n={config:t};return Gl.map(Hl.map(Vl.map(e,n),n),n)}(e,t),{width:i,height:r}=e,o=function(e,t,n){let{width:i,height:r}=t;const o=Aa(e)||Xs(e),a={};o?"container"==i&&"container"==r?(a.type="fit",a.contains="padding"):"container"==i?(a.type="fit-x",a.contains="padding"):"container"==r&&(a.type="fit-y",a.contains="padding"):("container"==i&&($i(Hn("width")),i=void 0),"container"==r&&($i(Hn("height")),r=void 0));const s={type:"pad",...a,...n?Yl(n.autosize):{},...Yl(e.autosize)};"fit"!==s.type||o||($i(In),s.type="pad");"container"==i&&"fit"!=s.type&&"fit-x"!=s.type&&$i(Vn("width"));"container"==r&&"fit"!=s.type&&"fit-y"!=s.type&&$i(Vn("height"));if(Y(s,{type:"pad"}))return;return s}(n,{width:i,height:r,autosize:e.autosize},t);return{...n,...o?{autosize:o}:{}}}const Hl=new class extends Qs{constructor(){super(...arguments),qn(this,"nonFacetUnitNormalizers",[Ka,is,cs,new ll,new hl])}map(e,t){if(Aa(e)){const n=Ta(e.encoding,Q),i=Ta(e.encoding,J),r=Ta(e.encoding,K);if(n||i||r)return this.mapFacetedUnit(e,t)}return super.map(e,t)}mapUnit(e,t){const{parentEncoding:n,parentProjection:i}=t,r=ul(e.encoding,t.repeater),o={...e,...e.name?{name:[t.repeaterPrefix,e.name].filter((e=>e)).join("_")}:{},...r?{encoding:r}:{}};if(n||i)return this.mapUnitWithParentEncodingOrProjection(o,t);const a=this.mapLayerOrUnit.bind(this);for(const e of this.nonFacetUnitNormalizers)if(e.hasMatchingType(o,t.config))return e.run(o,t,a);return o}mapRepeat(e,n){return function(e){return!t.isArray(e.repeat)&&e.repeat.layer}(e)?this.mapLayerRepeat(e,n):this.mapNonLayerRepeat(e,n)}mapLayerRepeat(e,t){const{repeat:n,spec:i,...r}=e,{row:o,column:a,layer:s}=n,{repeater:l={},repeaterPrefix:c=""}=t;return o||a?this.mapRepeat({...e,repeat:{...o?{row:o}:{},...a?{column:a}:{}},spec:{repeat:{layer:s},spec:i}},t):{...r,layer:s.map((e=>{const n={...l,layer:e},r=`${(i.name?`${i.name}_`:"")+c}child__layer_${_(e)}`,o=this.mapLayerOrUnit(i,{...t,repeater:n,repeaterPrefix:r});return o.name=r,o}))}}mapNonLayerRepeat(e,n){const{repeat:i,spec:r,data:o,...a}=e;!t.isArray(i)&&e.columns&&(e=f(e,["columns"]),$i(Zn("repeat")));const s=[],{repeater:l={},repeaterPrefix:c=""}=n,u=!t.isArray(i)&&i.row||[l?l.row:null],d=!t.isArray(i)&&i.column||[l?l.column:null],m=t.isArray(i)&&i||[l?l.repeat:null];for(const e of m)for(const o of u)for(const a of d){const u={repeat:e,row:o,column:a,layer:l.layer},d=(r.name?`${r.name}_`:"")+c+"child__"+(t.isArray(i)?`${_(e)}`:(i.row?`row_${_(o)}`:"")+(i.column?`column_${_(a)}`:"")),m=this.map(r,{...n,repeater:u,repeaterPrefix:d});m.name=d,s.push(f(m,["data"]))}const p=t.isArray(i)?e.columns:i.column?i.column.length:1;return{data:r.data??o,align:"all",...a,columns:p,concat:s}}mapFacet(e,t){const{facet:n}=e;return Ao(n)&&e.columns&&(e=f(e,["columns"]),$i(Zn("facet"))),super.mapFacet(e,t)}mapUnitWithParentEncodingOrProjection(e,t){const{encoding:n,projection:i}=e,{parentEncoding:r,parentProjection:o,config:a}=t,s=vl({parentProjection:o,projection:i}),l=yl({parentEncoding:r,encoding:ul(n,t.repeater)});return this.mapUnit({...e,...s?{projection:s}:{},...l?{encoding:l}:{}},{config:a})}mapFacetedUnit(e,t){const{row:n,column:i,facet:r,...o}=e.encoding,{mark:a,width:s,projection:l,height:c,view:u,params:f,encoding:d,...m}=e,{facetMapping:p,layout:g}=this.getFacetMappingAndLayout({row:n,column:i,facet:r},t),h=ul(o,t.repeater);return this.mapFacet({...m,...g,facet:p,spec:{...s?{width:s}:{},...c?{height:c}:{},...u?{view:u}:{},...l?{projection:l}:{},mark:a,encoding:h,...f?{params:f}:{}}},t)}getFacetMappingAndLayout(e,t){const{row:n,column:i,facet:r}=e;if(n||i){r&&$i(`Facet encoding dropped as ${(o=[...n?[Q]:[],...i?[J]:[]]).join(" and ")} ${o.length>1?"are":"is"} also specified.`);const t={},a={};for(const n of[Q,J]){const i=e[n];if(i){const{align:e,center:r,spacing:o,columns:s,...l}=i;t[n]=l;for(const e of["align","center","spacing"])void 0!==i[e]&&(a[e]??={},a[e][n]=i[e])}}return{facetMapping:t,layout:a}}{const{align:e,center:n,spacing:i,columns:o,...a}=r;return{facetMapping:cl(a,t.repeater),layout:{...e?{align:e}:{},...n?{center:n}:{},...i?{spacing:i}:{},...o?{columns:o}:{}}}}var o}mapLayer(e,t){let{parentEncoding:n,parentProjection:i,...r}=t;const{encoding:o,projection:a,...s}=e,l={...r,parentEncoding:yl({parentEncoding:n,encoding:o,layer:!0}),parentProjection:vl({parentProjection:i,projection:a})};return super.mapLayer({...s,...e.name?{name:[l.repeaterPrefix,e.name].filter((e=>e)).join("_")}:{}},l)}},Vl=new class extends Qs{map(e,t){return t.emptySelections??={},t.selectionPredicates??={},e=Ll(e,t),super.map(e,t)}mapLayerOrUnit(e,t){if((e=Ll(e,t)).encoding){const n={};for(const[i,r]of z(e.encoding))n[i]=ql(r,t);e={...e,encoding:n}}return super.mapLayerOrUnit(e,t)}mapUnit(e,t){const{selection:n,...i}=e;return n?{...i,params:z(n).map((e=>{let[n,i]=e;const{init:r,bind:o,empty:a,...s}=i;"single"===s.type?(s.type="point",s.toggle=!1):"multi"===s.type&&(s.type="point"),t.emptySelections[n]="none"!==a;for(const e of F(t.selectionPredicates[n]??{}))e.empty="none"!==a;return{name:n,value:r,select:s,bind:o}}))}:e}},Gl=new Wl;function Yl(e){return t.isString(e)?{type:e}:e??{}}const Xl=["background","padding"];function Ql(e,t){const n={};for(const t of Xl)e&&void 0!==e[t]&&(n[t]=Sn(e[t]));return t&&(n.params=e.params),n}class Jl{constructor(){let e=arguments.length>0&&void 0!==arguments[0]?arguments[0]:{},t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{};this.explicit=e,this.implicit=t}clone(){return new Jl(l(this.explicit),l(this.implicit))}combine(){return{...this.explicit,...this.implicit}}get(e){return U(this.explicit[e],this.implicit[e])}getWithExplicit(e){return void 0!==this.explicit[e]?{explicit:!0,value:this.explicit[e]}:void 0!==this.implicit[e]?{explicit:!1,value:this.implicit[e]}:{explicit:!1,value:void 0}}setWithExplicit(e,t){let{value:n,explicit:i}=t;void 0!==n&&this.set(e,n,i)}set(e,t,n){return delete this[n?"implicit":"explicit"][e],this[n?"explicit":"implicit"][e]=t,this}copyKeyFromSplit(e,t){let{explicit:n,implicit:i}=t;void 0!==n[e]?this.set(e,n[e],!0):void 0!==i[e]&&this.set(e,i[e],!1)}copyKeyFromObject(e,t){void 0!==t[e]&&this.set(e,t[e],!0)}copyAll(e){for(const t of D(e.combine())){const n=e.getWithExplicit(t);this.setWithExplicit(t,n)}}}function Kl(e){return{explicit:!0,value:e}}function Zl(e){return{explicit:!1,value:e}}function ec(e){return(t,n,i,r)=>{const o=e(t.value,n.value);return o>0?t:o<0?n:tc(t,n,i,r)}}function tc(e,t,n,i){return e.explicit&&t.explicit&&$i(function(e,t,n,i){return`Conflicting ${t.toString()} property "${e.toString()}" (${X(n)} and ${X(i)}). Using ${X(n)}.`}(n,i,e.value,t.value)),e}function nc(e,t,n,i){let r=arguments.length>4&&void 0!==arguments[4]?arguments[4]:tc;return void 0===e||void 0===e.value?t:e.explicit&&!t.explicit?e:t.explicit&&!e.explicit?t:Y(e.value,t.value)?e:r(e,t,n,i)}class ic extends Jl{constructor(){let e=arguments.length>0&&void 0!==arguments[0]?arguments[0]:{},t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{},n=arguments.length>2&&void 0!==arguments[2]&&arguments[2];super(e,t),this.explicit=e,this.implicit=t,this.parseNothing=n}clone(){const e=super.clone();return e.parseNothing=this.parseNothing,e}}function rc(e){return"url"in e}function oc(e){return"values"in e}function ac(e){return"name"in e&&!rc(e)&&!oc(e)&&!sc(e)}function sc(e){return e&&(lc(e)||cc(e)||uc(e))}function lc(e){return"sequence"in e}function cc(e){return"sphere"in e}function uc(e){return"graticule"in e}let fc=function(e){return e[e.Raw=0]="Raw",e[e.Main=1]="Main",e[e.Row=2]="Row",e[e.Column=3]="Column",e[e.Lookup=4]="Lookup",e}({});function dc(e){const{signals:t,hasLegend:n,index:i,...r}=e;return r.field=E(r.field),r}function mc(e){let n=!(arguments.length>1&&void 0!==arguments[1])||arguments[1],i=arguments.length>2&&void 0!==arguments[2]?arguments[2]:t.identity;if(t.isArray(e)){const t=e.map((e=>mc(e,n,i)));return n?`[${t.join(", ")}]`:t}return wi(e)?i(n?Oi(e):function(e){const t=zi(e,!0);return e.utc?+new Date(Date.UTC(...t)):+new Date(...t)}(e)):n?i(X(e)):e}function pc(e,n){for(const i of F(e.component.selection??{})){const r=i.name;let o=`${r}${Au}, ${"global"===i.resolve?"true":`{unit: ${Mu(e)}}`}`;for(const t of Eu)t.defined(i)&&(t.signals&&(n=t.signals(e,i,n)),t.modifyExpr&&(o=t.modifyExpr(e,i,o)));n.push({name:r+ju,on:[{events:{signal:i.name+Au},update:`modify(${t.stringValue(i.name+Pu)}, ${o})`}]})}return yc(n)}function gc(e,n){if(e.component.selection&&D(e.component.selection).length){const i=t.stringValue(e.getName("cell"));n.unshift({name:"facet",value:{},on:[{events:t.parseSelector("pointermove","scope"),update:`isTuple(facet) ? facet : group(${i}).datum`}]})}return yc(n)}function hc(e,t){for(const n of F(e.component.selection??{}))for(const i of Eu)i.defined(n)&&i.marks&&(t=i.marks(e,n,t));return t}function yc(e){return e.map((e=>(e.on&&!e.on.length&&delete e.on,e)))}class vc{constructor(e,t){this.debugName=t,qn(this,"_children",[]),qn(this,"_parent",null),qn(this,"_hash",void 0),e&&(this.parent=e)}clone(){throw new Error("Cannot clone node")}get parent(){return this._parent}set parent(e){this._parent=e,e&&e.addChild(this)}get children(){return this._children}numChildren(){return this._children.length}addChild(e,t){this._children.includes(e)?$i("Attempt to add the same child twice."):void 0!==t?this._children.splice(t,0,e):this._children.push(e)}removeChild(e){const t=this._children.indexOf(e);return this._children.splice(t,1),t}remove(){let e=this._parent.removeChild(this);for(const t of this._children)t._parent=this._parent,this._parent.addChild(t,e++)}insertAsParentOf(e){const t=e.parent;t.removeChild(this),this.parent=t,e.parent=this}swapWithParent(){const e=this._parent,t=e.parent;for(const t of this._children)t.parent=e;this._children=[],e.removeChild(this);const n=e.parent.removeChild(e);this._parent=t,t.addChild(this,n),e.parent=this}}class bc extends vc{clone(){const e=new this.constructor;return e.debugName=`clone_${this.debugName}`,e._source=this._source,e._name=`clone_${this._name}`,e.type=this.type,e.refCounts=this.refCounts,e.refCounts[e._name]=0,e}constructor(e,t,n,i){super(e,t),this.type=n,this.refCounts=i,qn(this,"_source",void 0),qn(this,"_name",void 0),this._source=this._name=t,this.refCounts&&!(this._name in this.refCounts)&&(this.refCounts[this._name]=0)}dependentFields(){return new Set}producedFields(){return new Set}hash(){return void 0===this._hash&&(this._hash=`Output ${W()}`),this._hash}getSource(){return this.refCounts[this._name]++,this._source}isRequired(){return!!this.refCounts[this._name]}setSource(e){this._source=e}}function xc(e){return void 0!==e.as}function $c(e){return`${e}_end`}class wc extends vc{clone(){return new wc(null,l(this.timeUnits))}constructor(e,t){super(e),this.timeUnits=t}static makeFromEncoding(e,t){const n=t.reduceFieldDef(((e,n,i)=>{const{field:r,timeUnit:o}=n;if(o){let a;if(Ci(o)){if(xm(t)){const{mark:e,markDef:i,config:s}=t,l=Lo({fieldDef:n,markDef:i,config:s});(Jr(e)||l)&&(a={timeUnit:Ui(o),field:r})}}else a={as:oa(n,{forAs:!0}),field:r,timeUnit:o};if(xm(t)){const{mark:e,markDef:r,config:o}=t,s=Lo({fieldDef:n,markDef:r,config:o});Jr(e)&&zt(i)&&.5!==s&&(a.rectBandPosition=s)}a&&(e[d(a)]=a)}return e}),{});return S(n)?null:new wc(e,n)}static makeFromTransform(e,t){const{timeUnit:n,...i}={...t},r={...i,timeUnit:Ui(n)};return new wc(e,{[d(r)]:r})}merge(e){this.timeUnits={...this.timeUnits};for(const t in e.timeUnits)this.timeUnits[t]||(this.timeUnits[t]=e.timeUnits[t]);for(const t of e.children)e.removeChild(t),t.parent=this;e.remove()}removeFormulas(e){const t={};for(const[n,i]of z(this.timeUnits)){const r=xc(i)?i.as:`${i.field}_end`;e.has(r)||(t[n]=i)}this.timeUnits=t}producedFields(){return new Set(F(this.timeUnits).map((e=>xc(e)?e.as:$c(e.field))))}dependentFields(){return new Set(F(this.timeUnits).map((e=>e.field)))}hash(){return`TimeUnit ${d(this.timeUnits)}`}assemble(){const e=[];for(const t of F(this.timeUnits)){const{rectBandPosition:n}=t,i=Ui(t.timeUnit);if(xc(t)){const{field:r,as:o}=t,{unit:a,utc:s,...l}=i,c=[o,`${o}_end`];e.push({field:E(r),type:"timeunit",...a?{units:Ti(a)}:{},...s?{timezone:"utc"}:{},...l,as:c}),e.push(...Fc(c,n,i))}else if(t){const{field:r}=t,o=r.replaceAll("\\.","."),a=Dc({timeUnit:i,field:o}),s=$c(o);e.push({type:"formula",expr:a,as:s}),e.push(...Fc([o,s],n,i))}}return e}}const kc="offsetted_rect_start",Sc="offsetted_rect_end";function Dc(e){let{timeUnit:t,field:n,reverse:i}=e;const{unit:r,utc:o}=t,a=Ei(r),{part:s,step:l}=Bi(a,t.step);return`${o?"utcOffset":"timeOffset"}('${s}', datum['${n}'], ${i?-l:l})`}function Fc(e,t,n){let[i,r]=e;if(void 0!==t&&.5!==t){const e=`datum['${i}']`,o=`datum['${r}']`;return[{type:"formula",expr:zc([Dc({timeUnit:n,field:i,reverse:!0}),e],t+.5),as:`${i}_${kc}`},{type:"formula",expr:zc([e,o],t+.5),as:`${i}_${Sc}`}]}return[]}function zc(e,t){let[n,i]=e;return`${1-t} * ${n} + ${t} * ${i}`}const Oc="_tuple_fields";class _c{constructor(){qn(this,"hasChannel",void 0),qn(this,"hasField",void 0),qn(this,"hasSelectionId",void 0),qn(this,"timeUnit",void 0),qn(this,"items",void 0);for(var e=arguments.length,t=new Array(e),n=0;n!0,parse:(e,n,i)=>{const r=n.name,o=n.project??=new _c,a={},s={},l=new Set,c=(e,t)=>{const n="visual"===t?e.channel:e.field;let i=_(`${r}_${n}`);for(let e=1;l.has(i);e++)i=_(`${r}_${n}_${e}`);return l.add(i),{[t]:i}},u=n.type,f=e.config.selection[u],m=void 0!==i.value?t.array(i.value):null;let{fields:p,encodings:g}=t.isObject(i.select)?i.select:{};if(!p&&!g&&m)for(const e of m)if(t.isObject(e))for(const t of D(e))Je[t]?(g||(g=[])).push(t):"interval"===u?($i('Interval selections should be initialized using "x", "y", "longitude", or "latitude" keys.'),g=f.encodings):(p??=[]).push(t);p||g||(g=f.encodings,"fields"in f&&(p=f.fields));for(const t of g??[]){const n=e.fieldDef(t);if(n){let i=n.field;if(n.aggregate){$i(Qn(t,n.aggregate));continue}if(!i){$i(Xn(t));continue}if(n.timeUnit&&!Ci(n.timeUnit)){i=e.vgField(t);const r={timeUnit:n.timeUnit,as:i,field:n.field};s[d(r)]=r}if(!a[i]){const r={field:i,channel:t,type:"interval"===u&&Ht(t)&&$r(e.getScaleComponent(t).get("type"))?"R":n.bin?"R-RE":"E",index:o.items.length};r.signals={...c(r,"data"),...c(r,"visual")},o.items.push(a[i]=r),o.hasField[i]=a[i],o.hasSelectionId=o.hasSelectionId||i===xs,Ee(t)?(r.geoChannel=t,r.channel=Te(t),o.hasChannel[r.channel]=a[i]):o.hasChannel[t]=a[i]}}else $i(Xn(t))}for(const e of p??[]){if(o.hasField[e])continue;const t={type:"E",field:e,index:o.items.length};t.signals={...c(t,"data")},o.items.push(t),o.hasField[e]=t,o.hasSelectionId=o.hasSelectionId||e===xs}m&&(n.init=m.map((e=>o.items.map((n=>t.isObject(e)?void 0!==e[n.geoChannel||n.channel]?e[n.geoChannel||n.channel]:e[n.field]:e))))),S(s)||(o.timeUnit=new wc(null,s))},signals:(e,t,n)=>{const i=t.name+Oc;return n.filter((e=>e.name===i)).length>0||t.project.hasSelectionId?n:n.concat({name:i,value:t.project.items.map(dc)})}},Cc={defined:e=>"interval"===e.type&&"global"===e.resolve&&e.bind&&"scales"===e.bind,parse:(e,t)=>{const n=t.scales=[];for(const i of t.project.items){const r=i.channel;if(!Ht(r))continue;const o=e.getScaleComponent(r),a=o?o.get("type"):void 0;o&&$r(a)?(o.set("selectionExtent",{param:t.name,field:i.field},!0),n.push(i)):$i("Scale bindings are currently only supported for scales with unbinned, continuous domains.")}},topLevelSignals:(e,n,i)=>{const r=n.scales.filter((e=>0===i.filter((t=>t.name===e.signals.data)).length));if(!e.parent||Ac(e)||0===r.length)return i;const o=i.filter((e=>e.name===n.name))[0];let a=o.update;if(a.indexOf(Tu)>=0)o.update=`{${r.map((e=>`${t.stringValue(E(e.field))}: ${e.signals.data}`)).join(", ")}}`;else{for(const e of r){const n=`${t.stringValue(E(e.field))}: ${e.signals.data}`;a.includes(n)||(a=`${a.substring(0,a.length-1)}, ${n}}`)}o.update=a}return i.concat(r.map((e=>({name:e.signals.data}))))},signals:(e,t,n)=>{if(e.parent&&!Ac(e))for(const e of t.scales){const t=n.filter((t=>t.name===e.signals.data))[0];t.push="outer",delete t.value,delete t.update}return n}};function Pc(e,n){return`domain(${t.stringValue(e.scaleName(n))})`}function Ac(e){return e.parent&&km(e.parent)&&(!e.parent.parent??Ac(e.parent.parent))}const jc="_brush",Tc="_scale_trigger",Ec="geo_interval_init_tick",Mc="_init",Lc={defined:e=>"interval"===e.type,parse:(e,n,i)=>{if(e.hasProjection){const e={...t.isObject(i.select)?i.select:{}};e.fields=[xs],e.encodings||(e.encodings=i.value?D(i.value):[ue,ce]),i.select={type:"interval",...e}}if(n.translate&&!Cc.defined(n)){const e=`!event.item || event.item.mark.name !== ${t.stringValue(n.name+jc)}`;for(const i of n.events){if(!i.between){$i(`${i} is not an ordered event stream for interval selections.`);continue}const n=t.array(i.between[0].filter??=[]);n.indexOf(e)<0&&n.push(e)}}},signals:(e,n,i)=>{const r=n.name,o=r+Au,a=F(n.project.hasChannel).filter((e=>e.channel===Z||e.channel===ee)),s=n.init?n.init[0]:null;if(i.push(...a.reduce(((i,r)=>i.concat(function(e,n,i,r){const o=!e.hasProjection,a=i.channel,s=i.signals.visual,l=t.stringValue(o?e.scaleName(a):e.projectionName()),c=e=>`scale(${l}, ${e})`,u=e.getSizeSignalRef(a===Z?"width":"height").signal,f=`${a}(unit)`,d=n.events.reduce(((e,t)=>[...e,{events:t.between[0],update:`[${f}, ${f}]`},{events:t,update:`[${s}[0], clamp(${f}, 0, ${u})]`}]),[]);if(o){const t=i.signals.data,o=Cc.defined(n),u=e.getScaleComponent(a),f=u?u.get("type"):void 0,m=r?{init:mc(r,!0,c)}:{value:[]};return d.push({events:{signal:n.name+Tc},update:$r(f)?`[${c(`${t}[0]`)}, ${c(`${t}[1]`)}]`:"[0, 0]"}),o?[{name:t,on:[]}]:[{name:s,...m,on:d},{name:t,...r?{init:mc(r)}:{},on:[{events:{signal:s},update:`${s}[0] === ${s}[1] ? null : invert(${l}, ${s})`}]}]}{const e=a===Z?0:1,t=n.name+Mc;return[{name:s,...r?{init:`[${t}[0][${e}], ${t}[1][${e}]]`}:{value:[]},on:d}]}}(e,n,r,s&&s[r.index]))),[])),e.hasProjection){const l=t.stringValue(e.projectionName()),c=e.projectionName()+"_center",{x:u,y:f}=n.project.hasChannel,d=u&&u.signals.visual,m=f&&f.signals.visual,p=u?s&&s[u.index]:`${c}[0]`,g=f?s&&s[f.index]:`${c}[1]`,h=t=>e.getSizeSignalRef(t).signal,y=`[[${d?d+"[0]":"0"}, ${m?m+"[0]":"0"}],[${d?d+"[1]":h("width")}, ${m?m+"[1]":h("height")}]]`;if(s&&(i.unshift({name:r+Mc,init:`[scale(${l}, [${u?p[0]:p}, ${f?g[0]:g}]), scale(${l}, [${u?p[1]:p}, ${f?g[1]:g}])]`}),!u||!f)){i.find((e=>e.name===c))||i.unshift({name:c,update:`invert(${l}, [${h("width")}/2, ${h("height")}/2])`})}const v=`vlSelectionTuples(${`intersect(${y}, {markname: ${t.stringValue(e.getName("marks"))}}, unit.mark)`}, ${`{unit: ${Mu(e)}}`})`,b=a.map((e=>e.signals.visual));return i.concat({name:o,on:[{events:[...b.length?[{signal:b.join(" || ")}]:[],...s?[{signal:Ec}]:[]],update:v}]})}{if(!Cc.defined(n)){const n=r+Tc,o=a.map((n=>{const i=n.channel,{data:r,visual:o}=n.signals,a=t.stringValue(e.scaleName(i)),s=$r(e.getScaleComponent(i).get("type"))?"+":"";return`(!isArray(${r}) || (${s}invert(${a}, ${o})[0] === ${s}${r}[0] && ${s}invert(${a}, ${o})[1] === ${s}${r}[1]))`}));o.length&&i.push({name:n,value:{},on:[{events:a.map((t=>({scale:e.scaleName(t.channel)}))),update:o.join(" && ")+` ? ${n} : {}`}]})}const l=a.map((e=>e.signals.data)),c=`unit: ${Mu(e)}, fields: ${r+Oc}, values`;return i.concat({name:o,...s?{init:`{${c}: ${mc(s)}}`}:{},...l.length?{on:[{events:[{signal:l.join(" || ")}],update:`${l.join(" && ")} ? {${c}: [${l}]} : null`}]}:{}})}},topLevelSignals:(e,t,n)=>{if(xm(e)&&e.hasProjection&&t.init){n.filter((e=>e.name===Ec)).length||n.unshift({name:Ec,value:null,on:[{events:"timer{1}",update:`${Ec} === null ? {} : ${Ec}`}]})}return n},marks:(e,n,i)=>{const r=n.name,{x:o,y:a}=n.project.hasChannel,s=o?.signals.visual,l=a?.signals.visual,c=`data(${t.stringValue(n.name+Pu)})`;if(Cc.defined(n)||!o&&!a)return i;const u={x:void 0!==o?{signal:`${s}[0]`}:{value:0},y:void 0!==a?{signal:`${l}[0]`}:{value:0},x2:void 0!==o?{signal:`${s}[1]`}:{field:{group:"width"}},y2:void 0!==a?{signal:`${l}[1]`}:{field:{group:"height"}}};if("global"===n.resolve)for(const t of D(u))u[t]=[{test:`${c}.length && ${c}[0].unit === ${Mu(e)}`,...u[t]},{value:0}];const{fill:f,fillOpacity:d,cursor:m,...p}=n.mark,g=D(p).reduce(((e,t)=>(e[t]=[{test:[void 0!==o&&`${s}[0] !== ${s}[1]`,void 0!==a&&`${l}[0] !== ${l}[1]`].filter((e=>e)).join(" && "),value:p[t]},{value:null}],e)),{});return[{name:`${r+jc}_bg`,type:"rect",clip:!0,encode:{enter:{fill:{value:f},fillOpacity:{value:d}},update:u}},...i,{name:r+jc,type:"rect",clip:!0,encode:{enter:{...m?{cursor:{value:m}}:{},fill:{value:"transparent"}},update:{...u,...g}}}]}};const qc={defined:e=>"point"===e.type,signals:(e,n,i)=>{const r=n.name,o=r+Oc,a=n.project,s="(item().isVoronoi ? datum.datum : datum)",l=F(e.component.selection??{}).reduce(((e,t)=>"interval"===t.type?e.concat(t.name+jc):e),[]).map((e=>`indexof(item().mark.name, '${e}') < 0`)).join(" && "),c="datum && item().mark.marktype !== 'group' && indexof(item().mark.role, 'legend') < 0"+(l?` && ${l}`:"");let u=`unit: ${Mu(e)}, `;if(n.project.hasSelectionId)u+=`${xs}: ${s}[${t.stringValue(xs)}]`;else{u+=`fields: ${o}, values: [${a.items.map((n=>{const i=e.fieldDef(n.channel);return i?.bin?`[${s}[${t.stringValue(e.vgField(n.channel,{}))}], ${s}[${t.stringValue(e.vgField(n.channel,{binSuffix:"end"}))}]]`:`${s}[${t.stringValue(n.field)}]`})).join(", ")}]`}const f=n.events;return i.concat([{name:r+Au,on:f?[{events:f,update:`${c} ? {${u}} : null`,force:!0}]:[]}])}};function Uc(e,n,i,r){const o=Wo(n)&&n.condition,a=r(n);if(o){return{[i]:[...t.array(o).map((t=>{const n=r(t);if(function(e){return e.param}(t)){const{param:i,empty:r}=t;return{test:Iu(e,{param:i,empty:r}),...n}}return{test:Vu(e,t.test),...n}})),...void 0!==a?[a]:[]]}}return void 0!==a?{[i]:a}:{}}function Rc(e){let t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:"text";const n=e.encoding[t];return Uc(e,n,t,(t=>Wc(t,e.config)))}function Wc(e,t){let n=arguments.length>2&&void 0!==arguments[2]?arguments[2]:"datum";if(e){if(Zo(e))return Fn(e.value);if(Jo(e)){const{format:i,formatType:r}=ma(e);return vo({fieldOrDatumDef:e,format:i,formatType:r,expr:n,config:t})}}}function Bc(e){let n=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{};const{encoding:i,markDef:r,config:o,stack:a}=e,s=i.tooltip;if(t.isArray(s))return{tooltip:Hc({tooltip:s},a,o,n)};{const l=n.reactiveGeom?"datum.datum":"datum";return Uc(e,s,"tooltip",(e=>{const s=Wc(e,o,l);if(s)return s;if(null===e)return;let c=Cn("tooltip",r,o);return!0===c&&(c={content:"encoding"}),t.isString(c)?{value:c}:t.isObject(c)?yn(c)?c:"encoding"===c.content?Hc(i,a,o,n):{signal:l}:void 0}))}}function Ic(e,n,i){let{reactiveGeom:r}=arguments.length>3&&void 0!==arguments[3]?arguments[3]:{};const o={...i,...i.tooltipFormat},a={},s=r?"datum.datum":"datum",l=[];function c(i,r){const c=tt(r),u=Ko(i)?i:{...i,type:e[c].type},f=u.title||da(u,o),d=t.array(f).join(", ").replaceAll(/"/g,'\\"');let m;if(zt(r)){const t="x"===r?"x2":"y2",n=pa(e[t]);if(cn(u.bin)&&n){const e=oa(u,{expr:s}),i=oa(n,{expr:s}),{format:r,formatType:l}=ma(u);m=Fo(e,i,r,l,o),a[t]=!0}}if((zt(r)||r===se||r===oe)&&n&&n.fieldChannel===r&&"normalize"===n.offset){const{format:e,formatType:t}=ma(u);m=vo({fieldOrDatumDef:u,format:e,formatType:t,expr:s,config:o,normalizeStack:!0}).signal}m??=Wc(u,o,s).signal,l.push({channel:r,key:d,value:m})}Wa(e,((e,t)=>{Ho(e)?c(e,t):Bo(e)&&c(e.condition,t)}));const u={};for(const{channel:e,key:t,value:n}of l)a[e]||u[t]||(u[t]=n);return u}function Hc(e,t,n){let{reactiveGeom:i}=arguments.length>3&&void 0!==arguments[3]?arguments[3]:{};const r=Ic(e,t,n,{reactiveGeom:i}),o=z(r).map((e=>{let[t,n]=e;return`"${t}": ${n}`}));return o.length>0?{signal:`{${o.join(", ")}}`}:void 0}function Vc(e){const{markDef:t,config:n}=e,i=Cn("aria",t,n);return!1===i?{}:{...i?{aria:i}:{},...Gc(e),...Yc(e)}}function Gc(e){const{mark:t,markDef:n,config:i}=e;if(!1===i.aria)return{};const r=Cn("ariaRoleDescription",n,i);return null!=r?{ariaRoleDescription:{value:r}}:t in $n?{}:{ariaRoleDescription:{value:t}}}function Yc(e){const{encoding:t,markDef:n,config:i,stack:r}=e,o=t.description;if(o)return Uc(e,o,"description",(t=>Wc(t,e.config)));const a=Cn("description",n,i);if(null!=a)return{description:Fn(a)};if(!1===i.aria)return{};const s=Ic(t,r,i);return S(s)?void 0:{description:{signal:z(s).map(((e,t)=>{let[n,i]=e;return`"${t>0?"; ":""}${n}: " + (${i})`})).join(" + ")}}}function Xc(e,t){let n=arguments.length>2&&void 0!==arguments[2]?arguments[2]:{};const{markDef:i,encoding:r,config:o}=t,{vgChannel:a}=n;let{defaultRef:s,defaultValue:l}=n;void 0===s&&(l??=Cn(e,i,o,{vgChannel:a,ignoreVgConfig:!0}),void 0!==l&&(s=Fn(l)));const c=r[e];return Uc(t,c,a??e,(n=>mo({channel:e,channelDef:n,markDef:i,config:o,scaleName:t.scaleName(e),scale:t.getScaleComponent(e),stack:null,defaultRef:s})))}function Qc(e){let t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{filled:void 0};const{markDef:n,encoding:i,config:r}=e,{type:o}=n,a=t.filled??Cn("filled",n,r),s=p(["bar","point","circle","square","geoshape"],o)?"transparent":void 0,l=Cn(!0===a?"color":void 0,n,r,{vgChannel:"fill"})??r.mark[!0===a&&"color"]??s,c=Cn(!1===a?"color":void 0,n,r,{vgChannel:"stroke"})??r.mark[!1===a&&"color"],u=a?"fill":"stroke",f={...l?{fill:Fn(l)}:{},...c?{stroke:Fn(c)}:{}};return n.color&&(a?n.fill:n.stroke)&&$i(ri("property",{fill:"fill"in n,stroke:"stroke"in n})),{...f,...Xc("color",e,{vgChannel:u,defaultValue:a?l:c}),...Xc("fill",e,{defaultValue:i.fill?l:void 0}),...Xc("stroke",e,{defaultValue:i.stroke?c:void 0})}}function Jc(e){const{encoding:t,mark:n}=e,i=t.order;return!Qr(n)&&Zo(i)?Uc(e,i,"zindex",(e=>Fn(e.value))):{}}function Kc(e){let{channel:t,markDef:n,encoding:i={},model:r,bandPosition:o}=e;const a=`${t}Offset`,s=n[a],l=i[a];if(("xOffset"===a||"yOffset"===a)&&l){return{offsetType:"encoding",offset:mo({channel:a,channelDef:l,markDef:n,config:r?.config,scaleName:r.scaleName(a),scale:r.getScaleComponent(a),stack:null,defaultRef:Fn(s),bandPosition:o})}}const c=n[a];return c?{offsetType:"visual",offset:c}:{}}function Zc(e,t,n){let{defaultPos:i,vgChannel:r}=n;const{encoding:o,markDef:a,config:s,stack:l}=t,c=o[e],u=o[it(e)],f=t.scaleName(e),d=t.getScaleComponent(e),{offset:m,offsetType:p}=Kc({channel:e,markDef:a,encoding:o,model:t,bandPosition:.5}),g=eu({model:t,defaultPos:i,channel:e,scaleName:f,scale:d}),h=!c&&zt(e)&&(o.latitude||o.longitude)?{field:t.getName(e)}:function(e){const{channel:t,channelDef:n,scaleName:i,stack:r,offset:o,markDef:a}=e;if(Jo(n)&&r&&t===r.fieldChannel){if(Ho(n)){let e=n.bandPosition;if(void 0!==e||"text"!==a.type||"radius"!==t&&"theta"!==t||(e=.5),void 0!==e)return fo({scaleName:i,fieldOrDatumDef:n,startSuffix:"start",bandPosition:e,offset:o})}return uo(n,i,{suffix:"end"},{offset:o})}return so(e)}({channel:e,channelDef:c,channel2Def:u,markDef:a,config:s,scaleName:f,scale:d,stack:l,offset:m,defaultRef:g,bandPosition:"encoding"===p?0:void 0});return h?{[r||e]:h}:void 0}function eu(e){let{model:t,defaultPos:n,channel:i,scaleName:r,scale:o}=e;const{markDef:a,config:s}=t;return()=>{const e=tt(i),l=nt(i),c=Cn(i,a,s,{vgChannel:l});if(void 0!==c)return po(i,c);switch(n){case"zeroOrMin":case"zeroOrMax":if(r){const e=o.get("type");if(p([cr.LOG,cr.TIME,cr.UTC],e));else if(o.domainDefinitelyIncludesZero())return{scale:r,value:0}}if("zeroOrMin"===n)return"y"===e?{field:{group:"height"}}:{value:0};switch(e){case"radius":return{signal:`min(${t.width.signal},${t.height.signal})/2`};case"theta":return{signal:"2*PI"};case"x":return{field:{group:"width"}};case"y":return{value:0}}break;case"mid":return{...t[rt(i)],mult:.5}}}}const tu={left:"x",center:"xc",right:"x2"},nu={top:"y",middle:"yc",bottom:"y2"};function iu(e,t,n){let i=arguments.length>3&&void 0!==arguments[3]?arguments[3]:"middle";if("radius"===e||"theta"===e)return nt(e);const r="x"===e?"align":"baseline",o=Cn(r,t,n);let a;return yn(o)?($i(function(e){return`The ${e} for range marks cannot be an expression`}(r)),a=void 0):a=o,"x"===e?tu[a||("top"===i?"left":"center")]:nu[a||i]}function ru(e,t,n){let{defaultPos:i,defaultPos2:r,range:o}=n;return o?ou(e,t,{defaultPos:i,defaultPos2:r}):Zc(e,t,{defaultPos:i})}function ou(e,t,n){let{defaultPos:i,defaultPos2:r}=n;const{markDef:o,config:a}=t,s=it(e),l=rt(e),c=function(e,t,n){const{encoding:i,mark:r,markDef:o,stack:a,config:s}=e,l=tt(n),c=rt(n),u=nt(n),f=i[l],d=e.scaleName(l),m=e.getScaleComponent(l),{offset:p}=Kc(n in i||n in o?{channel:n,markDef:o,encoding:i,model:e}:{channel:l,markDef:o,encoding:i,model:e});if(!f&&("x2"===n||"y2"===n)&&(i.latitude||i.longitude)){const t=rt(n),i=e.markDef[t];return null!=i?{[t]:{value:i}}:{[u]:{field:e.getName(n)}}}const g=function(e){let{channel:t,channelDef:n,channel2Def:i,markDef:r,config:o,scaleName:a,scale:s,stack:l,offset:c,defaultRef:u}=e;if(Jo(n)&&l&&t.charAt(0)===l.fieldChannel.charAt(0))return uo(n,a,{suffix:"start"},{offset:c});return so({channel:t,channelDef:i,scaleName:a,scale:s,stack:l,markDef:r,config:o,offset:c,defaultRef:u})}({channel:n,channelDef:f,channel2Def:i[n],markDef:o,config:s,scaleName:d,scale:m,stack:a,offset:p,defaultRef:void 0});if(void 0!==g)return{[u]:g};return au(n,o)||au(n,{[n]:An(n,o,s.style),[c]:An(c,o,s.style)})||au(n,s[r])||au(n,s.mark)||{[u]:eu({model:e,defaultPos:t,channel:n,scaleName:d,scale:m})()}}(t,r,s);return{...Zc(e,t,{defaultPos:i,vgChannel:c[l]?iu(e,o,a):nt(e)}),...c}}function au(e,t){const n=rt(e),i=nt(e);if(void 0!==t[i])return{[i]:po(e,t[i])};if(void 0!==t[e])return{[i]:po(e,t[e])};if(t[n]){const i=t[n];if(!io(i))return{[n]:po(e,i)};$i(function(e){return`Position range does not support relative band size for ${e}.`}(n))}}function su(e,n){const{config:i,encoding:r,markDef:o}=e,a=o.type,s=it(n),l=rt(n),c=r[n],u=r[s],f=e.getScaleComponent(n),d=f?f.get("type"):void 0,m=o.orient,p=r[l]??r.size??Cn("size",o,i,{vgChannel:l}),g=ot(n),h="bar"===a&&("x"===n?"vertical"===m:"horizontal"===m);return!Ho(c)||!(ln(c.bin)||cn(c.bin)||c.timeUnit&&!u)||p&&!io(p)||r[g]||xr(d)?(Jo(c)&&xr(d)||h)&&!u?function(e,n,i){const{markDef:r,encoding:o,config:a,stack:s}=i,l=r.orient,c=i.scaleName(n),u=i.getScaleComponent(n),f=rt(n),d=it(n),m=ot(n),p=i.scaleName(m),g=i.getScaleComponent(at(n)),h="horizontal"===l&&"y"===n||"vertical"===l&&"x"===n;let y;(o.size||r.size)&&(h?y=Xc("size",i,{vgChannel:f,defaultRef:Fn(r.size)}):$i(function(e){return`Cannot apply size to non-oriented mark "${e}".`}(r.type)));const v=!!y,b=qo({channel:n,fieldDef:e,markDef:r,config:a,scaleType:(u||g)?.get("type"),useVlSizeChannel:h});y=y||{[f]:lu(f,p||c,g||u,a,b,!!e,r.type)};const x="band"===(u||g)?.get("type")&&io(b)&&!v?"top":"middle",$=iu(n,r,a,x),w="xc"===$||"yc"===$,{offset:k,offsetType:S}=Kc({channel:n,markDef:r,encoding:o,model:i,bandPosition:w?.5:0}),D=so({channel:n,channelDef:e,markDef:r,config:a,scaleName:c,scale:u,stack:s,offset:k,defaultRef:eu({model:i,defaultPos:"mid",channel:n,scaleName:c,scale:u}),bandPosition:w?"encoding"===S?0:.5:yn(b)?{signal:`(1-${b})/2`}:io(b)?(1-b.band)/2:0});if(f)return{[$]:D,...y};{const e=nt(d),n=y[f],i=k?{...n,offset:k}:n;return{[$]:D,[e]:t.isArray(D)?[D[0],{...D[1],offset:i}]:{...D,offset:i}}}}(c,n,e):ou(n,e,{defaultPos:"zeroOrMax",defaultPos2:"zeroOrMin"}):function(e){let{fieldDef:t,fieldDef2:n,channel:i,model:r}=e;const{config:o,markDef:a,encoding:s}=r,l=r.getScaleComponent(i),c=r.scaleName(i),u=l?l.get("type"):void 0,f=l.get("reverse"),d=qo({channel:i,fieldDef:t,markDef:a,config:o,scaleType:u}),m=r.component.axes[i]?.[0],p=m?.get("translate")??.5,g=zt(i)?Cn("binSpacing",a,o)??0:0,h=it(i),y=nt(i),v=nt(h),b=Pn("minBandSize",a,o),{offset:x}=Kc({channel:i,markDef:a,encoding:s,model:r,bandPosition:0}),{offset:$}=Kc({channel:h,markDef:a,encoding:s,model:r,bandPosition:0}),w=function(e){let{scaleName:t,fieldDef:n}=e;const i=oa(n,{expr:"datum"});return`abs(scale("${t}", ${oa(n,{expr:"datum",suffix:"end"})}) - scale("${t}", ${i}))`}({fieldDef:t,scaleName:c}),k=cu(i,g,f,p,x,b,w),S=cu(h,g,f,p,$??x,b,w),D=yn(d)?{signal:`(1-${d.signal})/2`}:io(d)?(1-d.band)/2:.5,F=Lo({fieldDef:t,fieldDef2:n,markDef:a,config:o});if(ln(t.bin)||t.timeUnit){const e=t.timeUnit&&.5!==F;return{[v]:uu({fieldDef:t,scaleName:c,bandPosition:D,offset:S,useRectOffsetField:e}),[y]:uu({fieldDef:t,scaleName:c,bandPosition:yn(D)?{signal:`1-${D.signal}`}:1-D,offset:k,useRectOffsetField:e})}}if(cn(t.bin)){const e=uo(t,c,{},{offset:S});if(Ho(n))return{[v]:e,[y]:uo(n,c,{},{offset:k})};if(un(t.bin)&&t.bin.step)return{[v]:e,[y]:{signal:`scale("${c}", ${oa(t,{expr:"datum"})} + ${t.bin.step})`,offset:k}}}return void $i(vi(h))}({fieldDef:c,fieldDef2:u,channel:n,model:e})}function lu(e,n,i,r,o,a,s){if(io(o)){if(!i)return{mult:o.band,field:{group:e}};{const e=i.get("type");if("band"===e){let e=`bandwidth('${n}')`;1!==o.band&&(e=`${o.band} * ${e}`);const t=Pn("minBandSize",{type:s},r);return{signal:t?`max(${On(t)}, ${e})`:e}}1!==o.band&&($i(function(e){return`Cannot use the relative band size with ${e} scale.`}(e)),o=void 0)}}else{if(yn(o))return o;if(o)return{value:o}}if(i){const e=i.get("range");if(vn(e)&&t.isNumber(e.step))return{value:e.step-2}}if(!a){const{bandPaddingInner:n,barBandPaddingInner:i,rectBandPaddingInner:o}=r.scale,a=U(n,"bar"===s?i:o);if(yn(a))return{signal:`(1 - (${a.signal})) * ${e}`};if(t.isNumber(a))return{signal:`${1-a} * ${e}`}}return{value:js(r.view,e)-2}}function cu(e,t,n,i,r,o,a){if(Ae(e))return 0;const s="x"===e||"y2"===e,l=s?-t/2:t/2;if(yn(n)||yn(r)||yn(i)||o){const e=On(n),t=On(r),c=On(i),u=On(o),f=o?`(${a} < ${u} ? ${s?"":"-"}0.5 * (${u} - (${a})) : ${l})`:l;return{signal:(c?`${c} + `:"")+(e?`(${e} ? -1 : 1) * `:"")+(t?`(${t} + ${f})`:f)}}return r=r||0,i+(n?-r-l:+r+l)}function uu(e){let{fieldDef:t,scaleName:n,bandPosition:i,offset:r,useRectOffsetField:o}=e;return fo({scaleName:n,fieldOrDatumDef:t,bandPosition:i,offset:r,...o?{startSuffix:kc,endSuffix:Sc}:{}})}const fu=new Set(["aria","width","height"]);function du(e,t){const{fill:n,stroke:i}="include"===t.color?Qc(e):{};return{...pu(e.markDef,t),...mu(e,"fill",n),...mu(e,"stroke",i),...Xc("opacity",e),...Xc("fillOpacity",e),...Xc("strokeOpacity",e),...Xc("strokeWidth",e),...Xc("strokeDash",e),...Jc(e),...Bc(e),...Rc(e,"href"),...Vc(e)}}function mu(e,n,i){const{config:r,mark:o,markDef:a}=e;if("hide"===Cn("invalid",a,r)&&i&&!Qr(o)){const r=function(e,t){let{invalid:n=!1,channels:i}=t;const r=i.reduce(((t,n)=>{const i=e.getScaleComponent(n);if(i){const r=i.get("type"),o=e.vgField(n,{expr:"datum"});o&&$r(r)&&(t[o]=!0)}return t}),{}),o=D(r);if(o.length>0){const e=n?"||":"&&";return o.map((e=>co(e,n))).join(` ${e} `)}return}(e,{invalid:!0,channels:It});if(r)return{[n]:[{test:r,value:null},...t.array(i)]}}return i?{[n]:i}:{}}function pu(e,t){return xn.reduce(((n,i)=>(fu.has(i)||void 0===e[i]||"ignore"===t[i]||(n[i]=Fn(e[i])),n)),{})}function gu(e){const{config:t,markDef:n}=e;if(Cn("invalid",n,t)){const t=function(e,t){let{invalid:n=!1,channels:i}=t;const r=i.reduce(((t,n)=>{const i=e.getScaleComponent(n);if(i){const r=i.get("type"),o=e.vgField(n,{expr:"datum",binSuffix:e.stack?.impute?"mid":void 0});o&&$r(r)&&(t[o]=!0)}return t}),{}),o=D(r);if(o.length>0){const e=n?"||":"&&";return o.map((e=>co(e,n))).join(` ${e} `)}return}(e,{channels:Ft});if(t)return{defined:{signal:t}}}return{}}function hu(e,t){if(void 0!==t)return{[e]:Fn(t)}}const yu="voronoi",vu={defined:e=>"point"===e.type&&e.nearest,parse:(e,t)=>{if(t.events)for(const n of t.events)n.markname=e.getName(yu)},marks:(e,t,n)=>{const{x:i,y:r}=t.project.hasChannel,o=e.mark;if(Qr(o))return $i(`The "nearest" transform is not supported for ${o} marks.`),n;const a={name:e.getName(yu),type:"path",interactive:!0,from:{data:e.getName("marks")},encode:{update:{fill:{value:"transparent"},strokeWidth:{value:.35},stroke:{value:"transparent"},isVoronoi:{value:!0},...Bc(e,{reactiveGeom:!0})}},transform:[{type:"voronoi",x:{expr:i||!r?"datum.datum.x || 0":"0"},y:{expr:r||!i?"datum.datum.y || 0":"0"},size:[e.getSizeSignalRef("width"),e.getSizeSignalRef("height")]}]};let s=0,l=!1;return n.forEach(((t,n)=>{const i=t.name??"";i===e.component.mark[0].name?s=n:i.indexOf(yu)>=0&&(l=!0)})),l||n.splice(s+1,0,a),n}},bu={defined:e=>"point"===e.type&&"global"===e.resolve&&e.bind&&"scales"!==e.bind&&!ws(e.bind),parse:(e,t,n)=>qu(t,n),topLevelSignals:(e,n,i)=>{const r=n.name,o=n.project,a=n.bind,s=n.init&&n.init[0],l=vu.defined(n)?"(item().isVoronoi ? datum.datum : datum)":"datum";return o.items.forEach(((e,o)=>{const c=_(`${r}_${e.field}`);i.filter((e=>e.name===c)).length||i.unshift({name:c,...s?{init:mc(s[o])}:{value:null},on:n.events?[{events:n.events,update:`datum && item().mark.marktype !== 'group' ? ${l}[${t.stringValue(e.field)}] : null`}]:[],bind:a[e.field]??a[e.channel]??a})})),i},signals:(e,t,n)=>{const i=t.name,r=t.project,o=n.filter((e=>e.name===i+Au))[0],a=i+Oc,s=r.items.map((e=>_(`${i}_${e.field}`))),l=s.map((e=>`${e} !== null`)).join(" && ");return s.length&&(o.update=`${l} ? {fields: ${a}, values: [${s.join(", ")}]} : null`),delete o.value,delete o.on,n}},xu="_toggle",$u={defined:e=>"point"===e.type&&!!e.toggle,signals:(e,t,n)=>n.concat({name:t.name+xu,value:!1,on:[{events:t.events,update:t.toggle}]}),modifyExpr:(e,t)=>{const n=t.name+Au,i=t.name+xu;return`${i} ? null : ${n}, `+("global"===t.resolve?`${i} ? null : true, `:`${i} ? null : {unit: ${Mu(e)}}, `)+`${i} ? ${n} : null`}},wu={defined:e=>void 0!==e.clear&&!1!==e.clear,parse:(e,n)=>{n.clear&&(n.clear=t.isString(n.clear)?t.parseSelector(n.clear,"view"):n.clear)},topLevelSignals:(e,t,n)=>{if(bu.defined(t))for(const e of t.project.items){const i=n.findIndex((n=>n.name===_(`${t.name}_${e.field}`)));-1!==i&&n[i].on.push({events:t.clear,update:"null"})}return n},signals:(e,t,n)=>{function i(e,i){-1!==e&&n[e].on&&n[e].on.push({events:t.clear,update:i})}if("interval"===t.type)for(const e of t.project.items){const t=n.findIndex((t=>t.name===e.signals.visual));if(i(t,"[0, 0]"),-1===t){i(n.findIndex((t=>t.name===e.signals.data)),"null")}}else{let e=n.findIndex((e=>e.name===t.name+Au));i(e,"null"),$u.defined(t)&&(e=n.findIndex((e=>e.name===t.name+xu)),i(e,"false"))}return n}},ku={defined:e=>{const t="global"===e.resolve&&e.bind&&ws(e.bind),n=1===e.project.items.length&&e.project.items[0].field!==xs;return t&&!n&&$i("Legend bindings are only supported for selections over an individual field or encoding channel."),t&&n},parse:(e,n,i)=>{const r=l(i);if(r.select=t.isString(r.select)?{type:r.select,toggle:n.toggle}:{...r.select,toggle:n.toggle},qu(n,r),t.isObject(i.select)&&(i.select.on||i.select.clear)){const e='event.item && indexof(event.item.mark.role, "legend") < 0';for(const i of n.events)i.filter=t.array(i.filter??[]),i.filter.includes(e)||i.filter.push(e)}const o=ks(n.bind)?n.bind.legend:"click",a=t.isString(o)?t.parseSelector(o,"view"):t.array(o);n.bind={legend:{merge:a}}},topLevelSignals:(e,t,n)=>{const i=t.name,r=ks(t.bind)&&t.bind.legend,o=e=>t=>{const n=l(t);return n.markname=e,n};for(const e of t.project.items){if(!e.hasLegend)continue;const a=`${_(e.field)}_legend`,s=`${i}_${a}`;if(0===n.filter((e=>e.name===s)).length){const e=r.merge.map(o(`${a}_symbols`)).concat(r.merge.map(o(`${a}_labels`))).concat(r.merge.map(o(`${a}_entries`)));n.unshift({name:s,...t.init?{}:{value:null},on:[{events:e,update:"isDefined(datum.value) ? datum.value : item().items[0].items[0].datum.value",force:!0},{events:r.merge,update:`!event.item || !datum ? null : ${s}`,force:!0}]})}}return n},signals:(e,t,n)=>{const i=t.name,r=t.project,o=n.find((e=>e.name===i+Au)),a=i+Oc,s=r.items.filter((e=>e.hasLegend)).map((e=>_(`${i}_${_(e.field)}_legend`))),l=`${s.map((e=>`${e} !== null`)).join(" && ")} ? {fields: ${a}, values: [${s.join(", ")}]} : null`;t.events&&s.length>0?o.on.push({events:s.map((e=>({signal:e}))),update:l}):s.length>0&&(o.update=l,delete o.value,delete o.on);const c=n.find((e=>e.name===i+xu)),u=ks(t.bind)&&t.bind.legend;return c&&(t.events?c.on.push({...c.on[0],events:u}):c.on[0].events=u),n}};const Su="_translate_anchor",Du="_translate_delta",Fu={defined:e=>"interval"===e.type&&e.translate,signals:(e,n,i)=>{const r=n.name,o=Cc.defined(n),a=r+Su,{x:s,y:l}=n.project.hasChannel;let c=t.parseSelector(n.translate,"scope");return o||(c=c.map((e=>(e.between[0].markname=r+jc,e)))),i.push({name:a,value:{},on:[{events:c.map((e=>e.between[0])),update:"{x: x(unit), y: y(unit)"+(void 0!==s?`, extent_x: ${o?Pc(e,Z):`slice(${s.signals.visual})`}`:"")+(void 0!==l?`, extent_y: ${o?Pc(e,ee):`slice(${l.signals.visual})`}`:"")+"}"}]},{name:r+Du,value:{},on:[{events:c,update:`{x: ${a}.x - x(unit), y: ${a}.y - y(unit)}`}]}),void 0!==s&&zu(e,n,s,"width",i),void 0!==l&&zu(e,n,l,"height",i),i}};function zu(e,t,n,i,r){const o=t.name,a=o+Su,s=o+Du,l=n.channel,c=Cc.defined(t),u=r.filter((e=>e.name===n.signals[c?"data":"visual"]))[0],f=e.getSizeSignalRef(i).signal,d=e.getScaleComponent(l),m=d&&d.get("type"),p=d&&d.get("reverse"),g=`${a}.extent_${l}`,h=`${c&&d?"log"===m?"panLog":"symlog"===m?"panSymlog":"pow"===m?"panPow":"panLinear":"panLinear"}(${g}, ${`${c?l===Z?p?"":"-":p?"-":"":""}${s}.${l} / ${c?`${f}`:`span(${g})`}`}${c?"pow"===m?`, ${d.get("exponent")??1}`:"symlog"===m?`, ${d.get("constant")??1}`:"":""})`;u.on.push({events:{signal:s},update:c?h:`clampRange(${h}, 0, ${f})`})}const Ou="_zoom_anchor",_u="_zoom_delta",Nu={defined:e=>"interval"===e.type&&e.zoom,signals:(e,n,i)=>{const r=n.name,o=Cc.defined(n),a=r+_u,{x:s,y:l}=n.project.hasChannel,c=t.stringValue(e.scaleName(Z)),u=t.stringValue(e.scaleName(ee));let f=t.parseSelector(n.zoom,"scope");return o||(f=f.map((e=>(e.markname=r+jc,e)))),i.push({name:r+Ou,on:[{events:f,update:o?"{"+[c?`x: invert(${c}, x(unit))`:"",u?`y: invert(${u}, y(unit))`:""].filter((e=>e)).join(", ")+"}":"{x: x(unit), y: y(unit)}"}]},{name:a,on:[{events:f,force:!0,update:"pow(1.001, event.deltaY * pow(16, event.deltaMode))"}]}),void 0!==s&&Cu(e,n,s,"width",i),void 0!==l&&Cu(e,n,l,"height",i),i}};function Cu(e,t,n,i,r){const o=t.name,a=n.channel,s=Cc.defined(t),l=r.filter((e=>e.name===n.signals[s?"data":"visual"]))[0],c=e.getSizeSignalRef(i).signal,u=e.getScaleComponent(a),f=u&&u.get("type"),d=s?Pc(e,a):l.name,m=o+_u,p=`${s&&u?"log"===f?"zoomLog":"symlog"===f?"zoomSymlog":"pow"===f?"zoomPow":"zoomLinear":"zoomLinear"}(${d}, ${`${o}${Ou}.${a}`}, ${m}${s?"pow"===f?`, ${u.get("exponent")??1}`:"symlog"===f?`, ${u.get("constant")??1}`:"":""})`;l.on.push({events:{signal:m},update:s?p:`clampRange(${p}, 0, ${c})`})}const Pu="_store",Au="_tuple",ju="_modify",Tu="vlSelectionResolve",Eu=[qc,Lc,Nc,$u,bu,Cc,ku,wu,Fu,Nu,vu];function Mu(e){let{escape:n}=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{escape:!0},i=n?t.stringValue(e.name):e.name;const r=function(e){let t=e.parent;for(;t&&!$m(t);)t=t.parent;return t}(e);if(r){const{facet:e}=r;for(const n of Re)e[n]&&(i+=` + '__facet_${n}_' + (facet[${t.stringValue(r.vgField(n))}])`)}return i}function Lu(e){return F(e.component.selection??{}).reduce(((e,t)=>e||t.project.hasSelectionId),!1)}function qu(e,n){!t.isString(n.select)&&n.select.on||delete e.events,!t.isString(n.select)&&n.select.clear||delete e.clear,!t.isString(n.select)&&n.select.toggle||delete e.toggle}function Uu(e){const t=[];return"Identifier"===e.type?[e.name]:"Literal"===e.type?[e.value]:("MemberExpression"===e.type&&(t.push(...Uu(e.object)),t.push(...Uu(e.property))),t)}function Ru(e){return"MemberExpression"===e.object.type?Ru(e.object):"datum"===e.object.name}function Wu(e){const n=t.parseExpression(e),i=new Set;return n.visit((e=>{"MemberExpression"===e.type&&Ru(e)&&i.add(Uu(e).slice(1).join("."))})),i}class Bu extends vc{clone(){return new Bu(null,this.model,l(this.filter))}constructor(e,t,n){super(e),this.model=t,this.filter=n,qn(this,"expr",void 0),qn(this,"_dependentFields",void 0),this.expr=Vu(this.model,this.filter,this),this._dependentFields=Wu(this.expr)}dependentFields(){return this._dependentFields}producedFields(){return new Set}assemble(){return{type:"filter",expr:this.expr}}hash(){return`Filter ${this.expr}`}}function Iu(e,n,i){let r=arguments.length>3&&void 0!==arguments[3]?arguments[3]:"datum";const o=t.isString(n)?n:n.param,a=_(o),s=t.stringValue(a+Pu);let l;try{l=e.getSelectionComponent(a,o)}catch(e){return`!!${a}`}if(l.project.timeUnit){const t=i??e.component.data.raw,n=l.project.timeUnit.clone();t.parent?n.insertAsParentOf(t):t.parent=n}const c=`${l.project.hasSelectionId?"vlSelectionIdTest(":"vlSelectionTest("}${s}, ${r}${"global"===l.resolve?")":`, ${t.stringValue(l.resolve)})`}`,u=`length(data(${s}))`;return!1===n.empty?`${u} && ${c}`:`!${u} || ${c}`}function Hu(e,n,i){const r=_(n),o=i.encoding;let a,s=i.field;try{a=e.getSelectionComponent(r,n)}catch(e){return r}if(o||s){if(o&&!s){const e=a.project.items.filter((e=>e.channel===o));!e.length||e.length>1?(s=a.project.items[0].field,$i((e.length?"Multiple ":"No ")+`matching ${t.stringValue(o)} encoding found for selection ${t.stringValue(i.param)}. `+`Using "field": ${t.stringValue(s)}.`)):s=e[0].field}}else s=a.project.items[0].field,a.project.items.length>1&&$i(`A "field" or "encoding" must be specified when using a selection as a scale domain. Using "field": ${t.stringValue(s)}.`);return`${a.name}[${t.stringValue(E(s))}]`}function Vu(e,n,i){return N(n,(n=>t.isString(n)?n:function(e){return e?.param}(n)?Iu(e,n,i):Zi(n)))}function Gu(e,t,n,i){e.encode??={},e.encode[t]??={},e.encode[t].update??={},e.encode[t].update[n]=i}function Yu(e,n,i){let r=arguments.length>3&&void 0!==arguments[3]?arguments[3]:{header:!1};const{disable:o,orient:a,scale:s,labelExpr:l,title:c,zindex:u,...f}=e.combine();if(!o){for(const e in f){const i=Oa[e],r=f[e];if(i&&i!==n&&"both"!==i)delete f[e];else if(Fa(r)){const{condition:n,...i}=r,o=t.array(n),a=Da[e];if(a){const{vgProp:t,part:n}=a;Gu(f,n,t,[...o.map((e=>{const{test:t,...n}=e;return{test:Vu(null,t),...n}})),i]),delete f[e]}else if(null===a){const t={signal:o.map((e=>{const{test:t,...n}=e;return`${Vu(null,t)} ? ${zn(n)} : `})).join("")+zn(i)};f[e]=t}}else if(yn(r)){const t=Da[e];if(t){const{vgProp:n,part:i}=t;Gu(f,i,n,r),delete f[e]}}p(["labelAlign","labelBaseline"],e)&&null===f[e]&&delete f[e]}if("grid"===n){if(!f.grid)return;if(f.encode){const{grid:e}=f.encode;f.encode={...e?{grid:e}:{}},S(f.encode)&&delete f.encode}return{scale:s,orient:a,...f,domain:!1,labels:!1,aria:!1,maxExtent:0,minExtent:0,ticks:!1,zindex:U(u,0)}}{if(!r.header&&e.mainExtracted)return;if(void 0!==l){let e=l;f.encode?.labels?.update&&yn(f.encode.labels.update.text)&&(e=M(l,"datum.label",f.encode.labels.update.text.signal)),Gu(f,"labels","text",{signal:e})}if(null===f.labelAlign&&delete f.labelAlign,f.encode){for(const t of za)e.hasAxisPart(t)||delete f.encode[t];S(f.encode)&&delete f.encode}const n=function(e,n){if(e)return t.isArray(e)&&!hn(e)?e.map((e=>da(e,n))).join(", "):e}(c,i);return{scale:s,orient:a,grid:!1,...n?{title:n}:{},...f,...!1===i.aria?{aria:!1}:{},zindex:U(u,0)}}}}function Xu(e){const{axes:t}=e.component,n=[];for(const i of Ft)if(t[i])for(const r of t[i])if(!r.get("disable")&&!r.get("gridScale")){const t="x"===i?"height":"width",r=e.getSizeSignalRef(t).signal;t!==r&&n.push({name:t,update:r})}return n}function Qu(e,t,n,i){return Object.assign.apply(null,[{},...e.map((e=>{if("axisOrient"===e){const e="x"===n?"bottom":"left",r=t["x"===n?"axisBottom":"axisLeft"]||{},o=t["x"===n?"axisTop":"axisRight"]||{},a=new Set([...D(r),...D(o)]),s={};for(const t of a.values())s[t]={signal:`${i.signal} === "${e}" ? ${On(r[t])} : ${On(o[t])}`};return s}return t[e]}))])}function Ju(e,n){const i=[{}];for(const r of e){let e=n[r]?.style;if(e){e=t.array(e);for(const t of e)i.push(n.style[t])}}return Object.assign.apply(null,i)}function Ku(e,t,n){let i=arguments.length>3&&void 0!==arguments[3]?arguments[3]:{};const r=jn(e,n,t);if(void 0!==r)return{configFrom:"style",configValue:r};for(const t of["vlOnlyAxisConfig","vgAxisConfig","axisConfigStyle"])if(void 0!==i[t]?.[e])return{configFrom:t,configValue:i[t][e]};return{}}const Zu={scale:e=>{let{model:t,channel:n}=e;return t.scaleName(n)},format:e=>{let{format:t}=e;return t},formatType:e=>{let{formatType:t}=e;return t},grid:e=>{let{fieldOrDatumDef:t,axis:n,scaleType:i}=e;return n.grid??function(e,t){return!xr(e)&&Ho(t)&&!ln(t?.bin)&&!cn(t?.bin)}(i,t)},gridScale:e=>{let{model:t,channel:n}=e;return function(e,t){const n="x"===t?"y":"x";if(e.getScaleComponent(n))return e.scaleName(n);return}(t,n)},labelAlign:e=>{let{axis:t,labelAngle:n,orient:i,channel:r}=e;return t.labelAlign||nf(n,i,r)},labelAngle:e=>{let{labelAngle:t}=e;return t},labelBaseline:e=>{let{axis:t,labelAngle:n,orient:i,channel:r}=e;return t.labelBaseline||tf(n,i,r)},labelFlush:e=>{let{axis:t,fieldOrDatumDef:n,channel:i}=e;return t.labelFlush??function(e,t){if("x"===t&&p(["quantitative","temporal"],e))return!0;return}(n.type,i)},labelOverlap:e=>{let{axis:n,fieldOrDatumDef:i,scaleType:r}=e;return n.labelOverlap??function(e,n,i,r){if(i&&!t.isObject(r)||"nominal"!==e&&"ordinal"!==e)return"log"!==n&&"symlog"!==n||"greedy";return}(i.type,r,Ho(i)&&!!i.timeUnit,Ho(i)?i.sort:void 0)},orient:e=>{let{orient:t}=e;return t},tickCount:e=>{let{channel:t,model:n,axis:i,fieldOrDatumDef:r,scaleType:o}=e;const a="x"===t?"width":"y"===t?"height":void 0,s=a?n.getSizeSignalRef(a):void 0;return i.tickCount??function(e){let{fieldOrDatumDef:t,scaleType:n,size:i,values:r}=e;if(!r&&!xr(n)&&"log"!==n){if(Ho(t)){if(ln(t.bin))return{signal:`ceil(${i.signal}/10)`};if(t.timeUnit&&p(["month","hours","day","quarter"],Ui(t.timeUnit)?.unit))return}return{signal:`ceil(${i.signal}/40)`}}return}({fieldOrDatumDef:r,scaleType:o,size:s,values:i.values})},tickMinStep:function(e){let{format:t,fieldOrDatumDef:n}=e;if("d"===t)return 1;if(Ho(n)){const{timeUnit:e}=n;if(e){const t=Ri(e);if(t)return{signal:t}}}return},title:e=>{let{axis:t,model:n,channel:i}=e;if(void 0!==t.title)return t.title;const r=rf(n,i);if(void 0!==r)return r;const o=n.typedFieldDef(i),a="x"===i?"x2":"y2",s=n.fieldDef(a);return En(o?[Eo(o)]:[],Ho(s)?[Eo(s)]:[])},values:e=>{let{axis:n,fieldOrDatumDef:i}=e;return function(e,n){const i=e.values;if(t.isArray(i))return ka(n,i);if(yn(i))return i;return}(n,i)},zindex:e=>{let{axis:t,fieldOrDatumDef:n,mark:i}=e;return t.zindex??function(e,t){if("rect"===e&&aa(t))return 1;return 0}(i,n)}};function ef(e){return`(((${e.signal} % 360) + 360) % 360)`}function tf(e,t,n,i){if(void 0!==e){if("x"===n){if(yn(e)){const n=ef(e);return{signal:`(45 < ${n} && ${n} < 135) || (225 < ${n} && ${n} < 315) ? "middle" :(${n} <= 45 || 315 <= ${n}) === ${yn(t)?`(${t.signal} === "top")`:"top"===t} ? "bottom" : "top"`}}if(45{if(ea(t)&&Po(t.sort)){const{field:i,timeUnit:r}=t,o=t.sort,a=o.map(((e,t)=>`${Zi({field:i,timeUnit:r,equal:e})} ? ${t} : `)).join("")+o.length;e=new of(e,{calculate:a,as:af(t,n,{forAs:!0})})}})),e}producedFields(){return new Set([this.transform.as])}dependentFields(){return this._dependentFields}assemble(){return{type:"formula",expr:this.transform.calculate,as:this.transform.as}}hash(){return`Calculate ${d(this.transform)}`}}function af(e,t,n){return oa(e,{prefix:t,suffix:"sort_index",...n??{}})}function sf(e,t){return p(["top","bottom"],t)?"column":p(["left","right"],t)||"row"===e?"row":"column"}function lf(e,t,n,i){const r="row"===i?n.headerRow:"column"===i?n.headerColumn:n.headerFacet;return U((t||{})[e],r[e],n.header[e])}function cf(e,t,n,i){const r={};for(const o of e){const e=lf(o,t||{},n,i);void 0!==e&&(r[o]=e)}return r}const uf=["row","column"],ff=["header","footer"];function df(e,t){const n=e.component.layoutHeaders[t].title,i=e.config?e.config:void 0,r=e.component.layoutHeaders[t].facetFieldDef?e.component.layoutHeaders[t].facetFieldDef:void 0,{titleAnchor:o,titleAngle:a,titleOrient:s}=cf(["titleAnchor","titleAngle","titleOrient"],r.header,i,t),l=sf(t,s),c=H(a);return{name:`${t}-title`,type:"group",role:`${l}-title`,title:{text:n,..."row"===t?{orient:"left"}:{},style:"guide-title",...pf(c,l),...mf(l,c,o),...$f(i,r,t,hs,ps)}}}function mf(e,t){switch(arguments.length>2&&void 0!==arguments[2]?arguments[2]:"middle"){case"start":return{align:"left"};case"end":return{align:"right"}}const n=nf(t,"row"===e?"left":"top","row"===e?"y":"x");return n?{align:n}:{}}function pf(e,t){const n=tf(e,"row"===t?"left":"top","row"===t?"y":"x",!0);return n?{baseline:n}:{}}function gf(e,t){const n=e.component.layoutHeaders[t],i=[];for(const r of ff)if(n[r])for(const o of n[r]){const a=vf(e,t,r,n,o);null!=a&&i.push(a)}return i}function hf(e,n){const{sort:i}=e;return Co(i)?{field:oa(i,{expr:"datum"}),order:i.order??"ascending"}:t.isArray(i)?{field:af(e,n,{expr:"datum"}),order:"ascending"}:{field:oa(e,{expr:"datum"}),order:i??"ascending"}}function yf(e,t,n){const{format:i,formatType:r,labelAngle:o,labelAnchor:a,labelOrient:s,labelExpr:l}=cf(["format","formatType","labelAngle","labelAnchor","labelOrient","labelExpr"],e.header,n,t),c=vo({fieldOrDatumDef:e,format:i,formatType:r,expr:"parent",config:n}).signal,u=sf(t,s);return{text:{signal:l?M(M(l,"datum.label",c),"datum.value",oa(e,{expr:"parent"})):c},..."row"===t?{orient:"left"}:{},style:"guide-label",frame:"group",...pf(o,u),...mf(u,o,a),...$f(n,e,t,ys,gs)}}function vf(e,t,n,i,r){if(r){let o=null;const{facetFieldDef:a}=i,s=e.config?e.config:void 0;if(a&&r.labels){const{labelOrient:e}=cf(["labelOrient"],a.header,s,t);("row"===t&&!p(["top","bottom"],e)||"column"===t&&!p(["left","right"],e))&&(o=yf(a,t,s))}const l=$m(e)&&!Ao(e.facet),c=r.axes,u=c?.length>0;if(o||u){const s="row"===t?"height":"width";return{name:e.getName(`${t}_${n}`),type:"group",role:`${t}-${n}`,...i.facetFieldDef?{from:{data:e.getName(`${t}_domain`)},sort:hf(a,t)}:{},...u&&l?{from:{data:e.getName(`facet_domain_${t}`)}}:{},...o?{title:o}:{},...r.sizeSignal?{encode:{update:{[s]:r.sizeSignal}}}:{},...u?{axes:c}:{}}}}return null}const bf={column:{start:0,end:1},row:{start:1,end:0}};function xf(e,t){return bf[t][e]}function $f(e,t,n,i,r){const o={};for(const a of i){if(!r[a])continue;const i=lf(a,t?.header,e,n);void 0!==i&&(o[r[a]]=i)}return o}function wf(e){return[...kf(e,"width"),...kf(e,"height"),...kf(e,"childWidth"),...kf(e,"childHeight")]}function kf(e,t){const n="width"===t?"x":"y",i=e.component.layoutSize.get(t);if(!i||"merged"===i)return[];const r=e.getSizeSignalRef(t).signal;if("step"===i){const t=e.getScaleComponent(n);if(t){const i=t.get("type"),o=t.get("range");if(xr(i)&&vn(o)){const i=e.scaleName(n);if($m(e.parent)){if("independent"===e.parent.component.resolve.scale[n])return[Sf(i,o)]}return[Sf(i,o),{name:r,update:Df(i,t,`domain('${i}').length`)}]}}throw new Error("layout size is step although width/height is not step.")}if("container"==i){const t=r.endsWith("width"),n=t?"containerSize()[0]":"containerSize()[1]",i=`isFinite(${n}) ? ${n} : ${As(e.config.view,t?"width":"height")}`;return[{name:r,init:i,on:[{update:i,events:"window:resize"}]}]}return[{name:r,value:i}]}function Sf(e,t){const n=`${e}_step`;return yn(t.step)?{name:n,update:t.step.signal}:{name:n,value:t.step}}function Df(e,t,n){const i=t.get("type"),r=t.get("padding"),o=U(t.get("paddingOuter"),r);let a=t.get("paddingInner");return a="band"===i?void 0!==a?a:r:1,`bandspace(${n}, ${On(a)}, ${On(o)}) * ${e}_step`}function Ff(e){return"childWidth"===e?"width":"childHeight"===e?"height":e}function zf(e,t){return D(e).reduce(((n,i)=>{const r=e[i];return{...n,...Uc(t,r,i,(e=>Fn(e.value)))}}),{})}function Of(e,t){if($m(t))return"theta"===e?"independent":"shared";if(km(t))return"shared";if(wm(t))return zt(e)||"theta"===e||"radius"===e?"independent":"shared";throw new Error("invalid model type for resolve")}function _f(e,t){const n=e.scale[t],i=zt(t)?"axis":"legend";return"independent"===n?("shared"===e[i][t]&&$i(function(e){return`Setting the scale to be independent for "${e}" means we also have to set the guide (axis or legend) to be independent.`}(t)),"independent"):e[i][t]||"shared"}const Nf=D({aria:1,clipHeight:1,columnPadding:1,columns:1,cornerRadius:1,description:1,direction:1,fillColor:1,format:1,formatType:1,gradientLength:1,gradientOpacity:1,gradientStrokeColor:1,gradientStrokeWidth:1,gradientThickness:1,gridAlign:1,labelAlign:1,labelBaseline:1,labelColor:1,labelFont:1,labelFontSize:1,labelFontStyle:1,labelFontWeight:1,labelLimit:1,labelOffset:1,labelOpacity:1,labelOverlap:1,labelPadding:1,labelSeparation:1,legendX:1,legendY:1,offset:1,orient:1,padding:1,rowPadding:1,strokeColor:1,symbolDash:1,symbolDashOffset:1,symbolFillColor:1,symbolLimit:1,symbolOffset:1,symbolOpacity:1,symbolSize:1,symbolStrokeColor:1,symbolStrokeWidth:1,symbolType:1,tickCount:1,tickMinStep:1,title:1,titleAlign:1,titleAnchor:1,titleBaseline:1,titleColor:1,titleFont:1,titleFontSize:1,titleFontStyle:1,titleFontWeight:1,titleLimit:1,titleLineHeight:1,titleOpacity:1,titleOrient:1,titlePadding:1,type:1,values:1,zindex:1,disable:1,labelExpr:1,selections:1,opacity:1,shape:1,stroke:1,fill:1,size:1,strokeWidth:1,strokeDash:1,encode:1});class Cf extends Jl{}const Pf={symbols:function(e,n){let{fieldOrDatumDef:i,model:r,channel:o,legendCmpt:a,legendType:s}=n;if("symbol"!==s)return;const{markDef:l,encoding:c,config:u,mark:f}=r,d=l.filled&&"trail"!==f;let m={..._n({},r,eo),...Qc(r,{filled:d})};const p=a.get("symbolOpacity")??u.legend.symbolOpacity,g=a.get("symbolFillColor")??u.legend.symbolFillColor,h=a.get("symbolStrokeColor")??u.legend.symbolStrokeColor,y=void 0===p?Af(c.opacity)??l.opacity:void 0;if(m.fill)if("fill"===o||d&&o===me)delete m.fill;else if(m.fill.field)g?delete m.fill:(m.fill=Fn(u.legend.symbolBaseFillColor??"black"),m.fillOpacity=Fn(y??1));else if(t.isArray(m.fill)){const e=jf(c.fill??c.color)??l.fill??(d&&l.color);e&&(m.fill=Fn(e))}if(m.stroke)if("stroke"===o||!d&&o===me)delete m.stroke;else if(m.stroke.field||h)delete m.stroke;else if(t.isArray(m.stroke)){const e=U(jf(c.stroke||c.color),l.stroke,d?l.color:void 0);e&&(m.stroke={value:e})}if(o!==be){const e=Ho(i)&&Ef(r,a,i);e?m.opacity=[{test:e,...Fn(y??1)},Fn(u.legend.unselectedOpacity)]:y&&(m.opacity=Fn(y))}return m={...m,...e},S(m)?void 0:m},gradient:function(e,t){let{model:n,legendType:i,legendCmpt:r}=t;if("gradient"!==i)return;const{config:o,markDef:a,encoding:s}=n;let l={};const c=void 0===(r.get("gradientOpacity")??o.legend.gradientOpacity)?Af(s.opacity)||a.opacity:void 0;c&&(l.opacity=Fn(c));return l={...l,...e},S(l)?void 0:l},labels:function(e,t){let{fieldOrDatumDef:n,model:i,channel:r,legendCmpt:o}=t;const a=i.legend(r)||{},s=i.config,l=Ho(n)?Ef(i,o,n):void 0,c=l?[{test:l,value:1},{value:s.legend.unselectedOpacity}]:void 0,{format:u,formatType:f}=a;let d;go(f)?d=xo({fieldOrDatumDef:n,field:"datum.value",format:u,formatType:f,config:s}):void 0===u&&void 0===f&&s.customFormatTypes&&("quantitative"===n.type&&s.numberFormatType?d=xo({fieldOrDatumDef:n,field:"datum.value",format:s.numberFormat,formatType:s.numberFormatType,config:s}):"temporal"===n.type&&s.timeFormatType&&Ho(n)&&void 0===n.timeUnit&&(d=xo({fieldOrDatumDef:n,field:"datum.value",format:s.timeFormat,formatType:s.timeFormatType,config:s})));const m={...c?{opacity:c}:{},...d?{text:d}:{},...e};return S(m)?void 0:m},entries:function(e,t){let{legendCmpt:n}=t;const i=n.get("selections");return i?.length?{...e,fill:{value:"transparent"}}:e}};function Af(e){return Tf(e,((e,t)=>Math.max(e,t.value)))}function jf(e){return Tf(e,((e,t)=>U(e,t.value)))}function Tf(e,n){return function(e){const n=e?.condition;return!!n&&(t.isArray(n)||Zo(n))}(e)?t.array(e.condition).reduce(n,e.value):Zo(e)?e.value:void 0}function Ef(e,n,i){const r=n.get("selections");if(!r?.length)return;const o=t.stringValue(i.field);return r.map((e=>`(!length(data(${t.stringValue(_(e)+Pu)})) || (${e}[${o}] && indexof(${e}[${o}], datum.value) >= 0))`)).join(" || ")}const Mf={direction:e=>{let{direction:t}=e;return t},format:e=>{let{fieldOrDatumDef:t,legend:n,config:i}=e;const{format:r,formatType:o}=n;return $o(t,t.type,r,o,i,!1)},formatType:e=>{let{legend:t,fieldOrDatumDef:n,scaleType:i}=e;const{formatType:r}=t;return wo(r,n,i)},gradientLength:e=>{const{legend:t,legendConfig:n}=e;return t.gradientLength??n.gradientLength??function(e){let{legendConfig:t,model:n,direction:i,orient:r,scaleType:o}=e;const{gradientHorizontalMaxLength:a,gradientHorizontalMinLength:s,gradientVerticalMaxLength:l,gradientVerticalMinLength:c}=t;if(wr(o))return"horizontal"===i?"top"===r||"bottom"===r?Uf(n,"width",s,a):s:Uf(n,"height",c,l);return}(e)},labelOverlap:e=>{let{legend:t,legendConfig:n,scaleType:i}=e;return t.labelOverlap??n.labelOverlap??function(e){if(p(["quantile","threshold","log","symlog"],e))return"greedy";return}(i)},symbolType:e=>{let{legend:t,markDef:n,channel:i,encoding:r}=e;return t.symbolType??function(e,t,n,i){if("shape"!==t){const e=jf(n)??i;if(e)return e}switch(e){case"bar":case"rect":case"image":case"square":return"square";case"line":case"trail":case"rule":return"stroke";case"arc":case"point":case"circle":case"tick":case"geoshape":case"area":case"text":return"circle"}}(n.type,i,r.shape,n.shape)},title:e=>{let{fieldOrDatumDef:t,config:n}=e;return ua(t,n,{allowDisabling:!0})},type:e=>{let{legendType:t,scaleType:n,channel:i}=e;if(qe(i)&&wr(n)){if("gradient"===t)return}else if("symbol"===t)return;return t},values:e=>{let{fieldOrDatumDef:n,legend:i}=e;return function(e,n){const i=e.values;if(t.isArray(i))return ka(n,i);if(yn(i))return i;return}(i,n)}};function Lf(e){const{legend:t}=e;return U(t.type,function(e){let{channel:t,timeUnit:n,scaleType:i}=e;if(qe(t)){if(p(["quarter","month","day"],n))return"symbol";if(wr(i))return"gradient"}return"symbol"}(e))}function qf(e){let{legendConfig:t,legendType:n,orient:i,legend:r}=e;return r.direction??t[n?"gradientDirection":"symbolDirection"]??function(e,t){switch(e){case"top":case"bottom":return"horizontal";case"left":case"right":case"none":case void 0:return;default:return"gradient"===t?"horizontal":void 0}}(i,n)}function Uf(e,t,n,i){return{signal:`clamp(${e.getSizeSignalRef(t).signal}, ${n}, ${i})`}}function Rf(e){const t=xm(e)?function(e){const{encoding:t}=e,n={};for(const i of[me,...bs]){const r=ga(t[i]);r&&e.getScaleComponent(i)&&(i===he&&Ho(r)&&r.type===lr||(n[i]=Bf(e,i)))}return n}(e):function(e){const{legends:t,resolve:n}=e.component;for(const i of e.children){Rf(i);for(const r of D(i.component.legends))n.legend[r]=_f(e.component.resolve,r),"shared"===n.legend[r]&&(t[r]=If(t[r],i.component.legends[r]),t[r]||(n.legend[r]="independent",delete t[r]))}for(const i of D(t))for(const t of e.children)t.component.legends[i]&&"shared"===n.legend[i]&&delete t.component.legends[i];return t}(e);return e.component.legends=t,t}function Wf(e,t,n,i){switch(t){case"disable":return void 0!==n;case"values":return!!n?.values;case"title":if("title"===t&&e===i?.title)return!0}return e===(n||{})[t]}function Bf(e,t){let n=e.legend(t);const{markDef:i,encoding:r,config:o}=e,a=o.legend,s=new Cf({},function(e,t){const n=e.scaleName(t);if("trail"===e.mark){if("color"===t)return{stroke:n};if("size"===t)return{strokeWidth:n}}return"color"===t?e.markDef.filled?{fill:n}:{stroke:n}:{[t]:n}}(e,t));!function(e,t,n){const i=e.fieldDef(t)?.field;for(const r of F(e.component.selection??{})){const e=r.project.hasField[i]??r.project.hasChannel[t];if(e&&ku.defined(r)){const t=n.get("selections")??[];t.push(r.name),n.set("selections",t,!1),e.hasLegend=!0}}}(e,t,s);const l=void 0!==n?!n:a.disable;if(s.set("disable",l,void 0!==n),l)return s;n=n||{};const c=e.getScaleComponent(t).get("type"),u=ga(r[t]),f=Ho(u)?Ui(u.timeUnit)?.unit:void 0,d=n.orient||o.legend.orient||"right",m=Lf({legend:n,channel:t,timeUnit:f,scaleType:c}),p={legend:n,channel:t,model:e,markDef:i,encoding:r,fieldOrDatumDef:u,legendConfig:a,config:o,scaleType:c,orient:d,legendType:m,direction:qf({legend:n,legendType:m,orient:d,legendConfig:a})};for(const i of Nf){if("gradient"===m&&i.startsWith("symbol")||"symbol"===m&&i.startsWith("gradient"))continue;const r=i in Mf?Mf[i](p):n[i];if(void 0!==r){const a=Wf(r,i,n,e.fieldDef(t));(a||void 0===o.legend[i])&&s.set(i,r,a)}}const g=n?.encoding??{},h=s.get("selections"),y={},v={fieldOrDatumDef:u,model:e,channel:t,legendCmpt:s,legendType:m};for(const t of["labels","legend","title","symbols","gradient","entries"]){const n=zf(g[t]??{},e),i=t in Pf?Pf[t](n,v):n;void 0===i||S(i)||(y[t]={...h?.length&&Ho(u)?{name:`${_(u.field)}_legend_${t}`}:{},...h?.length?{interactive:!!h}:{},update:i})}return S(y)||s.set("encode",y,!!n?.encoding),s}function If(e,t){if(!e)return t.clone();const n=e.getWithExplicit("orient"),i=t.getWithExplicit("orient");if(n.explicit&&i.explicit&&n.value!==i.value)return;let r=!1;for(const n of Nf){const i=nc(e.getWithExplicit(n),t.getWithExplicit(n),n,"legend",((e,t)=>{switch(n){case"symbolType":return Hf(e,t);case"title":return Ln(e,t);case"type":return r=!0,Zl("symbol")}return tc(e,t,n,"legend")}));e.setWithExplicit(n,i)}return r&&(e.implicit?.encode?.gradient&&C(e.implicit,["encode","gradient"]),e.explicit?.encode?.gradient&&C(e.explicit,["encode","gradient"])),e}function Hf(e,t){return"circle"===t.value?t:e}function Vf(e){const t=e.component.legends,n={};for(const i of D(t)){const r=X(e.getScaleComponent(i).get("domains"));if(n[r])for(const e of n[r]){If(e,t[i])||n[r].push(t[i])}else n[r]=[t[i].clone()]}return F(n).flat().map((t=>function(e,t){const{disable:n,labelExpr:i,selections:r,...o}=e.combine();if(n)return;!1===t.aria&&null==o.aria&&(o.aria=!1);if(o.encode?.symbols){const e=o.encode.symbols.update;!e.fill||"transparent"===e.fill.value||e.stroke||o.stroke||(e.stroke={value:"transparent"});for(const t of bs)o[t]&&delete e[t]}o.title||delete o.title;if(void 0!==i){let e=i;o.encode?.labels?.update&&yn(o.encode.labels.update.text)&&(e=M(i,"datum.label",o.encode.labels.update.text.signal)),function(e,t,n,i){e.encode??={},e.encode[t]??={},e.encode[t].update??={},e.encode[t].update[n]=i}(o,"labels","text",{signal:e})}return o}(t,e.config))).filter((e=>void 0!==e))}function Gf(e){return km(e)||wm(e)?function(e){return e.children.reduce(((e,t)=>e.concat(t.assembleProjections())),Yf(e))}(e):Yf(e)}function Yf(e){const t=e.component.projection;if(!t||t.merged)return[];const n=t.combine(),{name:i}=n;if(t.data){const r={signal:`[${t.size.map((e=>e.signal)).join(", ")}]`},o=t.data.reduce(((t,n)=>{const i=yn(n)?n.signal:`data('${e.lookupDataSource(n)}')`;return p(t,i)||t.push(i),t}),[]);if(o.length<=0)throw new Error("Projection's fit didn't find any data sources");return[{name:i,size:r,fit:{signal:o.length>1?`[${o.join(", ")}]`:o[0]},...n}]}return[{name:i,translate:{signal:"[width / 2, height / 2]"},...n}]}const Xf=["type","clipAngle","clipExtent","center","rotate","precision","reflectX","reflectY","coefficient","distance","fraction","lobes","parallel","radius","ratio","spacing","tilt"];class Qf extends Jl{constructor(e,t,n,i){super({...t},{name:e}),this.specifiedProjection=t,this.size=n,this.data=i,qn(this,"merged",!1)}get isFit(){return!!this.data}}function Jf(e){e.component.projection=xm(e)?function(e){if(e.hasProjection){const t=pn(e.specifiedProjection),n=!(t&&(null!=t.scale||null!=t.translate)),i=n?[e.getSizeSignalRef("width"),e.getSizeSignalRef("height")]:void 0,r=n?function(e){const t=[],{encoding:n}=e;for(const i of[[ue,ce],[de,fe]])(ga(n[i[0]])||ga(n[i[1]]))&&t.push({signal:e.getName(`geojson_${t.length}`)});e.channelHasField(he)&&e.typedFieldDef(he).type===lr&&t.push({signal:e.getName(`geojson_${t.length}`)});0===t.length&&t.push(e.requestDataName(fc.Main));return t}(e):void 0,o=new Qf(e.projectionName(!0),{...pn(e.config.projection)??{},...t??{}},i,r);return o.get("type")||o.set("type","equalEarth",!1),o}return}(e):function(e){if(0===e.children.length)return;let n;for(const t of e.children)Jf(t);const i=h(e.children,(e=>{const i=e.component.projection;if(i){if(n){const e=function(e,n){const i=h(Xf,(i=>!t.hasOwnProperty(e.explicit,i)&&!t.hasOwnProperty(n.explicit,i)||!!(t.hasOwnProperty(e.explicit,i)&&t.hasOwnProperty(n.explicit,i)&&Y(e.get(i),n.get(i)))));if(Y(e.size,n.size)){if(i)return e;if(Y(e.explicit,{}))return n;if(Y(n.explicit,{}))return e}return null}(n,i);return e&&(n=e),!!e}return n=i,!0}return!0}));if(n&&i){const t=e.projectionName(!0),i=new Qf(t,n.specifiedProjection,n.size,l(n.data));for(const n of e.children){const e=n.component.projection;e&&(e.isFit&&i.data.push(...n.component.projection.data),n.renameProjection(e.get("name"),t),e.merged=!0)}return i}return}(e)}function Kf(e,t,n,i){if(Sa(t,n)){const r=xm(e)?e.axis(n)??e.legend(n)??{}:{},o=oa(t,{expr:"datum"}),a=oa(t,{expr:"datum",binSuffix:"end"});return{formulaAs:oa(t,{binSuffix:"range",forAs:!0}),formula:Fo(o,a,r.format,r.formatType,i)}}return{}}function Zf(e,t){return`${sn(e)}_${t}`}function ed(e,t,n){const i=Zf(ba(n,void 0)??{},t);return e.getName(`${i}_bins`)}function td(e,n,i){let r,o;r=function(e){return"as"in e}(e)?t.isString(e.as)?[e.as,`${e.as}_end`]:[e.as[0],e.as[1]]:[oa(e,{forAs:!0}),oa(e,{binSuffix:"end",forAs:!0})];const a={...ba(n,void 0)},s=Zf(a,e.field),{signal:l,extentSignal:c}=function(e,t){return{signal:e.getName(`${t}_bins`),extentSignal:e.getName(`${t}_extent`)}}(i,s);if(fn(a.extent)){const e=a.extent;o=Hu(i,e.param,e),delete a.extent}return{key:s,binComponent:{bin:a,field:e.field,as:[r],...l?{signal:l}:{},...c?{extentSignal:c}:{},...o?{span:o}:{}}}}class nd extends vc{clone(){return new nd(null,l(this.bins))}constructor(e,t){super(e),this.bins=t}static makeFromEncoding(e,t){const n=t.reduceFieldDef(((e,n,i)=>{if(Ko(n)&&ln(n.bin)){const{key:r,binComponent:o}=td(n,n.bin,t);e[r]={...o,...e[r],...Kf(t,n,i,t.config)}}return e}),{});return S(n)?null:new nd(e,n)}static makeFromTransform(e,t,n){const{key:i,binComponent:r}=td(t,t.bin,n);return new nd(e,{[i]:r})}merge(e,t){for(const n of D(e.bins))n in this.bins?(t(e.bins[n].signal,this.bins[n].signal),this.bins[n].as=b([...this.bins[n].as,...e.bins[n].as],d)):this.bins[n]=e.bins[n];for(const t of e.children)e.removeChild(t),t.parent=this;e.remove()}producedFields(){return new Set(F(this.bins).map((e=>e.as)).flat(2))}dependentFields(){return new Set(F(this.bins).map((e=>e.field)))}hash(){return`Bin ${d(this.bins)}`}assemble(){return F(this.bins).flatMap((e=>{const t=[],[n,...i]=e.as,{extent:r,...o}=e.bin,a={type:"bin",field:E(e.field),as:n,signal:e.signal,...fn(r)?{extent:null}:{extent:r},...e.span?{span:{signal:`span(${e.span})`}}:{},...o};!r&&e.extentSignal&&(t.push({type:"extent",field:E(e.field),signal:e.extentSignal}),a.extent={signal:e.extentSignal}),t.push(a);for(const e of i)for(let i=0;i<2;i++)t.push({type:"formula",expr:oa({field:n[i]},{expr:"datum"}),as:e[i]});return e.formula&&t.push({type:"formula",expr:e.formula,as:e.formulaAs}),t}))}}function id(e,n,i,r){const o=xm(r)?r.encoding[it(n)]:void 0;if(Ko(i)&&xm(r)&&Uo(i,o,r.markDef,r.config)){e.add(oa(i,{})),e.add(oa(i,{suffix:"end"}));const{mark:t,markDef:o,config:a}=r,s=Lo({fieldDef:i,markDef:o,config:a});Jr(t)&&.5!==s&&zt(n)&&(e.add(oa(i,{suffix:kc})),e.add(oa(i,{suffix:Sc}))),i.bin&&Sa(i,n)&&e.add(oa(i,{binSuffix:"range"}))}else if(Ee(n)){const t=Te(n);e.add(r.getName(t))}else e.add(oa(i));return ea(i)&&function(e){return t.isObject(e)&&"field"in e}(i.scale?.range)&&e.add(i.scale.range.field),e}class rd extends vc{clone(){return new rd(null,new Set(this.dimensions),l(this.measures))}constructor(e,t,n){super(e),this.dimensions=t,this.measures=n}get groupBy(){return this.dimensions}static makeFromEncoding(e,t){let n=!1;t.forEachFieldDef((e=>{e.aggregate&&(n=!0)}));const i={},r=new Set;return n?(t.forEachFieldDef(((e,n)=>{const{aggregate:o,field:a}=e;if(o)if("count"===o)i["*"]??={},i["*"].count=new Set([oa(e,{forAs:!0})]);else{if(Zt(o)||en(o)){const e=Zt(o)?"argmin":"argmax",t=o[e];i[t]??={},i[t][e]=new Set([oa({op:e,field:t},{forAs:!0})])}else i[a]??={},i[a][o]=new Set([oa(e,{forAs:!0})]);Ht(n)&&"unaggregated"===t.scaleDomain(n)&&(i[a]??={},i[a].min=new Set([oa({field:a,aggregate:"min"},{forAs:!0})]),i[a].max=new Set([oa({field:a,aggregate:"max"},{forAs:!0})]))}else id(r,n,e,t)})),r.size+D(i).length===0?null:new rd(e,r,i)):null}static makeFromTransform(e,t){const n=new Set,i={};for(const e of t.aggregate){const{op:t,field:n,as:r}=e;t&&("count"===t?(i["*"]??={},i["*"].count=new Set([r||oa(e,{forAs:!0})])):(i[n]??={},i[n][t]=new Set([r||oa(e,{forAs:!0})])))}for(const e of t.groupby??[])n.add(e);return n.size+D(i).length===0?null:new rd(e,n,i)}merge(e){return x(this.dimensions,e.dimensions)?(function(e,t){for(const n of D(t)){const i=t[n];for(const t of D(i))n in e?e[n][t]=new Set([...e[n][t]??[],...i[t]]):e[n]={[t]:i[t]}}}(this.measures,e.measures),!0):(function(){xi.debug(...arguments)}("different dimensions, cannot merge"),!1)}addDimensions(e){e.forEach(this.dimensions.add,this.dimensions)}dependentFields(){return new Set([...this.dimensions,...D(this.measures)])}producedFields(){const e=new Set;for(const t of D(this.measures))for(const n of D(this.measures[t])){const i=this.measures[t][n];0===i.size?e.add(`${n}_${t}`):i.forEach(e.add,e)}return e}hash(){return`Aggregate ${d({dimensions:this.dimensions,measures:this.measures})}`}assemble(){const e=[],t=[],n=[];for(const i of D(this.measures))for(const r of D(this.measures[i]))for(const o of this.measures[i][r])n.push(o),e.push(r),t.push("*"===i?null:E(i));return{type:"aggregate",groupby:[...this.dimensions].map(E),ops:e,fields:t,as:n}}}class od extends vc{constructor(e,n,i,r){super(e),this.model=n,this.name=i,this.data=r,qn(this,"column",void 0),qn(this,"row",void 0),qn(this,"facet",void 0),qn(this,"childModel",void 0);for(const e of Re){const i=n.facet[e];if(i){const{bin:r,sort:o}=i;this[e]={name:n.getName(`${e}_domain`),fields:[oa(i),...ln(r)?[oa(i,{binSuffix:"end"})]:[]],...Co(o)?{sortField:o}:t.isArray(o)?{sortIndexField:af(i,e)}:{}}}}this.childModel=n.child}hash(){let e="Facet";for(const t of Re)this[t]&&(e+=` ${t.charAt(0)}:${d(this[t])}`);return e}get fields(){const e=[];for(const t of Re)this[t]?.fields&&e.push(...this[t].fields);return e}dependentFields(){const e=new Set(this.fields);for(const t of Re)this[t]&&(this[t].sortField&&e.add(this[t].sortField.field),this[t].sortIndexField&&e.add(this[t].sortIndexField));return e}producedFields(){return new Set}getSource(){return this.name}getChildIndependentFieldsWithStep(){const e={};for(const t of Ft){const n=this.childModel.component.scales[t];if(n&&!n.merged){const i=n.get("type"),r=n.get("range");if(xr(i)&&vn(r)){const n=Qd(Jd(this.childModel,t));n?e[t]=n:$i(Yn(t))}}}return e}assembleRowColumnHeaderData(e,t,n){const i={row:"y",column:"x",facet:void 0}[e],r=[],o=[],a=[];i&&n&&n[i]&&(t?(r.push(`distinct_${n[i]}`),o.push("max")):(r.push(n[i]),o.push("distinct")),a.push(`distinct_${n[i]}`));const{sortField:s,sortIndexField:l}=this[e];if(s){const{op:e=zo,field:t}=s;r.push(t),o.push(e),a.push(oa(s,{forAs:!0}))}else l&&(r.push(l),o.push("max"),a.push(l));return{name:this[e].name,source:t??this.data,transform:[{type:"aggregate",groupby:this[e].fields,...r.length?{fields:r,ops:o,as:a}:{}}]}}assembleFacetHeaderData(e){const{columns:t}=this.model.layout,{layoutHeaders:n}=this.model.component,i=[],r={};for(const e of uf){for(const t of ff){const i=(n[e]&&n[e][t])??[];for(const t of i)if(t.axes?.length>0){r[e]=!0;break}}if(r[e]){const n=`length(data("${this.facet.name}"))`,r="row"===e?t?{signal:`ceil(${n} / ${t})`}:1:t?{signal:`min(${n}, ${t})`}:{signal:n};i.push({name:`${this.facet.name}_${e}`,transform:[{type:"sequence",start:0,stop:r}]})}}const{row:o,column:a}=r;return(o||a)&&i.unshift(this.assembleRowColumnHeaderData("facet",null,e)),i}assemble(){const e=[];let t=null;const n=this.getChildIndependentFieldsWithStep(),{column:i,row:r,facet:o}=this;if(i&&r&&(n.x||n.y)){t=`cross_${this.column.name}_${this.row.name}`;const i=[].concat(n.x??[],n.y??[]),r=i.map((()=>"distinct"));e.push({name:t,source:this.data,transform:[{type:"aggregate",groupby:this.fields,fields:i,ops:r}]})}for(const i of[J,Q])this[i]&&e.push(this.assembleRowColumnHeaderData(i,t,n));if(o){const t=this.assembleFacetHeaderData(n);t&&e.push(...t)}return e}}function ad(e){return e.startsWith("'")&&e.endsWith("'")||e.startsWith('"')&&e.endsWith('"')?e.slice(1,-1):e}function sd(e){const n={};return a(e.filter,(e=>{if(Ji(e)){let i=null;Ii(e)?i=Sn(e.equal):Vi(e)?i=Sn(e.lte):Hi(e)?i=Sn(e.lt):Gi(e)?i=Sn(e.gt):Yi(e)?i=Sn(e.gte):Xi(e)?i=e.range[0]:Qi(e)&&(i=(e.oneOf??e.in)[0]),i&&(wi(i)?n[e.field]="date":t.isNumber(i)?n[e.field]="number":t.isString(i)&&(n[e.field]="string")),e.timeUnit&&(n[e.field]="date")}})),n}function ld(e){const n={};function i(e){var i;$a(e)?n[e.field]="date":"quantitative"===e.type&&(i=e.aggregate,t.isString(i)&&p(["min","max"],i))?n[e.field]="number":q(e.field)>1?e.field in n||(n[e.field]="flatten"):ea(e)&&Co(e.sort)&&q(e.sort.field)>1&&(e.sort.field in n||(n[e.sort.field]="flatten"))}if((xm(e)||$m(e))&&e.forEachFieldDef(((t,n)=>{if(Ko(t))i(t);else{const r=tt(n),o=e.fieldDef(r);i({...t,type:o.type})}})),xm(e)){const{mark:t,markDef:i,encoding:r}=e;if(Qr(t)&&!e.encoding.order){const e=r["horizontal"===i.orient?"y":"x"];Ho(e)&&"quantitative"===e.type&&!(e.field in n)&&(n[e.field]="number")}}return n}class cd extends vc{clone(){return new cd(null,l(this._parse))}constructor(e,t){super(e),qn(this,"_parse",void 0),this._parse=t}hash(){return`Parse ${d(this._parse)}`}static makeExplicit(e,t,n){let i={};const r=t.data;return!sc(r)&&r?.format?.parse&&(i=r.format.parse),this.makeWithAncestors(e,i,{},n)}static makeWithAncestors(e,t,n,i){for(const e of D(n)){const t=i.getWithExplicit(e);void 0!==t.value&&(t.explicit||t.value===n[e]||"derived"===t.value||"flatten"===n[e]?delete n[e]:$i(ei(e,n[e],t.value)))}for(const e of D(t)){const n=i.get(e);void 0!==n&&(n===t[e]?delete t[e]:$i(ei(e,t[e],n)))}const r=new Jl(t,n);i.copyAll(r);const o={};for(const e of D(r.combine())){const t=r.get(e);null!==t&&(o[e]=t)}return 0===D(o).length||i.parseNothing?null:new cd(e,o)}get parse(){return this._parse}merge(e){this._parse={...this._parse,...e.parse},e.remove()}assembleFormatParse(){const e={};for(const t of D(this._parse)){const n=this._parse[t];1===q(t)&&(e[t]=n)}return e}producedFields(){return new Set(D(this._parse))}dependentFields(){return new Set(D(this._parse))}assembleTransforms(){let e=arguments.length>0&&void 0!==arguments[0]&&arguments[0];return D(this._parse).filter((t=>!e||q(t)>1)).map((e=>{const t=function(e,t){const n=A(e);if("number"===t)return`toNumber(${n})`;if("boolean"===t)return`toBoolean(${n})`;if("string"===t)return`toString(${n})`;if("date"===t)return`toDate(${n})`;if("flatten"===t)return n;if(t.startsWith("date:"))return`timeParse(${n},'${ad(t.slice(5,t.length))}')`;if(t.startsWith("utc:"))return`utcParse(${n},'${ad(t.slice(4,t.length))}')`;return $i(`Unrecognized parse "${t}".`),null}(e,this._parse[e]);if(!t)return null;return{type:"formula",expr:t,as:L(e)}})).filter((e=>null!==e))}}class ud extends vc{clone(){return new ud(null)}constructor(e){super(e)}dependentFields(){return new Set}producedFields(){return new Set([xs])}hash(){return"Identifier"}assemble(){return{type:"identifier",as:xs}}}class fd extends vc{clone(){return new fd(null,this.params)}constructor(e,t){super(e),this.params=t}dependentFields(){return new Set}producedFields(){}hash(){return`Graticule ${d(this.params)}`}assemble(){return{type:"graticule",...!0===this.params?{}:this.params}}}class dd extends vc{clone(){return new dd(null,this.params)}constructor(e,t){super(e),this.params=t}dependentFields(){return new Set}producedFields(){return new Set([this.params.as??"data"])}hash(){return`Hash ${d(this.params)}`}assemble(){return{type:"sequence",...this.params}}}class md extends vc{constructor(e){let t;if(super(null),qn(this,"_data",void 0),qn(this,"_name",void 0),qn(this,"_generator",void 0),e??={name:"source"},sc(e)||(t=e.format?{...f(e.format,["parse"])}:{}),oc(e))this._data={values:e.values};else if(rc(e)){if(this._data={url:e.url},!t.type){let n=/(?:\.([^.]+))?$/.exec(e.url)[1];p(["json","csv","tsv","dsv","topojson"],n)||(n="json"),t.type=n}}else cc(e)?this._data={values:[{type:"Sphere"}]}:(ac(e)||sc(e))&&(this._data={});this._generator=sc(e),e.name&&(this._name=e.name),t&&!S(t)&&(this._data.format=t)}dependentFields(){return new Set}producedFields(){}get data(){return this._data}hasName(){return!!this._name}get isGenerator(){return this._generator}get dataName(){return this._name}set dataName(e){this._name=e}set parent(e){throw new Error("Source nodes have to be roots.")}remove(){throw new Error("Source nodes are roots and cannot be removed.")}hash(){throw new Error("Cannot hash sources")}assemble(){return{name:this._name,...this._data,transform:[]}}}function pd(e){return e instanceof md||e instanceof fd||e instanceof dd}var gd=new WeakMap;class hd{constructor(){Wn(this,gd,{writable:!0,value:void 0}),Un(this,gd,!1)}setModified(){Un(this,gd,!0)}get modifiedFlag(){return function(e,t){return t.get?t.get.call(e):t.value}(e=this,Rn(e,gd,"get"));var e}}class yd extends hd{getNodeDepths(e,t,n){n.set(e,t);for(const i of e.children)this.getNodeDepths(i,t+1,n);return n}optimize(e){const t=[...this.getNodeDepths(e,0,new Map).entries()].sort(((e,t)=>t[1]-e[1]));for(const e of t)this.run(e[0]);return this.modifiedFlag}}class vd extends hd{optimize(e){this.run(e);for(const t of e.children)this.optimize(t);return this.modifiedFlag}}class bd extends vd{mergeNodes(e,t){const n=t.shift();for(const i of t)e.removeChild(i),i.parent=n,i.remove()}run(e){const t=e.children.map((e=>e.hash())),n={};for(let i=0;i1&&(this.setModified(),this.mergeNodes(e,n[t]))}}class xd extends vd{constructor(e){super(),qn(this,"requiresSelectionId",void 0),this.requiresSelectionId=e&&Lu(e)}run(e){e instanceof ud&&(this.requiresSelectionId&&(pd(e.parent)||e.parent instanceof rd||e.parent instanceof cd)||(this.setModified(),e.remove()))}}class $d extends hd{optimize(e){return this.run(e,new Set),this.modifiedFlag}run(e,t){let n=new Set;e instanceof wc&&(n=e.producedFields(),$(n,t)&&(this.setModified(),e.removeFormulas(t),0===e.producedFields.length&&e.remove()));for(const i of e.children)this.run(i,new Set([...t,...n]))}}class wd extends vd{constructor(){super()}run(e){e instanceof bc&&!e.isRequired()&&(this.setModified(),e.remove())}}class kd extends yd{run(e){if(!(pd(e)||e.numChildren()>1))for(const t of e.children)if(t instanceof cd)if(e instanceof cd)this.setModified(),e.merge(t);else{if(k(e.producedFields(),t.dependentFields()))continue;this.setModified(),t.swapWithParent()}}}class Sd extends yd{run(e){const t=[...e.children],n=e.children.filter((e=>e instanceof cd));if(e.numChildren()>1&&n.length>=1){const i={},r=new Set;for(const e of n){const t=e.parse;for(const e of D(t))e in i?i[e]!==t[e]&&r.add(e):i[e]=t[e]}for(const e of r)delete i[e];if(!S(i)){this.setModified();const n=new cd(e,i);for(const r of t){if(r instanceof cd)for(const e of D(i))delete r.parse[e];e.removeChild(r),r.parent=n,r instanceof cd&&0===D(r.parse).length&&r.remove()}}}}}class Dd extends yd{run(e){e instanceof bc||e.numChildren()>0||e instanceof od||e instanceof md||(this.setModified(),e.remove())}}class Fd extends yd{run(e){const t=e.children.filter((e=>e instanceof wc)),n=t.pop();for(const e of t)this.setModified(),n.merge(e)}}class zd extends yd{run(e){const t=e.children.filter((e=>e instanceof rd)),n={};for(const e of t){const t=d(e.groupBy);t in n||(n[t]=[]),n[t].push(e)}for(const t of D(n)){const i=n[t];if(i.length>1){const t=i.pop();for(const n of i)t.merge(n)&&(e.removeChild(n),n.parent=t,n.remove(),this.setModified())}}}}class Od extends yd{constructor(e){super(),this.model=e}run(e){const t=!(pd(e)||e instanceof Bu||e instanceof cd||e instanceof ud),n=[],i=[];for(const r of e.children)r instanceof nd&&(t&&!k(e.producedFields(),r.dependentFields())?n.push(r):i.push(r));if(n.length>0){const t=n.pop();for(const e of n)t.merge(e,this.model.renameSignal.bind(this.model));this.setModified(),e instanceof nd?e.merge(t,this.model.renameSignal.bind(this.model)):t.swapWithParent()}if(i.length>1){const e=i.pop();for(const t of i)e.merge(t,this.model.renameSignal.bind(this.model));this.setModified()}}}class _d extends yd{run(e){const t=[...e.children];if(!g(t,(e=>e instanceof bc))||e.numChildren()<=1)return;const n=[];let i;for(const r of t)if(r instanceof bc){let t=r;for(;1===t.numChildren();){const[e]=t.children;if(!(e instanceof bc))break;t=e}n.push(...t.children),i?(e.removeChild(r),r.parent=i.parent,i.parent.removeChild(i),i.parent=t,this.setModified()):i=t}else n.push(r);if(n.length){this.setModified();for(const e of n)e.parent.removeChild(e),e.parent=i}}}class Nd extends vc{clone(){return new Nd(null,l(this.transform))}constructor(e,t){super(e),this.transform=t}addDimensions(e){this.transform.groupby=b(this.transform.groupby.concat(e),(e=>e))}dependentFields(){const e=new Set;return this.transform.groupby&&this.transform.groupby.forEach(e.add,e),this.transform.joinaggregate.map((e=>e.field)).filter((e=>void 0!==e)).forEach(e.add,e),e}producedFields(){return new Set(this.transform.joinaggregate.map(this.getDefaultName))}getDefaultName(e){return e.as??oa(e)}hash(){return`JoinAggregateTransform ${d(this.transform)}`}assemble(){const e=[],t=[],n=[];for(const i of this.transform.joinaggregate)t.push(i.op),n.push(this.getDefaultName(i)),e.push(void 0===i.field?null:i.field);const i=this.transform.groupby;return{type:"joinaggregate",as:n,ops:t,fields:e,...void 0!==i?{groupby:i}:{}}}}class Cd extends vc{clone(){return new Cd(null,l(this._stack))}constructor(e,t){super(e),qn(this,"_stack",void 0),this._stack=t}static makeFromTransform(e,n){const{stack:i,groupby:r,as:o,offset:a="zero"}=n,s=[],l=[];if(void 0!==n.sort)for(const e of n.sort)s.push(e.field),l.push(U(e.order,"ascending"));const c={field:s,order:l};let u;return u=function(e){return t.isArray(e)&&e.every((e=>t.isString(e)))&&e.length>1}(o)?o:t.isString(o)?[o,`${o}_end`]:[`${n.stack}_start`,`${n.stack}_end`],new Cd(e,{dimensionFieldDefs:[],stackField:i,groupby:r,offset:a,sort:c,facetby:[],as:u})}static makeFromEncoding(e,n){const i=n.stack,{encoding:r}=n;if(!i)return null;const{groupbyChannels:o,fieldChannel:a,offset:s,impute:l}=i,c=o.map((e=>pa(r[e]))).filter((e=>!!e)),u=function(e){return e.stack.stackBy.reduce(((e,t)=>{const n=oa(t.fieldDef);return n&&e.push(n),e}),[])}(n),f=n.encoding.order;let d;if(t.isArray(f)||Ho(f))d=Tn(f);else{const e=Ro(f)?f.sort:"y"===a?"descending":"ascending";d=u.reduce(((t,n)=>(t.field.push(n),t.order.push(e),t)),{field:[],order:[]})}return new Cd(e,{dimensionFieldDefs:c,stackField:n.vgField(a),facetby:[],stackby:u,sort:d,offset:s,impute:l,as:[n.vgField(a,{suffix:"start",forAs:!0}),n.vgField(a,{suffix:"end",forAs:!0})]})}get stack(){return this._stack}addDimensions(e){this._stack.facetby.push(...e)}dependentFields(){const e=new Set;return e.add(this._stack.stackField),this.getGroupbyFields().forEach(e.add,e),this._stack.facetby.forEach(e.add,e),this._stack.sort.field.forEach(e.add,e),e}producedFields(){return new Set(this._stack.as)}hash(){return`Stack ${d(this._stack)}`}getGroupbyFields(){const{dimensionFieldDefs:e,impute:t,groupby:n}=this._stack;return e.length>0?e.map((e=>e.bin?t?[oa(e,{binSuffix:"mid"})]:[oa(e,{}),oa(e,{binSuffix:"end"})]:[oa(e)])).flat():n??[]}assemble(){const e=[],{facetby:t,dimensionFieldDefs:n,stackField:i,stackby:r,sort:o,offset:a,impute:s,as:l}=this._stack;if(s)for(const o of n){const{bandPosition:n=.5,bin:a}=o;if(a){const t=oa(o,{expr:"datum"}),i=oa(o,{expr:"datum",binSuffix:"end"});e.push({type:"formula",expr:`${n}*${t}+${1-n}*${i}`,as:oa(o,{binSuffix:"mid",forAs:!0})})}e.push({type:"impute",field:i,groupby:[...r,...t],key:oa(o,{binSuffix:"mid"}),method:"value",value:0})}return e.push({type:"stack",groupby:[...this.getGroupbyFields(),...t],field:i,sort:o,as:l,offset:a}),e}}class Pd extends vc{clone(){return new Pd(null,l(this.transform))}constructor(e,t){super(e),this.transform=t}addDimensions(e){this.transform.groupby=b(this.transform.groupby.concat(e),(e=>e))}dependentFields(){const e=new Set;return(this.transform.groupby??[]).forEach(e.add,e),(this.transform.sort??[]).forEach((t=>e.add(t.field))),this.transform.window.map((e=>e.field)).filter((e=>void 0!==e)).forEach(e.add,e),e}producedFields(){return new Set(this.transform.window.map(this.getDefaultName))}getDefaultName(e){return e.as??oa(e)}hash(){return`WindowTransform ${d(this.transform)}`}assemble(){const e=[],t=[],n=[],i=[];for(const r of this.transform.window)t.push(r.op),n.push(this.getDefaultName(r)),i.push(void 0===r.param?null:r.param),e.push(void 0===r.field?null:r.field);const r=this.transform.frame,o=this.transform.groupby;if(r&&null===r[0]&&null===r[1]&&t.every((e=>tn(e))))return{type:"joinaggregate",as:n,ops:t,fields:e,...void 0!==o?{groupby:o}:{}};const a=[],s=[];if(void 0!==this.transform.sort)for(const e of this.transform.sort)a.push(e.field),s.push(e.order??"ascending");const l={field:a,order:s},c=this.transform.ignorePeers;return{type:"window",params:i,as:n,ops:t,fields:e,sort:l,...void 0!==c?{ignorePeers:c}:{},...void 0!==o?{groupby:o}:{},...void 0!==r?{frame:r}:{}}}}function Ad(e){if(e instanceof od)if(1!==e.numChildren()||e.children[0]instanceof bc){const n=e.model.component.data.main;jd(n);const i=(t=e,function e(n){if(!(n instanceof od)){const i=n.clone();if(i instanceof bc){const e=Td+i.getSource();i.setSource(e),t.model.component.data.outputNodes[e]=i}else(i instanceof rd||i instanceof Cd||i instanceof Pd||i instanceof Nd)&&i.addDimensions(t.fields);for(const t of n.children.flatMap(e))t.parent=i;return[i]}return n.children.flatMap(e)}),r=e.children.map(i).flat();for(const e of r)e.parent=n}else{const t=e.children[0];(t instanceof rd||t instanceof Cd||t instanceof Pd||t instanceof Nd)&&t.addDimensions(e.fields),t.swapWithParent(),Ad(e)}else e.children.map(Ad);var t}function jd(e){if(e instanceof bc&&e.type===fc.Main&&1===e.numChildren()){const t=e.children[0];t instanceof od||(t.swapWithParent(),jd(e))}}const Td="scale_",Ed=5;function Md(e){for(const t of e){for(const e of t.children)if(e.parent!==t)return!1;if(!Md(t.children))return!1}return!0}function Ld(e,t){let n=!1;for(const i of t)n=e.optimize(i)||n;return n}function qd(e,t,n){let i=e.sources,r=!1;return r=Ld(new wd,i)||r,r=Ld(new xd(t),i)||r,i=i.filter((e=>e.numChildren()>0)),r=Ld(new Dd,i)||r,i=i.filter((e=>e.numChildren()>0)),n||(r=Ld(new kd,i)||r,r=Ld(new Od(t),i)||r,r=Ld(new $d,i)||r,r=Ld(new Sd,i)||r,r=Ld(new zd,i)||r,r=Ld(new Fd,i)||r,r=Ld(new bd,i)||r,r=Ld(new _d,i)||r),e.sources=i,r}class Ud{constructor(e){qn(this,"signal",void 0),Object.defineProperty(this,"signal",{enumerable:!0,get:e})}static fromName(e,t){return new Ud((()=>e(t)))}}function Rd(e){xm(e)?function(e){const t=e.component.scales;for(const n of D(t)){const i=Wd(e,n);if(t[n].setWithExplicit("domains",i),Vd(e,n),e.component.data.isFaceted){let t=e;for(;!$m(t)&&t.parent;)t=t.parent;if("shared"===t.component.resolve.scale[n])for(const e of i.value)bn(e)&&(e.data=Td+e.data.replace(Td,""))}}}(e):function(e){for(const t of e.children)Rd(t);const t=e.component.scales;for(const n of D(t)){let i,r=null;for(const t of e.children){const e=t.component.scales[n];if(e){i=void 0===i?e.getWithExplicit("domains"):nc(i,e.getWithExplicit("domains"),"domains","scale",Yd);const t=e.get("selectionExtent");r&&t&&r.param!==t.param&&$i(Kn),r=t}}t[n].setWithExplicit("domains",i),r&&t[n].set("selectionExtent",r,!0)}}(e)}function Wd(e,t){const n=e.getScaleComponent(t).get("type"),{encoding:i}=e,r=function(e,t,n,i){if("unaggregated"===e){const{valid:e,reason:i}=Gd(t,n);if(!e)return void $i(i)}else if(void 0===e&&i.useUnaggregatedDomain){const{valid:e}=Gd(t,n);if(e)return"unaggregated"}return e}(e.scaleDomain(t),e.typedFieldDef(t),n,e.config.scale);return r!==e.scaleDomain(t)&&(e.specifiedScales[t]={...e.specifiedScales[t],domain:r}),"x"===t&&ga(i.x2)?ga(i.x)?nc(Id(n,r,e,"x"),Id(n,r,e,"x2"),"domain","scale",Yd):Id(n,r,e,"x2"):"y"===t&&ga(i.y2)?ga(i.y)?nc(Id(n,r,e,"y"),Id(n,r,e,"y2"),"domain","scale",Yd):Id(n,r,e,"y2"):Id(n,r,e,t)}function Bd(e,t,n){const i=Ui(n)?.unit;return"temporal"===t||i?function(e,t,n){return e.map((e=>({signal:`{data: ${wa(e,{timeUnit:n,type:t})}}`})))}(e,t,i):[e]}function Id(e,n,i,r){const{encoding:o,markDef:a,mark:s,config:l,stack:c}=i,u=ga(o[r]),{type:f}=u,d=u.timeUnit;if(function(e){return e?.unionWith}(n)){const t=Id(e,void 0,i,r);return Kl([...Bd(n.unionWith,f,d),...t.value])}if(yn(n))return Kl([n]);if(n&&"unaggregated"!==n&&!Sr(n))return Kl(Bd(n,f,d));if(c&&r===c.fieldChannel){if("normalize"===c.offset)return Zl([[0,1]]);const e=i.requestDataName(fc.Main);return Zl([{data:e,field:i.vgField(r,{suffix:"start"})},{data:e,field:i.vgField(r,{suffix:"end"})}])}const m=Ht(r)&&Ho(u)?function(e,t,n){if(!xr(n))return;const i=e.fieldDef(t),r=i.sort;if(Po(r))return{op:"min",field:af(i,t),order:"ascending"};const{stack:o}=e,a=o?new Set([...o.groupbyFields,...o.stackBy.map((e=>e.fieldDef.field))]):void 0;if(Co(r)){return Hd(r,o&&!a.has(r.field))}if(No(r)){const{encoding:t,order:n}=r,i=e.fieldDef(t),{aggregate:s,field:l}=i,c=o&&!a.has(l);if(Zt(s)||en(s))return Hd({field:oa(i),order:n},c);if(tn(s)||!s)return Hd({op:s,field:l,order:n},c)}else{if("descending"===r)return{op:"min",field:e.vgField(t),order:"descending"};if(p(["ascending",void 0],r))return!0}return}(i,r,e):void 0;if(Go(u)){return Zl(Bd([u.datum],f,d))}const g=u;if("unaggregated"===n){const e=i.requestDataName(fc.Main),{field:t}=u;return Zl([{data:e,field:oa({field:t,aggregate:"min"})},{data:e,field:oa({field:t,aggregate:"max"})}])}if(ln(g.bin)){if(xr(e))return Zl("bin-ordinal"===e?[]:[{data:O(m)?i.requestDataName(fc.Main):i.requestDataName(fc.Raw),field:i.vgField(r,Sa(g,r)?{binSuffix:"range"}:{}),sort:!0!==m&&t.isObject(m)?m:{field:i.vgField(r,{}),op:"min"}}]);{const{bin:e}=g;if(ln(e)){const t=ed(i,g.field,e);return Zl([new Ud((()=>{const e=i.getSignalName(t);return`[${e}.start, ${e}.stop]`}))])}return Zl([{data:i.requestDataName(fc.Main),field:i.vgField(r,{})}])}}if(g.timeUnit&&p(["time","utc"],e)){const e=o[it(r)];if(Uo(g,e,a,l)){const t=i.requestDataName(fc.Main),n=Lo({fieldDef:g,fieldDef2:e,markDef:a,config:l}),o=Jr(s)&&.5!==n&&zt(r);return Zl([{data:t,field:i.vgField(r,o?{suffix:kc}:{})},{data:t,field:i.vgField(r,{suffix:o?Sc:"end"})}])}}return Zl(m?[{data:O(m)?i.requestDataName(fc.Main):i.requestDataName(fc.Raw),field:i.vgField(r),sort:m}]:[{data:i.requestDataName(fc.Main),field:i.vgField(r)}])}function Hd(e,t){const{op:n,field:i,order:r}=e;return{op:n??(t?"sum":zo),...i?{field:E(i)}:{},...r?{order:r}:{}}}function Vd(e,t){const n=e.component.scales[t],i=e.specifiedScales[t].domain,r=e.fieldDef(t)?.bin,o=Sr(i)&&i,a=un(r)&&fn(r.extent)&&r.extent;(o||a)&&n.set("selectionExtent",o??a,!0)}function Gd(e,n){const{aggregate:i,type:r}=e;return i?t.isString(i)&&!an.has(i)?{valid:!1,reason:fi(i)}:"quantitative"===r&&"log"===n?{valid:!1,reason:di(e)}:{valid:!0}:{valid:!1,reason:ui(e)}}function Yd(e,t,n,i){return e.explicit&&t.explicit&&$i(function(e,t,n,i){return`Conflicting ${t.toString()} property "${e.toString()}" (${X(n)} and ${X(i)}). Using the union of the two domains.`}(n,i,e.value,t.value)),{explicit:e.explicit,value:[...e.value,...t.value]}}function Xd(e){const n=b(e.map((e=>{if(bn(e)){const{sort:t,...n}=e;return n}return e})),d),i=b(e.map((e=>{if(bn(e)){const t=e.sort;return void 0===t||O(t)||("op"in t&&"count"===t.op&&delete t.field,"ascending"===t.order&&delete t.order),t}})).filter((e=>void 0!==e)),d);if(0===n.length)return;if(1===n.length){const n=e[0];if(bn(n)&&i.length>0){let e=i[0];if(i.length>1){$i(gi);const n=i.filter((e=>t.isObject(e)&&"op"in e&&"min"!==e.op));e=!i.every((e=>t.isObject(e)&&"op"in e))||1!==n.length||n[0]}else if(t.isObject(e)&&"field"in e){const t=e.field;n.field===t&&(e=!e.order||{order:e.order})}return{...n,sort:e}}return n}const r=b(i.map((e=>O(e)||!("op"in e)||t.isString(e.op)&&e.op in Kt?e:($i(function(e){return`Dropping sort property ${X(e)} as unioned domains only support boolean or op "count", "min", and "max".`}(e)),!0))),d);let o;1===r.length?o=r[0]:r.length>1&&($i(gi),o=!0);const a=b(e.map((e=>bn(e)?e.data:null)),(e=>e));if(1===a.length&&null!==a[0]){return{data:a[0],fields:n.map((e=>e.field)),...o?{sort:o}:{}}}return{fields:n,...o?{sort:o}:{}}}function Qd(e){if(bn(e)&&t.isString(e.field))return e.field;if(function(e){return!t.isArray(e)&&"fields"in e&&!("data"in e)}(e)){let n;for(const i of e.fields)if(bn(i)&&t.isString(i.field))if(n){if(n!==i.field)return $i("Detected faceted independent scales that union domain of multiple fields from different data sources. We will use the first field. The result view size may be incorrect."),n}else n=i.field;return $i("Detected faceted independent scales that union domain of the same fields from different source. We will assume that this is the same field from a different fork of the same data source. However, if this is not the case, the result view size may be incorrect."),n}if(function(e){return!t.isArray(e)&&"fields"in e&&"data"in e}(e)){$i("Detected faceted independent scales that union domain of multiple fields from the same data source. We will use the first field. The result view size may be incorrect.");const n=e.fields[0];return t.isString(n)?n:void 0}}function Jd(e,t){const n=e.component.scales[t].get("domains").map((t=>(bn(t)&&(t.data=e.lookupDataSource(t.data)),t)));return Xd(n)}function Kd(e){return km(e)||wm(e)?e.children.reduce(((e,t)=>e.concat(Kd(t))),Zd(e)):Zd(e)}function Zd(e){return D(e.component.scales).reduce(((n,i)=>{const r=e.component.scales[i];if(r.merged)return n;const o=r.combine(),{name:a,type:s,selectionExtent:l,domains:c,range:u,reverse:f,...d}=o,m=function(e,n,i,r){if(zt(i)){if(vn(e))return{step:{signal:`${n}_step`}}}else if(t.isObject(e)&&bn(e))return{...e,data:r.lookupDataSource(e.data)};return e}(o.range,a,i,e),p=Jd(e,i),g=l?function(e,n,i,r){const o=Hu(e,n.param,n);return{signal:$r(i.get("type"))&&t.isArray(r)&&r[0]>r[1]?`isValid(${o}) && reverse(${o})`:o}}(e,l,r,p):null;return n.push({name:a,type:s,...p?{domain:p}:{},...g?{domainRaw:g}:{},range:m,...void 0!==f?{reverse:f}:{},...d}),n}),[])}class em extends Jl{constructor(e,t){super({},{name:e}),qn(this,"merged",!1),this.setWithExplicit("type",t)}domainDefinitelyIncludesZero(){return!1!==this.get("zero")||g(this.get("domains"),(e=>t.isArray(e)&&2===e.length&&t.isNumber(e[0])&&e[0]<=0&&t.isNumber(e[1])&&e[1]>=0))}}const tm=["range","scheme"];function nm(e,n){const i=e.fieldDef(n);if(i?.bin){const{bin:r,field:o}=i,a=rt(n),s=e.getName(a);if(t.isObject(r)&&r.binned&&void 0!==r.step)return new Ud((()=>{const t=e.scaleName(n),i=`(domain("${t}")[1] - domain("${t}")[0]) / ${r.step}`;return`${e.getSignalName(s)} / (${i})`}));if(ln(r)){const t=ed(e,o,r);return new Ud((()=>{const n=e.getSignalName(t),i=`(${n}.stop - ${n}.start) / ${n}.step`;return`${e.getSignalName(s)} / (${i})`}))}}}function im(e,n){const i=n.specifiedScales[e],{size:r}=n,o=n.getScaleComponent(e).get("type");for(const r of tm)if(void 0!==i[r]){const a=Ar(o,r),s=jr(e,r);if(a)if(s)$i(s);else switch(r){case"range":{const r=i.range;if(t.isArray(r)){if(zt(e))return Kl(r.map((e=>{if("width"===e||"height"===e){const t=n.getName(e),i=n.getSignalName.bind(n);return Ud.fromName(i,t)}return e})))}else if(t.isObject(r))return Kl({data:n.requestDataName(fc.Main),field:r.field,sort:{op:"min",field:n.vgField(e)}});return Kl(r)}case"scheme":return Kl(rm(i[r]))}else $i(mi(o,r,e))}const a=e===Z||"xOffset"===e?"width":"height",s=r[a];if(Ns(s))if(zt(e))if(xr(o)){const t=am(s,n,e);if(t)return Kl({step:t})}else $i(pi(a));else if(Pt(e)){const t=e===ie?"x":"y";if("band"===n.getScaleComponent(t).get("type")){const e=sm(s,o);if(e)return Kl(e)}}const{rangeMin:l,rangeMax:u}=i,f=function(e,n){const{size:i,config:r,mark:o,encoding:a}=n,{type:s}=ga(a[e]),l=n.getScaleComponent(e),u=l.get("type"),{domain:f,domainMid:d}=n.specifiedScales[e];switch(e){case Z:case ee:if(p(["point","band"],u)){const t=lm(e,i,r.view);if(Ns(t)){return{step:am(t,n,e)}}}return om(e,n,u);case ie:case re:return function(e,t,n){const i=e===ie?"x":"y",r=t.getScaleComponent(i);if(!r)return om(i,t,n,{center:!0});const o=r.get("type"),a=t.scaleName(i),{markDef:s,config:l}=t;if("band"===o){const e=lm(i,t.size,t.config.view);if(Ns(e)){const t=sm(e,n);if(t)return t}return[0,{signal:`bandwidth('${a}')`}]}{const n=t.encoding[i];if(Ho(n)&&n.timeUnit){const e=Ri(n.timeUnit,(e=>`scale('${a}', ${e})`)),i=t.config.scale.bandWithNestedOffsetPaddingInner,r=Lo({fieldDef:n,markDef:s,config:l})-.5,o=0!==r?` + ${r}`:"";if(i){return[{signal:`${yn(i)?`${i.signal}/2`+o:`${i/2+r}`} * (${e})`},{signal:`${yn(i)?`(1 - ${i.signal}/2)`+o:`${1-i/2+r}`} * (${e})`}]}return[0,{signal:e}]}return c(`Cannot use ${e} scale if ${i} scale is not discrete.`)}}(e,n,u);case ye:{const a=cm(o,n.component.scales[e].get("zero"),r),s=function(e,n,i,r){const o={x:nm(i,"x"),y:nm(i,"y")};switch(e){case"bar":case"tick":{if(void 0!==r.scale.maxBandSize)return r.scale.maxBandSize;const e=fm(n,o,r.view);return t.isNumber(e)?e-1:new Ud((()=>`${e.signal} - 1`))}case"line":case"trail":case"rule":return r.scale.maxStrokeWidth;case"text":return r.scale.maxFontSize;case"point":case"square":case"circle":{if(r.scale.maxSize)return r.scale.maxSize;const e=fm(n,o,r.view);return t.isNumber(e)?Math.pow(um*e,2):new Ud((()=>`pow(${um} * ${e.signal}, 2)`))}}throw new Error(ai("size",e))}(o,i,n,r);return kr(u)?function(e,t,n){const i=()=>{const i=On(t),r=On(e),o=`(${i} - ${r}) / (${n} - 1)`;return`sequence(${r}, ${i} + ${o}, ${o})`};return yn(t)?new Ud(i):{signal:i()}}(a,s,function(e,n,i,r){switch(e){case"quantile":return n.scale.quantileCount;case"quantize":return n.scale.quantizeCount;case"threshold":return void 0!==i&&t.isArray(i)?i.length+1:($i(function(e){return`Domain for ${e} is required for threshold scale.`}(r)),3)}}(u,r,f,e)):[a,s]}case se:return[0,2*Math.PI];case ve:return[0,360];case oe:return[0,new Ud((()=>`min(${n.getSignalName("width")},${n.getSignalName("height")})/2`))];case we:return[r.scale.minStrokeWidth,r.scale.maxStrokeWidth];case ke:return[[1,0],[4,2],[2,1],[1,1],[1,2,4,2]];case he:return"symbol";case me:case pe:case ge:return"ordinal"===u?"nominal"===s?"category":"ordinal":void 0!==d?"diverging":"rect"===o||"geoshape"===o?"heatmap":"ramp";case be:case xe:case $e:return[r.scale.minOpacity,r.scale.maxOpacity]}}(e,n);return(void 0!==l||void 0!==u)&&Ar(o,"rangeMin")&&t.isArray(f)&&2===f.length?Kl([l??f[0],u??f[1]]):Zl(f)}function rm(e){return function(e){return!t.isString(e)&&!!e.name}(e)?{scheme:e.name,...f(e,["name"])}:{scheme:e}}function om(e,t,n){let{center:i}=arguments.length>3&&void 0!==arguments[3]?arguments[3]:{};const r=rt(e),o=t.getName(r),a=t.getSignalName.bind(t);return e===ee&&$r(n)?i?[Ud.fromName((e=>`${a(e)}/2`),o),Ud.fromName((e=>`-${a(e)}/2`),o)]:[Ud.fromName(a,o),0]:i?[Ud.fromName((e=>`-${a(e)}/2`),o),Ud.fromName((e=>`${a(e)}/2`),o)]:[0,Ud.fromName(a,o)]}function am(e,n,i){const{encoding:r}=n,o=n.getScaleComponent(i),a=at(i),s=r[a];if("offset"===_s({step:e,offsetIsDiscrete:Jo(s)&&ir(s.type)})&&Ea(r,a)){const i=n.getScaleComponent(a);let r=`domain('${n.scaleName(a)}').length`;if("band"===i.get("type")){r=`bandspace(${r}, ${i.get("paddingInner")??i.get("padding")??0}, ${i.get("paddingOuter")??i.get("padding")??0})`}const s=o.get("paddingInner")??o.get("padding");return{signal:`${e.step} * ${r} / (1-${l=s,yn(l)?l.signal:t.stringValue(l)})`}}return e.step;var l}function sm(e,t){if("offset"===_s({step:e,offsetIsDiscrete:xr(t)}))return{step:e.step}}function lm(e,t,n){const i=e===Z?"width":"height",r=t[i];return r||Ts(n,i)}function cm(e,t,n){if(t)return yn(t)?{signal:`${t.signal} ? 0 : ${cm(e,!1,n)}`}:0;switch(e){case"bar":case"tick":return n.scale.minBandSize;case"line":case"trail":case"rule":return n.scale.minStrokeWidth;case"text":return n.scale.minFontSize;case"point":case"square":case"circle":return n.scale.minSize}throw new Error(ai("size",e))}const um=.95;function fm(e,t,n){const i=Ns(e.width)?e.width.step:js(n,"width"),r=Ns(e.height)?e.height.step:js(n,"height");return t.x||t.y?new Ud((()=>`min(${[t.x?t.x.signal:i,t.y?t.y.signal:r].join(", ")})`)):Math.min(i,r)}function dm(e,t){xm(e)?function(e,t){const n=e.component.scales,{config:i,encoding:r,markDef:o,specifiedScales:a}=e;for(const s of D(n)){const l=a[s],c=n[s],u=e.getScaleComponent(s),f=ga(r[s]),d=l[t],m=u.get("type"),p=u.get("padding"),g=u.get("paddingInner"),h=Ar(m,t),y=jr(s,t);if(void 0!==d&&(h?y&&$i(y):$i(mi(m,t,s))),h&&void 0===y)if(void 0!==d){const e=f.timeUnit,n=f.type;switch(t){case"domainMax":case"domainMin":wi(l[t])||"temporal"===n||e?c.set(t,{signal:wa(l[t],{type:n,timeUnit:e})},!0):c.set(t,l[t],!0);break;default:c.copyKeyFromObject(t,l)}}else{const n=t in mm?mm[t]({model:e,channel:s,fieldOrDatumDef:f,scaleType:m,scalePadding:p,scalePaddingInner:g,domain:l.domain,domainMin:l.domainMin,domainMax:l.domainMax,markDef:o,config:i,hasNestedOffsetScale:Ma(r,s),hasSecondaryRangeChannel:!!r[it(s)]}):i.scale[t];void 0!==n&&c.set(t,n,!1)}}}(e,t):gm(e,t)}const mm={bins:e=>{let{model:t,fieldOrDatumDef:n}=e;return Ho(n)?function(e,t){const n=t.bin;if(ln(n)){const i=ed(e,t.field,n);return new Ud((()=>e.getSignalName(i)))}if(cn(n)&&un(n)&&void 0!==n.step)return{step:n.step};return}(t,n):void 0},interpolate:e=>{let{channel:t,fieldOrDatumDef:n}=e;return function(e,t){if(p([me,pe,ge],e)&&"nominal"!==t)return"hcl";return}(t,n.type)},nice:e=>{let{scaleType:n,channel:i,domain:r,domainMin:o,domainMax:a,fieldOrDatumDef:s}=e;return function(e,n,i,r,o,a){if(pa(a)?.bin||t.isArray(i)||null!=o||null!=r||p([cr.TIME,cr.UTC],e))return;return!!zt(n)||void 0}(n,i,r,o,a,s)},padding:e=>{let{channel:t,scaleType:n,fieldOrDatumDef:i,markDef:r,config:o}=e;return function(e,t,n,i,r,o){if(zt(e)){if(wr(t)){if(void 0!==n.continuousPadding)return n.continuousPadding;const{type:t,orient:a}=r;if("bar"===t&&(!Ho(i)||!i.bin&&!i.timeUnit)&&("vertical"===a&&"x"===e||"horizontal"===a&&"y"===e))return o.continuousBandSize}if(t===cr.POINT)return n.pointPadding}return}(t,n,o.scale,i,r,o.bar)},paddingInner:e=>{let{scalePadding:t,channel:n,markDef:i,scaleType:r,config:o,hasNestedOffsetScale:a}=e;return function(e,t,n,i,r){let o=arguments.length>5&&void 0!==arguments[5]&&arguments[5];if(void 0!==e)return;if(zt(t)){const{bandPaddingInner:e,barBandPaddingInner:t,rectBandPaddingInner:i,bandWithNestedOffsetPaddingInner:a}=r;return o?a:U(e,"bar"===n?t:i)}if(Pt(t)&&i===cr.BAND)return r.offsetBandPaddingInner;return}(t,n,i.type,r,o.scale,a)},paddingOuter:e=>{let{scalePadding:t,channel:n,scaleType:i,scalePaddingInner:r,config:o,hasNestedOffsetScale:a}=e;return function(e,t,n,i,r){let o=arguments.length>5&&void 0!==arguments[5]&&arguments[5];if(void 0!==e)return;if(zt(t)){const{bandPaddingOuter:e,bandWithNestedOffsetPaddingOuter:t}=r;if(o)return t;if(n===cr.BAND)return U(e,yn(i)?{signal:`${i.signal}/2`}:i/2)}else if(Pt(t)){if(n===cr.POINT)return.5;if(n===cr.BAND)return r.offsetBandPaddingOuter}return}(t,n,i,r,o.scale,a)},reverse:e=>{let{fieldOrDatumDef:t,scaleType:n,channel:i,config:r}=e;return function(e,t,n,i){if("x"===n&&void 0!==i.xReverse)return $r(e)&&"descending"===t?yn(i.xReverse)?{signal:`!${i.xReverse.signal}`}:!i.xReverse:i.xReverse;if($r(e)&&"descending"===t)return!0;return}(n,Ho(t)?t.sort:void 0,i,r.scale)},zero:e=>{let{channel:n,fieldOrDatumDef:i,domain:r,markDef:o,scaleType:a,config:s,hasSecondaryRangeChannel:l}=e;return function(e,n,i,r,o,a,s){if(i&&"unaggregated"!==i&&$r(o)){if(t.isArray(i)){const e=i[0],n=i[i.length-1];if(t.isNumber(e)&&e<=0&&t.isNumber(n)&&n>=0)return!0}return!1}if("size"===e&&"quantitative"===n.type&&!kr(o))return!0;if((!Ho(n)||!n.bin)&&p([...Ft,..._t],e)){const{orient:t,type:n}=r;return(!p(["bar","area","line","trail"],n)||!("horizontal"===t&&"y"===e||"vertical"===t&&"x"===e))&&(!(!p(["bar","area"],n)||s)||a?.zero)}return!1}(n,i,r,o,a,s.scale,l)}};function pm(e){xm(e)?function(e){const t=e.component.scales;for(const n of It){const i=t[n];if(!i)continue;const r=im(n,e);i.setWithExplicit("range",r)}}(e):gm(e,"range")}function gm(e,t){const n=e.component.scales;for(const n of e.children)"range"===t?pm(n):dm(n,t);for(const i of D(n)){let r;for(const n of e.children){const e=n.component.scales[i];if(e){r=nc(r,e.getWithExplicit(t),t,"scale",ec(((e,n)=>"range"===t&&e.step&&n.step?e.step-n.step:0)))}}n[i].setWithExplicit(t,r)}}function hm(e,t,n,i){const r=function(e,t,n,i){switch(t.type){case"nominal":case"ordinal":if(qe(e)||"discrete"===Qt(e))return"shape"===e&&"ordinal"===t.type&&$i(ci(e,"ordinal")),"ordinal";if(zt(e)||Pt(e)){if(p(["rect","bar","image","rule"],n.type))return"band";if(i)return"band"}else if("arc"===n.type&&e in Ot)return"band";return io(n[rt(e)])||ta(t)&&t.axis?.tickBand?"band":"point";case"temporal":return qe(e)?"time":"discrete"===Qt(e)?($i(ci(e,"temporal")),"ordinal"):Ho(t)&&t.timeUnit&&Ui(t.timeUnit).utc?"utc":"time";case"quantitative":return qe(e)?Ho(t)&&ln(t.bin)?"bin-ordinal":"linear":"discrete"===Qt(e)?($i(ci(e,"quantitative")),"ordinal"):"linear";case"geojson":return}throw new Error(ii(t.type))}(t,n,i,arguments.length>4&&void 0!==arguments[4]&&arguments[4]),{type:o}=e;return Ht(t)?void 0!==o?function(e,t){let n=arguments.length>2&&void 0!==arguments[2]&&arguments[2];if(!Ht(e))return!1;switch(e){case Z:case ee:case ie:case re:case se:case oe:return!!wr(t)||"band"===t||"point"===t&&!n;case ye:case we:case be:case xe:case $e:case ve:return wr(t)||kr(t)||p(["band","point","ordinal"],t);case me:case pe:case ge:return"band"!==t;case ke:case he:return"ordinal"===t||kr(t)}}(t,o)?Ho(n)&&(a=o,s=n.type,!(p([or,sr],s)?void 0===a||xr(a):s===ar?p([cr.TIME,cr.UTC,void 0],a):s!==rr||hr(a)||kr(a)||void 0===a))?($i(function(e,t){return`FieldDef does not work with "${e}" scale. We are using "${t}" scale instead.`}(o,r)),r):o:($i(function(e,t,n){return`Channel "${e}" does not work with "${t}" scale. We are using "${n}" scale instead.`}(t,o,r)),r):r:null;var a,s}function ym(e){xm(e)?e.component.scales=function(e){const{encoding:t,mark:n,markDef:i}=e,r={};for(const o of It){const a=ga(t[o]);if(a&&n===Xr&&o===he&&a.type===lr)continue;let s=a&&a.scale;if(a&&null!==s&&!1!==s){s??={};const n=hm(s,o,a,i,Ma(t,o));r[o]=new em(e.scaleName(`${o}`,!0),{value:n,explicit:s.type===n})}}return r}(e):e.component.scales=function(e){const t=e.component.scales={},n={},i=e.component.resolve;for(const t of e.children){ym(t);for(const r of D(t.component.scales))if(i.scale[r]??=Of(r,e),"shared"===i.scale[r]){const e=n[r],o=t.component.scales[r].getWithExplicit("type");e?fr(e.value,o.value)?n[r]=nc(e,o,"type","scale",vm):(i.scale[r]="independent",delete n[r]):n[r]=o}}for(const i of D(n)){const r=e.scaleName(i,!0),o=n[i];t[i]=new em(r,o);for(const t of e.children){const e=t.component.scales[i];e&&(t.renameScale(e.get("name"),r),e.merged=!0)}}return t}(e)}const vm=ec(((e,t)=>mr(e)-mr(t)));class bm{constructor(){qn(this,"nameMap",void 0),this.nameMap={}}rename(e,t){this.nameMap[e]=t}has(e){return void 0!==this.nameMap[e]}get(e){for(;this.nameMap[e]&&e!==this.nameMap[e];)e=this.nameMap[e];return e}}function xm(e){return"unit"===e?.type}function $m(e){return"facet"===e?.type}function wm(e){return"concat"===e?.type}function km(e){return"layer"===e?.type}class Sm{constructor(e,n,i,r,o,a,c){this.type=n,this.parent=i,this.config=o,qn(this,"name",void 0),qn(this,"size",void 0),qn(this,"title",void 0),qn(this,"description",void 0),qn(this,"data",void 0),qn(this,"transforms",void 0),qn(this,"layout",void 0),qn(this,"scaleNameMap",void 0),qn(this,"projectionNameMap",void 0),qn(this,"signalNameMap",void 0),qn(this,"component",void 0),qn(this,"view",void 0),qn(this,"children",void 0),qn(this,"correctDataNames",(e=>(e.from?.data&&(e.from.data=this.lookupDataSource(e.from.data)),e.from?.facet?.data&&(e.from.facet.data=this.lookupDataSource(e.from.facet.data)),e))),this.parent=i,this.config=o,this.view=pn(c),this.name=e.name??r,this.title=hn(e.title)?{text:e.title}:e.title?pn(e.title):void 0,this.scaleNameMap=i?i.scaleNameMap:new bm,this.projectionNameMap=i?i.projectionNameMap:new bm,this.signalNameMap=i?i.signalNameMap:new bm,this.data=e.data,this.description=e.description,this.transforms=(e.transform??[]).map((e=>bl(e)?{filter:s(e.filter,tr)}:e)),this.layout="layer"===n||"unit"===n?{}:function(e,n,i){const r=i[n],o={},{spacing:a,columns:s}=r;void 0!==a&&(o.spacing=a),void 0!==s&&(To(e)&&!Ao(e.facet)||Fs(e))&&(o.columns=s),zs(e)&&(o.columns=1);for(const n of Ps)if(void 0!==e[n])if("spacing"===n){const i=e[n];o[n]=t.isNumber(i)?i:{row:i.row??a,column:i.column??a}}else o[n]=e[n];return o}(e,n,o),this.component={data:{sources:i?i.component.data.sources:[],outputNodes:i?i.component.data.outputNodes:{},outputNodeRefCounts:i?i.component.data.outputNodeRefCounts:{},isFaceted:To(e)||i?.component.data.isFaceted&&void 0===e.data},layoutSize:new Jl,layoutHeaders:{row:{},column:{},facet:{}},mark:null,resolve:{scale:{},axis:{},legend:{},...a?l(a):{}},selection:null,scales:null,projection:null,axes:{},legends:{}}}get width(){return this.getSizeSignalRef("width")}get height(){return this.getSizeSignalRef("height")}parse(){this.parseScale(),this.parseLayoutSize(),this.renameTopLevelLayoutSizeSignal(),this.parseSelections(),this.parseProjection(),this.parseData(),this.parseAxesAndHeaders(),this.parseLegends(),this.parseMarkGroup()}parseScale(){!function(e){let{ignoreRange:t}=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{};ym(e),Rd(e);for(const t of Pr)dm(e,t);t||pm(e)}(this)}parseProjection(){Jf(this)}renameTopLevelLayoutSizeSignal(){"width"!==this.getName("width")&&this.renameSignal(this.getName("width"),"width"),"height"!==this.getName("height")&&this.renameSignal(this.getName("height"),"height")}parseLegends(){Rf(this)}assembleEncodeFromView(e){const{style:t,...n}=e,i={};for(const e of D(n)){const t=n[e];void 0!==t&&(i[e]=Fn(t))}return i}assembleGroupEncodeEntry(e){let t={};return this.view&&(t=this.assembleEncodeFromView(this.view)),e||(this.description&&(t.description=Fn(this.description)),"unit"!==this.type&&"layer"!==this.type)?S(t)?void 0:t:{width:this.getSizeSignalRef("width"),height:this.getSizeSignalRef("height"),...t??{}}}assembleLayout(){if(!this.layout)return;const{spacing:e,...t}=this.layout,{component:n,config:i}=this,r=function(e,t){const n={};for(const i of Re){const r=e[i];if(r?.facetFieldDef){const{titleAnchor:e,titleOrient:o}=cf(["titleAnchor","titleOrient"],r.facetFieldDef.header,t,i),a=sf(i,o),s=xf(e,a);void 0!==s&&(n[a]=s)}}return S(n)?void 0:n}(n.layoutHeaders,i);return{padding:e,...this.assembleDefaultLayout(),...t,...r?{titleBand:r}:{}}}assembleDefaultLayout(){return{}}assembleHeaderMarks(){const{layoutHeaders:e}=this.component;let t=[];for(const n of Re)e[n].title&&t.push(df(this,n));for(const e of uf)t=t.concat(gf(this,e));return t}assembleAxes(){return function(e,t){const{x:n=[],y:i=[]}=e;return[...n.map((e=>Yu(e,"grid",t))),...i.map((e=>Yu(e,"grid",t))),...n.map((e=>Yu(e,"main",t))),...i.map((e=>Yu(e,"main",t)))].filter((e=>e))}(this.component.axes,this.config)}assembleLegends(){return Vf(this)}assembleProjections(){return Gf(this)}assembleTitle(){const{encoding:e,...t}=this.title??{},n={...gn(this.config.title).nonMarkTitleProperties,...t,...e?{encode:{update:e}}:{}};if(n.text)return p(["unit","layer"],this.type)?p(["middle",void 0],n.anchor)&&(n.frame??="group"):n.anchor??="start",S(n)?void 0:n}assembleGroup(){let e=arguments.length>0&&void 0!==arguments[0]?arguments[0]:[];const t={};e=e.concat(this.assembleSignals()),e.length>0&&(t.signals=e);const n=this.assembleLayout();n&&(t.layout=n),t.marks=[].concat(this.assembleHeaderMarks(),this.assembleMarks());const i=!this.parent||$m(this.parent)?Kd(this):[];i.length>0&&(t.scales=i);const r=this.assembleAxes();r.length>0&&(t.axes=r);const o=this.assembleLegends();return o.length>0&&(t.legends=o),t}getName(e){return _((this.name?`${this.name}_`:"")+e)}getDataName(e){return this.getName(fc[e].toLowerCase())}requestDataName(e){const t=this.getDataName(e),n=this.component.data.outputNodeRefCounts;return n[t]=(n[t]||0)+1,t}getSizeSignalRef(e){if($m(this.parent)){const t=Nt(Ff(e)),n=this.component.scales[t];if(n&&!n.merged){const e=n.get("type"),i=n.get("range");if(xr(e)&&vn(i)){const e=n.get("name"),i=Qd(Jd(this,t));if(i){return{signal:Df(e,n,oa({aggregate:"distinct",field:i},{expr:"datum"}))}}return $i(Yn(t)),null}}}return{signal:this.signalNameMap.get(this.getName(e))}}lookupDataSource(e){const t=this.component.data.outputNodes[e];return t?t.getSource():e}getSignalName(e){return this.signalNameMap.get(e)}renameSignal(e,t){this.signalNameMap.rename(e,t)}renameScale(e,t){this.scaleNameMap.rename(e,t)}renameProjection(e,t){this.projectionNameMap.rename(e,t)}scaleName(e,t){return t?this.getName(e):Ke(e)&&Ht(e)&&this.component.scales[e]||this.scaleNameMap.has(this.getName(e))?this.scaleNameMap.get(this.getName(e)):void 0}projectionName(e){return e?this.getName("projection"):this.component.projection&&!this.component.projection.merged||this.projectionNameMap.has(this.getName("projection"))?this.projectionNameMap.get(this.getName("projection")):void 0}getScaleComponent(e){if(!this.component.scales)throw new Error("getScaleComponent cannot be called before parseScale(). Make sure you have called parseScale or use parseUnitModelWithScale().");const t=this.component.scales[e];return t&&!t.merged?t:this.parent?this.parent.getScaleComponent(e):void 0}getSelectionComponent(e,t){let n=this.component.selection[e];if(!n&&this.parent&&(n=this.parent.getSelectionComponent(e,t)),!n)throw new Error(function(e){return`Cannot find a selection named "${e}".`}(t));return n}hasAxisOrientSignalRef(){return this.component.axes.x?.some((e=>e.hasOrientSignalRef()))||this.component.axes.y?.some((e=>e.hasOrientSignalRef()))}}class Dm extends Sm{vgField(e){let t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{};const n=this.fieldDef(e);if(n)return oa(n,t)}reduceFieldDef(e,n){return function(e,n,i,r){return e?D(e).reduce(((i,o)=>{const a=e[o];return t.isArray(a)?a.reduce(((e,t)=>n.call(r,e,t,o)),i):n.call(r,i,a,o)}),i):i}(this.getMapping(),((t,n,i)=>{const r=pa(n);return r?e(t,r,i):t}),n)}forEachFieldDef(e,t){Wa(this.getMapping(),((t,n)=>{const i=pa(t);i&&e(i,n)}),t)}}class Fm extends vc{clone(){return new Fm(null,l(this.transform))}constructor(e,t){super(e),this.transform=t,this.transform=l(t);const n=this.transform.as??[void 0,void 0];this.transform.as=[n[0]??"value",n[1]??"density"]}dependentFields(){return new Set([this.transform.density,...this.transform.groupby??[]])}producedFields(){return new Set(this.transform.as)}hash(){return`DensityTransform ${d(this.transform)}`}assemble(){const{density:e,...t}=this.transform,n={type:"kde",field:e,...t};return this.transform.groupby&&(n.resolve="shared"),n}}class zm extends vc{clone(){return new zm(null,l(this.transform))}constructor(e,t){super(e),this.transform=t,this.transform=l(t)}dependentFields(){return new Set([this.transform.extent])}producedFields(){return new Set([])}hash(){return`ExtentTransform ${d(this.transform)}`}assemble(){const{extent:e,param:t}=this.transform;return{type:"extent",field:e,signal:t}}}class Om extends vc{clone(){return new Om(null,{...this.filter})}constructor(e,t){super(e),this.filter=t}static make(e,t){const{config:n,mark:i,markDef:r}=t;if("filter"!==Cn("invalid",r,n))return null;const o=t.reduceFieldDef(((e,n,r)=>{const o=Ht(r)&&t.getScaleComponent(r);if(o){$r(o.get("type"))&&"count"!==n.aggregate&&!Qr(i)&&(e[n.field]=n)}return e}),{});return D(o).length?new Om(e,o):null}dependentFields(){return new Set(D(this.filter))}producedFields(){return new Set}hash(){return`FilterInvalid ${d(this.filter)}`}assemble(){const e=D(this.filter).reduce(((e,t)=>{const n=this.filter[t],i=oa(n,{expr:"datum"});return null!==n&&("temporal"===n.type?e.push(`(isDate(${i}) || (isValid(${i}) && isFinite(+${i})))`):"quantitative"===n.type&&(e.push(`isValid(${i})`),e.push(`isFinite(+${i})`))),e}),[]);return e.length>0?{type:"filter",expr:e.join(" && ")}:null}}class _m extends vc{clone(){return new _m(this.parent,l(this.transform))}constructor(e,t){super(e),this.transform=t,this.transform=l(t);const{flatten:n,as:i=[]}=this.transform;this.transform.as=n.map(((e,t)=>i[t]??e))}dependentFields(){return new Set(this.transform.flatten)}producedFields(){return new Set(this.transform.as)}hash(){return`FlattenTransform ${d(this.transform)}`}assemble(){const{flatten:e,as:t}=this.transform;return{type:"flatten",fields:e,as:t}}}class Nm extends vc{clone(){return new Nm(null,l(this.transform))}constructor(e,t){super(e),this.transform=t,this.transform=l(t);const n=this.transform.as??[void 0,void 0];this.transform.as=[n[0]??"key",n[1]??"value"]}dependentFields(){return new Set(this.transform.fold)}producedFields(){return new Set(this.transform.as)}hash(){return`FoldTransform ${d(this.transform)}`}assemble(){const{fold:e,as:t}=this.transform;return{type:"fold",fields:e,as:t}}}class Cm extends vc{clone(){return new Cm(null,l(this.fields),this.geojson,this.signal)}static parseAll(e,t){if(t.component.projection&&!t.component.projection.isFit)return e;let n=0;for(const i of[[ue,ce],[de,fe]]){const r=i.map((e=>{const n=ga(t.encoding[e]);return Ho(n)?n.field:Go(n)?{expr:`${n.datum}`}:Zo(n)?{expr:`${n.value}`}:void 0}));(r[0]||r[1])&&(e=new Cm(e,r,null,t.getName("geojson_"+n++)))}if(t.channelHasField(he)){const i=t.typedFieldDef(he);i.type===lr&&(e=new Cm(e,null,i.field,t.getName("geojson_"+n++)))}return e}constructor(e,t,n,i){super(e),this.fields=t,this.geojson=n,this.signal=i}dependentFields(){const e=(this.fields??[]).filter(t.isString);return new Set([...this.geojson?[this.geojson]:[],...e])}producedFields(){return new Set}hash(){return`GeoJSON ${this.geojson} ${this.signal} ${d(this.fields)}`}assemble(){return[...this.geojson?[{type:"filter",expr:`isValid(datum["${this.geojson}"])`}]:[],{type:"geojson",...this.fields?{fields:this.fields}:{},...this.geojson?{geojson:this.geojson}:{},signal:this.signal}]}}class Pm extends vc{clone(){return new Pm(null,this.projection,l(this.fields),l(this.as))}constructor(e,t,n,i){super(e),this.projection=t,this.fields=n,this.as=i}static parseAll(e,t){if(!t.projectionName())return e;for(const n of[[ue,ce],[de,fe]]){const i=n.map((e=>{const n=ga(t.encoding[e]);return Ho(n)?n.field:Go(n)?{expr:`${n.datum}`}:Zo(n)?{expr:`${n.value}`}:void 0})),r=n[0]===de?"2":"";(i[0]||i[1])&&(e=new Pm(e,t.projectionName(),i,[t.getName(`x${r}`),t.getName(`y${r}`)]))}return e}dependentFields(){return new Set(this.fields.filter(t.isString))}producedFields(){return new Set(this.as)}hash(){return`Geopoint ${this.projection} ${d(this.fields)} ${d(this.as)}`}assemble(){return{type:"geopoint",projection:this.projection,fields:this.fields,as:this.as}}}class Am extends vc{clone(){return new Am(null,l(this.transform))}constructor(e,t){super(e),this.transform=t}dependentFields(){return new Set([this.transform.impute,this.transform.key,...this.transform.groupby??[]])}producedFields(){return new Set([this.transform.impute])}processSequence(e){const{start:t=0,stop:n,step:i}=e;return{signal:`sequence(${[t,n,...i?[i]:[]].join(",")})`}}static makeFromTransform(e,t){return new Am(e,t)}static makeFromEncoding(e,t){const n=t.encoding,i=n.x,r=n.y;if(Ho(i)&&Ho(r)){const o=i.impute?i:r.impute?r:void 0;if(void 0===o)return;const a=i.impute?r:r.impute?i:void 0,{method:s,value:l,frame:c,keyvals:u}=o.impute,f=Ba(t.mark,n);return new Am(e,{impute:o.field,key:a.field,...s?{method:s}:{},...void 0!==l?{value:l}:{},...c?{frame:c}:{},...void 0!==u?{keyvals:u}:{},...f.length?{groupby:f}:{}})}return null}hash(){return`Impute ${d(this.transform)}`}assemble(){const{impute:e,key:t,keyvals:n,method:i,groupby:r,value:o,frame:a=[null,null]}=this.transform,s={type:"impute",field:e,key:t,...n?{keyvals:(l=n,void 0!==l?.stop?this.processSequence(n):n)}:{},method:"value",...r?{groupby:r}:{},value:i&&"value"!==i?null:o};var l;if(i&&"value"!==i){return[s,{type:"window",as:[`imputed_${e}_value`],ops:[i],fields:[e],frame:a,ignorePeers:!1,...r?{groupby:r}:{}},{type:"formula",expr:`datum.${e} === null ? datum.imputed_${e}_value : datum.${e}`,as:e}]}return[s]}}class jm extends vc{clone(){return new jm(null,l(this.transform))}constructor(e,t){super(e),this.transform=t,this.transform=l(t);const n=this.transform.as??[void 0,void 0];this.transform.as=[n[0]??t.on,n[1]??t.loess]}dependentFields(){return new Set([this.transform.loess,this.transform.on,...this.transform.groupby??[]])}producedFields(){return new Set(this.transform.as)}hash(){return`LoessTransform ${d(this.transform)}`}assemble(){const{loess:e,on:t,...n}=this.transform;return{type:"loess",x:t,y:e,...n}}}class Tm extends vc{clone(){return new Tm(null,l(this.transform),this.secondary)}constructor(e,t,n){super(e),this.transform=t,this.secondary=n}static make(e,t,n,i){const r=t.component.data.sources,{from:o}=n;let a=null;if(function(e){return"data"in e}(o)){let e=Qm(o.data,r);e||(e=new md(o.data),r.push(e));const n=t.getName(`lookup_${i}`);a=new bc(e,n,fc.Lookup,t.component.data.outputNodeRefCounts),t.component.data.outputNodes[n]=a}else if(function(e){return"param"in e}(o)){const e=o.param;let i;n={as:e,...n};try{i=t.getSelectionComponent(_(e),e)}catch(t){throw new Error(function(e){return`Lookups can only be performed on selection parameters. "${e}" is a variable parameter.`}(e))}if(a=i.materialized,!a)throw new Error(function(e){return`Cannot define and lookup the "${e}" selection in the same view. Try moving the lookup into a second, layered view?`}(e))}return new Tm(e,n,a.getSource())}dependentFields(){return new Set([this.transform.lookup])}producedFields(){return new Set(this.transform.as?t.array(this.transform.as):this.transform.from.fields)}hash(){return`Lookup ${d({transform:this.transform,secondary:this.secondary})}`}assemble(){let e;if(this.transform.from.fields)e={values:this.transform.from.fields,...this.transform.as?{as:t.array(this.transform.as)}:{}};else{let n=this.transform.as;t.isString(n)||($i('If "from.fields" is not specified, "as" has to be a string that specifies the key to be used for the data from the secondary source.'),n="_lookup"),e={as:[n]}}return{type:"lookup",from:this.secondary,key:this.transform.from.key,fields:[this.transform.lookup],...e,...this.transform.default?{default:this.transform.default}:{}}}}class Em extends vc{clone(){return new Em(null,l(this.transform))}constructor(e,t){super(e),this.transform=t,this.transform=l(t);const n=this.transform.as??[void 0,void 0];this.transform.as=[n[0]??"prob",n[1]??"value"]}dependentFields(){return new Set([this.transform.quantile,...this.transform.groupby??[]])}producedFields(){return new Set(this.transform.as)}hash(){return`QuantileTransform ${d(this.transform)}`}assemble(){const{quantile:e,...t}=this.transform;return{type:"quantile",field:e,...t}}}class Mm extends vc{clone(){return new Mm(null,l(this.transform))}constructor(e,t){super(e),this.transform=t,this.transform=l(t);const n=this.transform.as??[void 0,void 0];this.transform.as=[n[0]??t.on,n[1]??t.regression]}dependentFields(){return new Set([this.transform.regression,this.transform.on,...this.transform.groupby??[]])}producedFields(){return new Set(this.transform.as)}hash(){return`RegressionTransform ${d(this.transform)}`}assemble(){const{regression:e,on:t,...n}=this.transform;return{type:"regression",x:t,y:e,...n}}}class Lm extends vc{clone(){return new Lm(null,l(this.transform))}constructor(e,t){super(e),this.transform=t}addDimensions(e){this.transform.groupby=b((this.transform.groupby??[]).concat(e),(e=>e))}producedFields(){}dependentFields(){return new Set([this.transform.pivot,this.transform.value,...this.transform.groupby??[]])}hash(){return`PivotTransform ${d(this.transform)}`}assemble(){const{pivot:e,value:t,groupby:n,limit:i,op:r}=this.transform;return{type:"pivot",field:e,value:t,...void 0!==i?{limit:i}:{},...void 0!==r?{op:r}:{},...void 0!==n?{groupby:n}:{}}}}class qm extends vc{clone(){return new qm(null,l(this.transform))}constructor(e,t){super(e),this.transform=t}dependentFields(){return new Set}producedFields(){return new Set}hash(){return`SampleTransform ${d(this.transform)}`}assemble(){return{type:"sample",size:this.transform.sample}}}function Um(e){let t=0;return function n(i,r){if(i instanceof md&&!i.isGenerator&&!rc(i.data)){e.push(r);r={name:null,source:r.name,transform:[]}}if(i instanceof cd&&(i.parent instanceof md&&!r.source?(r.format={...r.format??{},parse:i.assembleFormatParse()},r.transform.push(...i.assembleTransforms(!0))):r.transform.push(...i.assembleTransforms())),i instanceof od)return r.name||(r.name="data_"+t++),!r.source||r.transform.length>0?(e.push(r),i.data=r.name):i.data=r.source,void e.push(...i.assemble());if((i instanceof fd||i instanceof dd||i instanceof Om||i instanceof Bu||i instanceof of||i instanceof Pm||i instanceof rd||i instanceof Tm||i instanceof Pd||i instanceof Nd||i instanceof Nm||i instanceof _m||i instanceof Fm||i instanceof jm||i instanceof Em||i instanceof Mm||i instanceof ud||i instanceof qm||i instanceof Lm||i instanceof zm)&&r.transform.push(i.assemble()),(i instanceof nd||i instanceof wc||i instanceof Am||i instanceof Cd||i instanceof Cm)&&r.transform.push(...i.assemble()),i instanceof bc)if(r.source&&0===r.transform.length)i.setSource(r.source);else if(i.parent instanceof bc)i.setSource(r.name);else if(r.name||(r.name="data_"+t++),i.setSource(r.name),1===i.numChildren()){e.push(r);r={name:null,source:r.name,transform:[]}}switch(i.numChildren()){case 0:i instanceof bc&&(!r.source||r.transform.length>0)&&e.push(r);break;case 1:n(i.children[0],r);break;default:{r.name||(r.name="data_"+t++);let o=r.name;!r.source||r.transform.length>0?e.push(r):o=r.source;for(const e of i.children){n(e,{name:null,source:o,transform:[]})}break}}}}function Rm(e){return"top"===e||"left"===e||yn(e)?"header":"footer"}function Wm(e,n){const{facet:i,config:r,child:o,component:a}=e;if(e.channelHasField(n)){const s=i[n],l=lf("title",null,r,n);let c=ua(s,r,{allowDisabling:!0,includeDefault:void 0===l||!!l});o.component.layoutHeaders[n].title&&(c=t.isArray(c)?c.join(", "):c,c+=` / ${o.component.layoutHeaders[n].title}`,o.component.layoutHeaders[n].title=null);const u=lf("labelOrient",s.header,r,n),f=null!==s.header&&U(s.header?.labels,r.header.labels,!0),d=p(["bottom","right"],u)?"footer":"header";a.layoutHeaders[n]={title:null!==s.header?c:null,facetFieldDef:s,[d]:"facet"===n?[]:[Bm(e,n,f)]}}}function Bm(e,t,n){const i="row"===t?"height":"width";return{labels:n,sizeSignal:e.child.component.layoutSize.get(i)?e.child.getSizeSignalRef(i):void 0,axes:[]}}function Im(e,t){const{child:n}=e;if(n.component.axes[t]){const{layoutHeaders:i,resolve:r}=e.component;if(r.axis[t]=_f(r,t),"shared"===r.axis[t]){const r="x"===t?"column":"row",o=i[r];for(const i of n.component.axes[t]){const t=Rm(i.get("orient"));o[t]??=[Bm(e,r,!1)];const n=Yu(i,"main",e.config,{header:!0});n&&o[t][0].axes.push(n),i.mainExtracted=!0}}}}function Hm(e){for(const t of e.children)t.parseLayoutSize()}function Vm(e,t){const n=Ff(t),i=Nt(n),r=e.component.resolve,o=e.component.layoutSize;let a;for(const t of e.children){const o=t.component.layoutSize.getWithExplicit(n),s=r.scale[i]??Of(i,e);if("independent"===s&&"step"===o.value){a=void 0;break}if(a){if("independent"===s&&a.value!==o.value){a=void 0;break}a=nc(a,o,n,"")}else a=o}if(a){for(const i of e.children)e.renameSignal(i.getName(n),e.getName(t)),i.component.layoutSize.set(n,"merged",!1);o.setWithExplicit(t,a)}else o.setWithExplicit(t,{explicit:!1,value:void 0})}function Gm(e,t){const n="width"===t?"x":"y",i=e.config,r=e.getScaleComponent(n);if(r){const e=r.get("type"),n=r.get("range");if(xr(e)){const e=Ts(i.view,t);return vn(n)||Ns(e)?"step":e}return As(i.view,t)}if(e.hasProjection||"arc"===e.mark)return As(i.view,t);{const e=Ts(i.view,t);return Ns(e)?e.step:e}}function Ym(e,t,n){return oa(t,{suffix:`by_${oa(e)}`,...n??{}})}class Xm extends Dm{constructor(e,t,n,i){super(e,"facet",t,n,i,e.resolve),qn(this,"facet",void 0),qn(this,"child",void 0),qn(this,"children",void 0),this.child=wp(e.spec,this,this.getName("child"),void 0,i),this.children=[this.child],this.facet=this.initFacet(e.facet)}initFacet(e){if(!Ao(e))return{facet:this.initFacetFieldDef(e,"facet")};const t=D(e),n={};for(const i of t){if(![Q,J].includes(i)){$i(ai(i,"facet"));break}const t=e[i];if(void 0===t.field){$i(oi(t,i));break}n[i]=this.initFacetFieldDef(t,i)}return n}initFacetFieldDef(e,t){const n=va(e,t);return n.header?n.header=pn(n.header):null===n.header&&(n.header=null),n}channelHasField(e){return!!this.facet[e]}fieldDef(e){return this.facet[e]}parseData(){this.component.data=Jm(this),this.child.parseData()}parseLayoutSize(){Hm(this)}parseSelections(){this.child.parseSelections(),this.component.selection=this.child.component.selection}parseMarkGroup(){this.child.parseMarkGroup()}parseAxesAndHeaders(){this.child.parseAxesAndHeaders(),function(e){for(const t of Re)Wm(e,t);Im(e,"x"),Im(e,"y")}(this)}assembleSelectionTopLevelSignals(e){return this.child.assembleSelectionTopLevelSignals(e)}assembleSignals(){return this.child.assembleSignals(),[]}assembleSelectionData(e){return this.child.assembleSelectionData(e)}getHeaderLayoutMixins(){const e={};for(const t of Re)for(const n of ff){const i=this.component.layoutHeaders[t],r=i[n],{facetFieldDef:o}=i;if(o){const n=lf("titleOrient",o.header,this.config,t);if(["right","bottom"].includes(n)){const i=sf(t,n);e.titleAnchor??={},e.titleAnchor[i]="end"}}if(r?.[0]){const r="row"===t?"height":"width",o="header"===n?"headerBand":"footerBand";"facet"===t||this.child.component.layoutSize.get(r)||(e[o]??={},e[o][t]=.5),i.title&&(e.offset??={},e.offset["row"===t?"rowTitle":"columnTitle"]=10)}}return e}assembleDefaultLayout(){const{column:e,row:t}=this.facet,n=e?this.columnDistinctSignal():t?1:void 0;let i="all";return(t||"independent"!==this.component.resolve.scale.x)&&(e||"independent"!==this.component.resolve.scale.y)||(i="none"),{...this.getHeaderLayoutMixins(),...n?{columns:n}:{},bounds:"full",align:i}}assembleLayoutSignals(){return this.child.assembleLayoutSignals()}columnDistinctSignal(){if(!(this.parent&&this.parent instanceof Xm)){return{signal:`length(data('${this.getName("column_domain")}'))`}}}assembleGroupStyle(){}assembleGroup(e){return this.parent&&this.parent instanceof Xm?{...this.channelHasField("column")?{encode:{update:{columns:{field:oa(this.facet.column,{prefix:"distinct"})}}}}:{},...super.assembleGroup(e)}:super.assembleGroup(e)}getCardinalityAggregateForChild(){const e=[],t=[],n=[];if(this.child instanceof Xm){if(this.child.channelHasField("column")){const i=oa(this.child.facet.column);e.push(i),t.push("distinct"),n.push(`distinct_${i}`)}}else for(const i of Ft){const r=this.child.component.scales[i];if(r&&!r.merged){const o=r.get("type"),a=r.get("range");if(xr(o)&&vn(a)){const r=Qd(Jd(this.child,i));r?(e.push(r),t.push("distinct"),n.push(`distinct_${r}`)):$i(Yn(i))}}}return{fields:e,ops:t,as:n}}assembleFacet(){const{name:e,data:n}=this.component.data.facetRoot,{row:i,column:r}=this.facet,{fields:o,ops:a,as:s}=this.getCardinalityAggregateForChild(),l=[];for(const e of Re){const n=this.facet[e];if(n){l.push(oa(n));const{bin:c,sort:u}=n;if(ln(c)&&l.push(oa(n,{binSuffix:"end"})),Co(u)){const{field:e,op:t=zo}=u,l=Ym(n,u);i&&r?(o.push(l),a.push("max"),s.push(l)):(o.push(e),a.push(t),s.push(l))}else if(t.isArray(u)){const t=af(n,e);o.push(t),a.push("max"),s.push(t)}}}const c=!!i&&!!r;return{name:e,data:n,groupby:l,...c||o.length>0?{aggregate:{...c?{cross:c}:{},...o.length?{fields:o,ops:a,as:s}:{}}}:{}}}facetSortFields(e){const{facet:n}=this,i=n[e];return i?Co(i.sort)?[Ym(i,i.sort,{expr:"datum"})]:t.isArray(i.sort)?[af(i,e,{expr:"datum"})]:[oa(i,{expr:"datum"})]:[]}facetSortOrder(e){const{facet:n}=this,i=n[e];if(i){const{sort:e}=i;return[(Co(e)?e.order:!t.isArray(e)&&e)||"ascending"]}return[]}assembleLabelTitle(){const{facet:e,config:t}=this;if(e.facet)return yf(e.facet,"facet",t);const n={row:["top","bottom"],column:["left","right"]};for(const i of uf)if(e[i]){const r=lf("labelOrient",e[i]?.header,t,i);if(n[i].includes(r))return yf(e[i],i,t)}}assembleMarks(){const{child:e}=this,t=function(e){const t=[],n=Um(t);for(const t of e.children)n(t,{source:e.name,name:null,transform:[]});return t}(this.component.data.facetRoot),n=e.assembleGroupEncodeEntry(!1),i=this.assembleLabelTitle()||e.assembleTitle(),r=e.assembleGroupStyle();return[{name:this.getName("cell"),type:"group",...i?{title:i}:{},...r?{style:r}:{},from:{facet:this.assembleFacet()},sort:{field:Re.map((e=>this.facetSortFields(e))).flat(),order:Re.map((e=>this.facetSortOrder(e))).flat()},...t.length>0?{data:t}:{},...n?{encode:{update:n}}:{},...e.assembleGroup(gc(this,[]))}]}getMapping(){return this.facet}}function Qm(e,t){for(const n of t){const t=n.data;if(e.name&&n.hasName()&&e.name!==n.dataName)continue;const i=e.format?.mesh,r=t.format?.feature;if(i&&r)continue;const o=e.format?.feature;if((o||r)&&o!==r)continue;const a=t.format?.mesh;if(!i&&!a||i===a)if(oc(e)&&oc(t)){if(Y(e.values,t.values))return n}else if(rc(e)&&rc(t)){if(e.url===t.url)return n}else if(ac(e)&&e.name===n.dataName)return n}return null}function Jm(e){let t=function(e,t){if(e.data||!e.parent){if(null===e.data){const e=new md({values:[]});return t.push(e),e}const n=Qm(e.data,t);if(n)return sc(e.data)||(n.data.format=y({},e.data.format,n.data.format)),!n.hasName()&&e.data.name&&(n.dataName=e.data.name),n;{const n=new md(e.data);return t.push(n),n}}return e.parent.component.data.facetRoot?e.parent.component.data.facetRoot:e.parent.component.data.main}(e,e.component.data.sources);const{outputNodes:n,outputNodeRefCounts:i}=e.component.data,r=e.data,o=!(r&&(sc(r)||rc(r)||oc(r)))&&e.parent?e.parent.component.data.ancestorParse.clone():new ic;sc(r)?(lc(r)?t=new dd(t,r.sequence):uc(r)&&(t=new fd(t,r.graticule)),o.parseNothing=!0):null===r?.format?.parse&&(o.parseNothing=!0),t=cd.makeExplicit(t,e,o)??t,t=new ud(t);const a=e.parent&&km(e.parent);(xm(e)||$m(e))&&a&&(t=nd.makeFromEncoding(t,e)??t),e.transforms.length>0&&(t=function(e,t,n){let i=0;for(const r of t.transforms){let o,a;if(Nl(r))a=e=new of(e,r),o="derived";else if(bl(r)){const i=sd(r);a=e=cd.makeWithAncestors(e,{},i,n)??e,e=new Bu(e,t,r.filter)}else if(Cl(r))a=e=nd.makeFromTransform(e,r,t),o="number";else if(Al(r))o="date",void 0===n.getWithExplicit(r.field).value&&(e=new cd(e,{[r.field]:o}),n.set(r.field,o,!1)),a=e=wc.makeFromTransform(e,r);else if(jl(r))a=e=rd.makeFromTransform(e,r),o="number",Lu(t)&&(e=new ud(e));else if(xl(r))a=e=Tm.make(e,t,r,i++),o="derived";else if(zl(r))a=e=new Pd(e,r),o="number";else if(Ol(r))a=e=new Nd(e,r),o="number";else if(Tl(r))a=e=Cd.makeFromTransform(e,r),o="derived";else if(El(r))a=e=new Nm(e,r),o="derived";else if(Ml(r))a=e=new zm(e,r),o="derived";else if(_l(r))a=e=new _m(e,r),o="derived";else if($l(r))a=e=new Lm(e,r),o="derived";else if(Fl(r))e=new qm(e,r);else if(Pl(r))a=e=Am.makeFromTransform(e,r),o="derived";else if(wl(r))a=e=new Fm(e,r),o="derived";else if(kl(r))a=e=new Em(e,r),o="derived";else if(Sl(r))a=e=new Mm(e,r),o="derived";else{if(!Dl(r)){$i(`Ignoring an invalid transform: ${X(r)}.`);continue}a=e=new jm(e,r),o="derived"}if(a&&void 0!==o)for(const e of a.producedFields()??[])n.set(e,o,!1)}return e}(t,e,o));const s=function(e){const t={};if(xm(e)&&e.component.selection)for(const n of D(e.component.selection)){const i=e.component.selection[n];for(const e of i.project.items)!e.channel&&q(e.field)>1&&(t[e.field]="flatten")}return t}(e),l=ld(e);t=cd.makeWithAncestors(t,{},{...s,...l},o)??t,xm(e)&&(t=Cm.parseAll(t,e),t=Pm.parseAll(t,e)),(xm(e)||$m(e))&&(a||(t=nd.makeFromEncoding(t,e)??t),t=wc.makeFromEncoding(t,e)??t,t=of.parseAllForSortIndex(t,e));const c=e.getDataName(fc.Raw),u=new bc(t,c,fc.Raw,i);if(n[c]=u,t=u,xm(e)){const n=rd.makeFromEncoding(t,e);n&&(t=n,Lu(e)&&(t=new ud(t))),t=Am.makeFromEncoding(t,e)??t,t=Cd.makeFromEncoding(t,e)??t}xm(e)&&(t=Om.make(t,e)??t);const f=e.getDataName(fc.Main),d=new bc(t,f,fc.Main,i);n[f]=d,t=d,xm(e)&&function(e,t){for(const[n,i]of z(e.component.selection??{})){const r=e.getName(`lookup_${n}`);e.component.data.outputNodes[r]=i.materialized=new bc(new Bu(t,e,{param:n}),r,fc.Lookup,e.component.data.outputNodeRefCounts)}}(e,d);let m=null;if($m(e)){const i=e.getName("facet");t=function(e,t){const{row:n,column:i}=t;if(n&&i){let t=null;for(const r of[n,i])if(Co(r.sort)){const{field:n,op:i=zo}=r.sort;e=t=new Nd(e,{joinaggregate:[{op:i,field:n,as:Ym(r,r.sort,{forAs:!0})}],groupby:[oa(r)]})}return t}return null}(t,e.facet)??t,m=new od(t,e,i,d.getSource()),n[i]=m}return{...e.component.data,outputNodes:n,outputNodeRefCounts:i,raw:u,main:d,facetRoot:m,ancestorParse:o}}class Km extends Sm{constructor(e,t,n,i){super(e,"concat",t,n,i,e.resolve),qn(this,"children",void 0),"shared"!==e.resolve?.axis?.x&&"shared"!==e.resolve?.axis?.y||$i("Axes cannot be shared in concatenated or repeated views yet (https://github.com/vega/vega-lite/issues/2415)."),this.children=this.getChildren(e).map(((e,t)=>wp(e,this,this.getName(`concat_${t}`),void 0,i)))}parseData(){this.component.data=Jm(this);for(const e of this.children)e.parseData()}parseSelections(){this.component.selection={};for(const e of this.children){e.parseSelections();for(const t of D(e.component.selection))this.component.selection[t]=e.component.selection[t]}}parseMarkGroup(){for(const e of this.children)e.parseMarkGroup()}parseAxesAndHeaders(){for(const e of this.children)e.parseAxesAndHeaders()}getChildren(e){return zs(e)?e.vconcat:Os(e)?e.hconcat:e.concat}parseLayoutSize(){!function(e){Hm(e);const t=1===e.layout.columns?"width":"childWidth",n=void 0===e.layout.columns?"height":"childHeight";Vm(e,t),Vm(e,n)}(this)}parseAxisGroup(){return null}assembleSelectionTopLevelSignals(e){return this.children.reduce(((e,t)=>t.assembleSelectionTopLevelSignals(e)),e)}assembleSignals(){return this.children.forEach((e=>e.assembleSignals())),[]}assembleLayoutSignals(){const e=wf(this);for(const t of this.children)e.push(...t.assembleLayoutSignals());return e}assembleSelectionData(e){return this.children.reduce(((e,t)=>t.assembleSelectionData(e)),e)}assembleMarks(){return this.children.map((e=>{const t=e.assembleTitle(),n=e.assembleGroupStyle(),i=e.assembleGroupEncodeEntry(!1);return{type:"group",name:e.getName("group"),...t?{title:t}:{},...n?{style:n}:{},...i?{encode:{update:i}}:{},...e.assembleGroup()}}))}assembleGroupStyle(){}assembleDefaultLayout(){const e=this.layout.columns;return{...null!=e?{columns:e}:{},bounds:"full",align:"each"}}}const Zm={disable:1,gridScale:1,scale:1,..._a,labelExpr:1,encode:1},ep=D(Zm);class tp extends Jl{constructor(){let e=arguments.length>0&&void 0!==arguments[0]?arguments[0]:{},t=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{},n=arguments.length>2&&void 0!==arguments[2]&&arguments[2];super(),this.explicit=e,this.implicit=t,this.mainExtracted=n}clone(){return new tp(l(this.explicit),l(this.implicit),this.mainExtracted)}hasAxisPart(e){return"axis"===e||("grid"===e||"title"===e?!!this.get(e):!(!1===(t=this.get(e))||null===t));var t}hasOrientSignalRef(){return yn(this.explicit.orient)}}const np={bottom:"top",top:"bottom",left:"right",right:"left"};function ip(e,t){if(!e)return t.map((e=>e.clone()));{if(e.length!==t.length)return;const n=e.length;for(let i=0;i{switch(n){case"title":return Ln(e,t);case"gridScale":return{explicit:e.explicit,value:U(e.value,t.value)}}return tc(e,t,n,"axis")}));e.setWithExplicit(n,i)}return e}function op(e,t,n,i,r){if("disable"===t)return void 0!==n;switch(n=n||{},t){case"titleAngle":case"labelAngle":return e===(yn(n.labelAngle)?n.labelAngle:H(n.labelAngle));case"values":return!!n.values;case"encode":return!!n.encoding||!!n.labelAngle;case"title":if(e===rf(i,r))return!0}return e===n[t]}const ap=new Set(["grid","translate","format","formatType","orient","labelExpr","tickCount","position","tickMinStep"]);function sp(e,t){let n=t.axis(e);const i=new tp,r=ga(t.encoding[e]),{mark:o,config:a}=t,s=n?.orient||a["x"===e?"axisX":"axisY"]?.orient||a.axis?.orient||function(e){return"x"===e?"bottom":"left"}(e),l=t.getScaleComponent(e).get("type"),c=function(e,t,n,i){const r="band"===t?["axisDiscrete","axisBand"]:"point"===t?["axisDiscrete","axisPoint"]:hr(t)?["axisQuantitative"]:"time"===t||"utc"===t?["axisTemporal"]:[],o="x"===e?"axisX":"axisY",a=yn(n)?"axisOrient":`axis${P(n)}`,s=[...r,...r.map((e=>o+e.substr(4)))],l=["axis",a,o];return{vlOnlyAxisConfig:Qu(s,i,e,n),vgAxisConfig:Qu(l,i,e,n),axisConfigStyle:Ju([...l,...s],i)}}(e,l,s,t.config),u=void 0!==n?!n:Ku("disable",a.style,n?.style,c).configValue;if(i.set("disable",u,void 0!==n),u)return i;n=n||{};const f=function(e,t,n,i,r){const o=t?.labelAngle;if(void 0!==o)return yn(o)?o:H(o);{const{configValue:o}=Ku("labelAngle",i,t?.style,r);return void 0!==o?H(o):n!==Z||!p([sr,or],e.type)||Ho(e)&&e.timeUnit?void 0:270}}(r,n,e,a.style,c),d=wo(n.formatType,r,l),m=$o(r,r.type,n.format,n.formatType,a,!0),g={fieldOrDatumDef:r,axis:n,channel:e,model:t,scaleType:l,orient:s,labelAngle:f,format:m,formatType:d,mark:o,config:a};for(const r of ep){const o=r in Zu?Zu[r](g):Ca(r)?n[r]:void 0,s=void 0!==o,l=op(o,r,n,t,e);if(s&&l)i.set(r,o,l);else{const{configValue:e,configFrom:t}=Ca(r)&&"values"!==r?Ku(r,a.style,n.style,c):{},u=void 0!==e;s&&!u?i.set(r,o,l):("vgAxisConfig"!==t||ap.has(r)&&u||Fa(e)||yn(e))&&i.set(r,e,!1)}}const h=n.encoding??{},y=za.reduce(((n,r)=>{if(!i.hasAxisPart(r))return n;const o=zf(h[r]??{},t),a="labels"===r?function(e,t,n){const{encoding:i,config:r}=e,o=ga(i[t])??ga(i[it(t)]),a=e.axis(t)||{},{format:s,formatType:l}=a;if(go(l))return{text:xo({fieldOrDatumDef:o,field:"datum.value",format:s,formatType:l,config:r}),...n};if(void 0===s&&void 0===l&&r.customFormatTypes){if("quantitative"===Vo(o)){if(ta(o)&&"normalize"===o.stack&&r.normalizedNumberFormatType)return{text:xo({fieldOrDatumDef:o,field:"datum.value",format:r.normalizedNumberFormat,formatType:r.normalizedNumberFormatType,config:r}),...n};if(r.numberFormatType)return{text:xo({fieldOrDatumDef:o,field:"datum.value",format:r.numberFormat,formatType:r.numberFormatType,config:r}),...n}}if("temporal"===Vo(o)&&r.timeFormatType&&Ho(o)&&!o.timeUnit)return{text:xo({fieldOrDatumDef:o,field:"datum.value",format:r.timeFormat,formatType:r.timeFormatType,config:r}),...n}}return n}(t,e,o):o;return void 0===a||S(a)||(n[r]={update:a}),n}),{});return S(y)||i.set("encode",y,!!n.encoding||void 0!==n.labelAngle),i}function lp(e,t){const{config:n}=e;return{...du(e,{align:"ignore",baseline:"ignore",color:"include",size:"include",orient:"ignore",theta:"ignore"}),...Zc("x",e,{defaultPos:"mid"}),...Zc("y",e,{defaultPos:"mid"}),...Xc("size",e),...Xc("angle",e),...cp(e,n,t)}}function cp(e,t,n){return n?{shape:{value:n}}:Xc("shape",e)}const up={vgMark:"rule",encodeEntry:e=>{const{markDef:t}=e,n=t.orient;return e.encoding.x||e.encoding.y||e.encoding.latitude||e.encoding.longitude?{...du(e,{align:"ignore",baseline:"ignore",color:"include",orient:"ignore",size:"ignore",theta:"ignore"}),...ru("x",e,{defaultPos:"horizontal"===n?"zeroOrMax":"mid",defaultPos2:"zeroOrMin",range:"vertical"!==n}),...ru("y",e,{defaultPos:"vertical"===n?"zeroOrMax":"mid",defaultPos2:"zeroOrMin",range:"horizontal"!==n}),...Xc("size",e,{vgChannel:"strokeWidth"})}:{}}};function fp(e,t,n){if(void 0===Cn("align",e,n))return"center"}function dp(e,t,n){if(void 0===Cn("baseline",e,n))return"middle"}const mp={vgMark:"rect",encodeEntry:e=>{const{config:t,markDef:n}=e,i=n.orient,r="horizontal"===i?"width":"height",o="horizontal"===i?"height":"width";return{...du(e,{align:"ignore",baseline:"ignore",color:"include",orient:"ignore",size:"ignore",theta:"ignore"}),...Zc("x",e,{defaultPos:"mid",vgChannel:"xc"}),...Zc("y",e,{defaultPos:"mid",vgChannel:"yc"}),...Xc("size",e,{defaultValue:pp(e),vgChannel:r}),[o]:Fn(Cn("thickness",n,t))}}};function pp(e){const{config:n,markDef:i}=e,{orient:r}=i,o="horizontal"===r?"width":"height",a=e.getScaleComponent("horizontal"===r?"x":"y"),s=Cn("size",i,n,{vgChannel:o})??n.tick.bandSize;if(void 0!==s)return s;{const e=a?a.get("range"):void 0;if(e&&vn(e)&&t.isNumber(e.step))return 3*e.step/4;return 3*js(n.view,o)/4}}const gp={arc:{vgMark:"arc",encodeEntry:e=>({...du(e,{align:"ignore",baseline:"ignore",color:"include",size:"ignore",orient:"ignore",theta:"ignore"}),...Zc("x",e,{defaultPos:"mid"}),...Zc("y",e,{defaultPos:"mid"}),...su(e,"radius"),...su(e,"theta")})},area:{vgMark:"area",encodeEntry:e=>({...du(e,{align:"ignore",baseline:"ignore",color:"include",orient:"include",size:"ignore",theta:"ignore"}),...ru("x",e,{defaultPos:"zeroOrMin",defaultPos2:"zeroOrMin",range:"horizontal"===e.markDef.orient}),...ru("y",e,{defaultPos:"zeroOrMin",defaultPos2:"zeroOrMin",range:"vertical"===e.markDef.orient}),...gu(e)})},bar:{vgMark:"rect",encodeEntry:e=>({...du(e,{align:"ignore",baseline:"ignore",color:"include",orient:"ignore",size:"ignore",theta:"ignore"}),...su(e,"x"),...su(e,"y")})},circle:{vgMark:"symbol",encodeEntry:e=>lp(e,"circle")},geoshape:{vgMark:"shape",encodeEntry:e=>({...du(e,{align:"ignore",baseline:"ignore",color:"include",size:"ignore",orient:"ignore",theta:"ignore"})}),postEncodingTransform:e=>{const{encoding:t}=e,n=t.shape;return[{type:"geoshape",projection:e.projectionName(),...n&&Ho(n)&&n.type===lr?{field:oa(n,{expr:"datum"})}:{}}]}},image:{vgMark:"image",encodeEntry:e=>({...du(e,{align:"ignore",baseline:"ignore",color:"ignore",orient:"ignore",size:"ignore",theta:"ignore"}),...su(e,"x"),...su(e,"y"),...Rc(e,"url")})},line:{vgMark:"line",encodeEntry:e=>({...du(e,{align:"ignore",baseline:"ignore",color:"include",size:"ignore",orient:"ignore",theta:"ignore"}),...Zc("x",e,{defaultPos:"mid"}),...Zc("y",e,{defaultPos:"mid"}),...Xc("size",e,{vgChannel:"strokeWidth"}),...gu(e)})},point:{vgMark:"symbol",encodeEntry:e=>lp(e)},rect:{vgMark:"rect",encodeEntry:e=>({...du(e,{align:"ignore",baseline:"ignore",color:"include",orient:"ignore",size:"ignore",theta:"ignore"}),...su(e,"x"),...su(e,"y")})},rule:up,square:{vgMark:"symbol",encodeEntry:e=>lp(e,"square")},text:{vgMark:"text",encodeEntry:e=>{const{config:t,encoding:n}=e;return{...du(e,{align:"include",baseline:"include",color:"include",size:"ignore",orient:"ignore",theta:"include"}),...Zc("x",e,{defaultPos:"mid"}),...Zc("y",e,{defaultPos:"mid"}),...Rc(e),...Xc("size",e,{vgChannel:"fontSize"}),...Xc("angle",e),...hu("align",fp(e.markDef,n,t)),...hu("baseline",dp(e.markDef,n,t)),...Zc("radius",e,{defaultPos:null}),...Zc("theta",e,{defaultPos:null})}}},tick:mp,trail:{vgMark:"trail",encodeEntry:e=>({...du(e,{align:"ignore",baseline:"ignore",color:"include",size:"include",orient:"ignore",theta:"ignore"}),...Zc("x",e,{defaultPos:"mid"}),...Zc("y",e,{defaultPos:"mid"}),...Xc("size",e),...gu(e)})}};function hp(e){if(p([Ur,Mr,Vr],e.mark)){const t=Ba(e.mark,e.encoding);if(t.length>0)return function(e,t){return[{name:e.getName("pathgroup"),type:"group",from:{facet:{name:yp+e.requestDataName(fc.Main),data:e.requestDataName(fc.Main),groupby:t}},encode:{update:{width:{field:{group:"width"}},height:{field:{group:"height"}}}},marks:bp(e,{fromPrefix:yp})}]}(e,t)}else if(e.mark===Lr){const t=wn.some((t=>Cn(t,e.markDef,e.config)));if(e.stack&&!e.fieldDef("size")&&t)return function(e){const[t]=bp(e,{fromPrefix:vp}),n=e.scaleName(e.stack.fieldChannel),i=function(){let t=arguments.length>0&&void 0!==arguments[0]?arguments[0]:{};return e.vgField(e.stack.fieldChannel,t)},r=(e,t)=>`${e}(${[i({prefix:"min",suffix:"start",expr:t}),i({prefix:"max",suffix:"start",expr:t}),i({prefix:"min",suffix:"end",expr:t}),i({prefix:"max",suffix:"end",expr:t})].map((e=>`scale('${n}',${e})`)).join(",")})`;let o,a;"x"===e.stack.fieldChannel?(o={...u(t.encode.update,["y","yc","y2","height",...wn]),x:{signal:r("min","datum")},x2:{signal:r("max","datum")},clip:{value:!0}},a={x:{field:{group:"x"},mult:-1},height:{field:{group:"height"}}},t.encode.update={...f(t.encode.update,["y","yc","y2"]),height:{field:{group:"height"}}}):(o={...u(t.encode.update,["x","xc","x2","width"]),y:{signal:r("min","datum")},y2:{signal:r("max","datum")},clip:{value:!0}},a={y:{field:{group:"y"},mult:-1},width:{field:{group:"width"}}},t.encode.update={...f(t.encode.update,["x","xc","x2"]),width:{field:{group:"width"}}});for(const n of wn){const i=Pn(n,e.markDef,e.config);t.encode.update[n]?(o[n]=t.encode.update[n],delete t.encode.update[n]):i&&(o[n]=Fn(i)),i&&(t.encode.update[n]={value:0})}const s=[];if(e.stack.groupbyChannels?.length>0)for(const t of e.stack.groupbyChannels){const n=e.fieldDef(t),i=oa(n);i&&s.push(i),(n?.bin||n?.timeUnit)&&s.push(oa(n,{binSuffix:"end"}))}o=["stroke","strokeWidth","strokeJoin","strokeCap","strokeDash","strokeDashOffset","strokeMiterLimit","strokeOpacity"].reduce(((n,i)=>{if(t.encode.update[i])return{...n,[i]:t.encode.update[i]};{const t=Pn(i,e.markDef,e.config);return void 0!==t?{...n,[i]:Fn(t)}:n}}),o),o.stroke&&(o.strokeForeground={value:!0},o.strokeOffset={value:0});return[{type:"group",from:{facet:{data:e.requestDataName(fc.Main),name:vp+e.requestDataName(fc.Main),groupby:s,aggregate:{fields:[i({suffix:"start"}),i({suffix:"start"}),i({suffix:"end"}),i({suffix:"end"})],ops:["min","max","min","max"]}}},encode:{update:o},marks:[{type:"group",encode:{update:a},marks:[t]}]}]}(e)}return bp(e)}const yp="faceted_path_";const vp="stack_group_";function bp(e){let n=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{fromPrefix:""};const{mark:i,markDef:r,encoding:o,config:a}=e,s=U(r.clip,function(e){const t=e.getScaleComponent("x"),n=e.getScaleComponent("y");return!(!t?.get("selectionExtent")&&!n?.get("selectionExtent"))||void 0}(e),function(e){const t=e.component.projection;return!(!t||t.isFit)||void 0}(e)),l=Nn(r),c=o.key,u=function(e){const{encoding:n,stack:i,mark:r,markDef:o,config:a}=e,s=n.order;if(!(!t.isArray(s)&&Zo(s)&&m(s.value)||!s&&m(Cn("order",o,a)))){if((t.isArray(s)||Ho(s))&&!i)return Tn(s,{expr:"datum"});if(Qr(r)){const i="horizontal"===o.orient?"y":"x",r=n[i];if(Ho(r)){const n=r.sort;return t.isArray(n)?{field:oa(r,{prefix:i,suffix:"sort_index",expr:"datum"})}:Co(n)?{field:oa({aggregate:La(e.encoding)?n.op:void 0,field:n.field},{expr:"datum"})}:No(n)?{field:oa(e.fieldDef(n.encoding),{expr:"datum"}),order:n.order}:null===n?void 0:{field:oa(r,{binSuffix:e.stack?.impute?"mid":void 0,expr:"datum"})}}}}}(e),f=function(e){if(!e.component.selection)return null;const t=D(e.component.selection).length;let n=t,i=e.parent;for(;i&&0===n;)n=D(i.component.selection).length,i=i.parent;return n?{interactive:t>0||"geoshape"===e.mark||!!e.encoding.tooltip}:null}(e),d=Cn("aria",r,a),p=gp[i].postEncodingTransform?gp[i].postEncodingTransform(e):null;return[{name:e.getName("marks"),type:gp[i].vgMark,...s?{clip:!0}:{},...l?{style:l}:{},...c?{key:c.field}:{},...u?{sort:u}:{},...f||{},...!1===d?{aria:d}:{},from:{data:n.fromPrefix+e.requestDataName(fc.Main)},encode:{update:gp[i].encodeEntry(e)},...p?{transform:p}:{}}]}class xp extends Dm{constructor(e,n,i){let r=arguments.length>3&&void 0!==arguments[3]?arguments[3]:{},o=arguments.length>4?arguments[4]:void 0;super(e,"unit",n,i,o,void 0,Cs(e)?e.view:void 0),qn(this,"markDef",void 0),qn(this,"encoding",void 0),qn(this,"specifiedScales",{}),qn(this,"stack",void 0),qn(this,"specifiedAxes",{}),qn(this,"specifiedLegends",{}),qn(this,"specifiedProjection",{}),qn(this,"selection",[]),qn(this,"children",[]);const a=Zr(e.mark)?{...e.mark}:{type:e.mark},s=a.type;void 0===a.filled&&(a.filled=function(e,t,n){let{graticule:i}=n;if(i)return!1;const r=Pn("filled",e,t),o=e.type;return U(r,o!==Rr&&o!==Ur&&o!==Br)}(a,o,{graticule:e.data&&uc(e.data)}));const l=this.encoding=function(e,n,i,r){const o={};for(const t of D(e))Ke(t)||$i(`${a=t}-encoding is dropped as ${a} is not a valid encoding channel.`);var a;for(let a of lt){if(!e[a])continue;const s=e[a];if(Pt(a)){const e=st(a),t=o[e];if(Ho(t)&&nr(t.type)&&Ho(s)&&!t.timeUnit){$i(ni(e));continue}}if("angle"!==a||"arc"!==n||e.theta||($i("Arc marks uses theta channel rather than angle, replacing angle with theta."),a=se),Ua(e,a,n)){if(a===ye&&"line"===n){const t=pa(e[a]);if(t?.aggregate){$i("Line marks cannot encode size with a non-groupby field. You may want to use trail marks instead.");continue}}if(a===me&&(i?"fill"in e:"stroke"in e))$i(ri("encoding",{fill:"fill"in e,stroke:"stroke"in e}));else if(a===Fe||a===De&&!t.isArray(s)&&!Zo(s)||a===Oe&&t.isArray(s)){if(s){if(a===De){const t=e[a];if(Ro(t)){o[a]=t;continue}}o[a]=t.array(s).reduce(((e,t)=>(Ho(t)?e.push(va(t,a)):$i(oi(t,a)),e)),[])}}else{if(a===Oe&&null===s)o[a]=null;else if(!(Ho(s)||Go(s)||Zo(s)||Wo(s)||yn(s))){$i(oi(s,a));continue}o[a]=ha(s,a,r)}}else $i(ai(a,n))}return o}(e.encoding||{},s,a.filled,o);this.markDef=il(a,l,o),this.size=function(e){let{encoding:t,size:n}=e;for(const e of Ft){const i=rt(e);Ns(n[i])&&Yo(t[e])&&(delete n[i],$i(pi(i)))}return n}({encoding:l,size:Cs(e)?{...r,...e.width?{width:e.width}:{},...e.height?{height:e.height}:{}}:r}),this.stack=nl(this.markDef,l),this.specifiedScales=this.initScales(s,l),this.specifiedAxes=this.initAxes(l),this.specifiedLegends=this.initLegends(l),this.specifiedProjection=e.projection,this.selection=(e.params??[]).filter((e=>Ss(e)))}get hasProjection(){const{encoding:e}=this,t=this.mark===Xr,n=e&&Me.some((t=>Jo(e[t])));return t||n}scaleDomain(e){const t=this.specifiedScales[e];return t?t.domain:void 0}axis(e){return this.specifiedAxes[e]}legend(e){return this.specifiedLegends[e]}initScales(e,t){return It.reduce(((e,n)=>{const i=ga(t[n]);return i&&(e[n]=this.initScale(i.scale??{})),e}),{})}initScale(e){const{domain:n,range:i}=e,r=pn(e);return t.isArray(n)&&(r.domain=n.map(Sn)),t.isArray(i)&&(r.range=i.map(Sn)),r}initAxes(e){return Ft.reduce(((t,n)=>{const i=e[n];if(Jo(i)||n===Z&&Jo(e.x2)||n===ee&&Jo(e.y2)){const e=Jo(i)?i.axis:void 0;t[n]=e?this.initAxis({...e}):e}return t}),{})}initAxis(e){const t=D(e),n={};for(const i of t){const t=e[i];n[i]=Fa(t)?kn(t):Sn(t)}return n}initLegends(e){return Wt.reduce(((t,n)=>{const i=ga(e[n]);if(i&&function(e){switch(e){case me:case pe:case ge:case ye:case he:case be:case we:case ke:return!0;case xe:case $e:case ve:return!1}}(n)){const e=i.legend;t[n]=e?pn(e):e}return t}),{})}parseData(){this.component.data=Jm(this)}parseLayoutSize(){!function(e){const{size:t,component:n}=e;for(const i of Ft){const r=rt(i);if(t[r]){const e=t[r];n.layoutSize.set(r,Ns(e)?"step":e,!0)}else{const t=Gm(e,r);n.layoutSize.set(r,t,!1)}}}(this)}parseSelections(){this.component.selection=function(e,n){const i={},r=e.config.selection;if(!n||!n.length)return i;for(const o of n){const n=_(o.name),a=o.select,s=t.isString(a)?a:a.type,c=t.isObject(a)?l(a):{type:s},u=r[s];for(const e in u)"fields"!==e&&"encodings"!==e&&("mark"===e&&(c[e]={...u[e],...c[e]}),void 0!==c[e]&&!0!==c[e]||(c[e]=l(u[e]??c[e])));const f=i[n]={...c,name:n,type:s,init:o.value,bind:o.bind,events:t.isString(c.on)?t.parseSelector(c.on,"scope"):t.array(l(c.on))},d=l(o);for(const t of Eu)t.defined(f)&&t.parse&&t.parse(e,f,d)}return i}(this,this.selection)}parseMarkGroup(){this.component.mark=hp(this)}parseAxesAndHeaders(){var e;this.component.axes=(e=this,Ft.reduce(((t,n)=>(e.component.scales[n]&&(t[n]=[sp(n,e)]),t)),{}))}assembleSelectionTopLevelSignals(e){return function(e,n){let i=!1;for(const r of F(e.component.selection??{})){const o=r.name,a=t.stringValue(o+Pu);if(0===n.filter((e=>e.name===o)).length){const e="global"===r.resolve?"union":r.resolve,i="point"===r.type?", true, true)":")";n.push({name:r.name,update:`${Tu}(${a}, ${t.stringValue(e)}${i}`})}i=!0;for(const t of Eu)t.defined(r)&&t.topLevelSignals&&(n=t.topLevelSignals(e,r,n))}i&&0===n.filter((e=>"unit"===e.name)).length&&n.unshift({name:"unit",value:{},on:[{events:"pointermove",update:"isTuple(group()) ? group() : unit"}]});return yc(n)}(this,e)}assembleSignals(){return[...Xu(this),...pc(this,[])]}assembleSelectionData(e){return function(e,t){const n=[...t],i=Mu(e,{escape:!1});for(const t of F(e.component.selection??{})){const e={name:t.name+Pu};if(t.project.hasSelectionId&&(e.transform=[{type:"collect",sort:{field:xs}}]),t.init){const n=t.project.items.map(dc);e.values=t.project.hasSelectionId?t.init.map((e=>({unit:i,[xs]:mc(e,!1)[0]}))):t.init.map((e=>({unit:i,fields:n,values:mc(e,!1)})))}n.filter((e=>e.name===t.name+Pu)).length||n.push(e)}return n}(this,e)}assembleLayout(){return null}assembleLayoutSignals(){return wf(this)}assembleMarks(){let e=this.component.mark??[];return this.parent&&km(this.parent)||(e=hc(this,e)),e.map(this.correctDataNames)}assembleGroupStyle(){const{style:e}=this.view||{};return void 0!==e?e:this.encoding.x||this.encoding.y?"cell":"view"}getMapping(){return this.encoding}get mark(){return this.markDef.type}channelHasField(e){return Ta(this.encoding,e)}fieldDef(e){return pa(this.encoding[e])}typedFieldDef(e){const t=this.fieldDef(e);return Ko(t)?t:null}}class $p extends Sm{constructor(e,t,n,i,r){super(e,"layer",t,n,r,e.resolve,e.view),qn(this,"children",void 0);const o={...i,...e.width?{width:e.width}:{},...e.height?{height:e.height}:{}};this.children=e.layer.map(((e,t)=>{if(Xs(e))return new $p(e,this,this.getName(`layer_${t}`),o,r);if(Aa(e))return new xp(e,this,this.getName(`layer_${t}`),o,r);throw new Error(Bn(e))}))}parseData(){this.component.data=Jm(this);for(const e of this.children)e.parseData()}parseLayoutSize(){var e;Hm(e=this),Vm(e,"width"),Vm(e,"height")}parseSelections(){this.component.selection={};for(const e of this.children){e.parseSelections();for(const t of D(e.component.selection))this.component.selection[t]=e.component.selection[t]}}parseMarkGroup(){for(const e of this.children)e.parseMarkGroup()}parseAxesAndHeaders(){!function(e){const{axes:t,resolve:n}=e.component,i={top:0,bottom:0,right:0,left:0};for(const i of e.children){i.parseAxesAndHeaders();for(const r of D(i.component.axes))n.axis[r]=_f(e.component.resolve,r),"shared"===n.axis[r]&&(t[r]=ip(t[r],i.component.axes[r]),t[r]||(n.axis[r]="independent",delete t[r]))}for(const r of Ft){for(const o of e.children)if(o.component.axes[r]){if("independent"===n.axis[r]){t[r]=(t[r]??[]).concat(o.component.axes[r]);for(const e of o.component.axes[r]){const{value:t,explicit:n}=e.getWithExplicit("orient");if(!yn(t)){if(i[t]>0&&!n){const n=np[t];i[t]>i[n]&&e.set("orient",n,!1)}i[t]++}}}delete o.component.axes[r]}if("independent"===n.axis[r]&&t[r]&&t[r].length>1)for(const[e,n]of(t[r]||[]).entries())e>0&&n.get("grid")&&!n.explicit.grid&&(n.implicit.grid=!1)}}(this)}assembleSelectionTopLevelSignals(e){return this.children.reduce(((e,t)=>t.assembleSelectionTopLevelSignals(e)),e)}assembleSignals(){return this.children.reduce(((e,t)=>e.concat(t.assembleSignals())),Xu(this))}assembleLayoutSignals(){return this.children.reduce(((e,t)=>e.concat(t.assembleLayoutSignals())),wf(this))}assembleSelectionData(e){return this.children.reduce(((e,t)=>t.assembleSelectionData(e)),e)}assembleGroupStyle(){const e=new Set;for(const n of this.children)for(const i of t.array(n.assembleGroupStyle()))e.add(i);const n=Array.from(e);return n.length>1?n:1===n.length?n[0]:void 0}assembleTitle(){let e=super.assembleTitle();if(e)return e;for(const t of this.children)if(e=t.assembleTitle(),e)return e}assembleLayout(){return null}assembleMarks(){return function(e,t){for(const n of e.children)xm(n)&&(t=hc(n,t));return t}(this,this.children.flatMap((e=>e.assembleMarks())))}assembleLegends(){return this.children.reduce(((e,t)=>e.concat(t.assembleLegends())),Vf(this))}}function wp(e,t,n,i,r){if(To(e))return new Xm(e,t,n,r);if(Xs(e))return new $p(e,t,n,i,r);if(Aa(e))return new xp(e,t,n,i,r);if(function(e){return zs(e)||Os(e)||Fs(e)}(e))return new Km(e,t,n,r);throw new Error(Bn(e))}const kp=n;e.accessPathDepth=q,e.accessPathWithDatum=A,e.compile=function(e){let n=arguments.length>1&&void 0!==arguments[1]?arguments[1]:{};var i;n.logger&&(i=n.logger,xi=i),n.fieldTitle&&ca(n.fieldTitle);try{const i=Bs(t.mergeConfig(n.config,e.config)),r=Il(e,i),o=wp(r,null,"",void 0,i);o.parse(),function(e,t){Md(e.sources);let n=0,i=0;for(let i=0;i2&&void 0!==arguments[2]?arguments[2]:{},i=arguments.length>3?arguments[3]:void 0;const r=e.config?Gs(e.config):void 0,o=[].concat(e.assembleSelectionData([]),function(e,t){const n=[],i=Um(n);let r=0;for(const t of e.sources){t.hasName()||(t.dataName="source_"+r++);const e=t.assemble();i(t,e)}for(const e of n)0===e.transform.length&&delete e.transform;let o=0;for(const[e,t]of n.entries())0!==(t.transform??[]).length||t.source||n.splice(o++,0,n.splice(e,1)[0]);for(const t of n)for(const n of t.transform??[])"lookup"===n.type&&(n.from=e.outputNodes[n.from].getSource());for(const e of n)e.name in t&&(e.values=t[e.name]);return n}(e.component.data,n)),a=e.assembleProjections(),s=e.assembleTitle(),l=e.assembleGroupStyle(),c=e.assembleGroupEncodeEntry(!0);let u=e.assembleLayoutSignals();u=u.filter((e=>"width"!==e.name&&"height"!==e.name||void 0===e.value||(t[e.name]=+e.value,!1)));const{params:f,...d}=t;return{$schema:"https://vega.github.io/schema/vega/v5.json",...e.description?{description:e.description}:{},...d,...s?{title:s}:{},...l?{style:l}:{},...c?{encode:{update:c}}:{},data:o,...a.length>0?{projections:a}:{},...e.assembleGroup([...u,...e.assembleSelectionTopLevelSignals([]),...Ds(f)]),...r?{config:r}:{},...i?{usermeta:i}:{}}}(o,function(e,n,i,r){const o=r.component.layoutSize.get("width"),a=r.component.layoutSize.get("height");void 0===n?(n={type:"pad"},r.hasAxisOrientSignalRef()&&(n.resize=!0)):t.isString(n)&&(n={type:n});if(o&&a&&(s=n.type,"fit"===s||"fit-x"===s||"fit-y"===s))if("step"===o&&"step"===a)$i(Gn()),n.type="pad";else if("step"===o||"step"===a){const e="step"===o?"width":"height";$i(Gn(Nt(e)));const t="width"===e?"height":"width";n.type=function(e){return e?`fit-${Nt(e)}`:"fit"}(t)}var s;return{...1===D(n).length&&n.type?"pad"===n.type?{}:{autosize:n.type}:{autosize:n},...Ql(i,!1),...Ql(e,!0)}}(e,r.autosize,i,o),e.datasets,e.usermeta);return{spec:a,normalized:r}}finally{n.logger&&(xi=bi),n.fieldTitle&&ca(sa)}},e.contains=p,e.deepEqual=Y,e.deleteNestedProperty=C,e.duplicate=l,e.entries=z,e.every=h,e.fieldIntersection=k,e.flatAccessWithDatum=j,e.getFirstDefined=U,e.hasIntersection=$,e.hash=d,e.internalField=B,e.isBoolean=O,e.isEmpty=S,e.isEqual=function(e,t){const n=D(e),i=D(t);if(n.length!==i.length)return!1;for(const i of n)if(e[i]!==t[i])return!1;return!0},e.isInternalField=I,e.isNullOrFalse=m,e.isNumeric=V,e.keys=D,e.logicalExpr=N,e.mergeDeep=y,e.never=c,e.normalize=Il,e.normalizeAngle=H,e.omit=f,e.pick=u,e.prefixGenerator=w,e.removePathFromField=L,e.replaceAll=M,e.replacePathInField=E,e.resetIdCounter=function(){R=42},e.setEqual=x,e.some=g,e.stringify=X,e.titleCase=P,e.unique=b,e.uniqueId=W,e.vals=F,e.varName=_,e.version=kp})); //# sourceMappingURL=vega-lite.min.js.map ================================================ FILE: docs/_static/js/vega@5.js ================================================ !function(t,e){"object"==typeof exports&&"undefined"!=typeof module?e(exports):"function"==typeof define&&define.amd?define(["exports"],e):e((t="undefined"!=typeof globalThis?globalThis:t||self).vega={})}(this,(function(t){"use strict";function e(t,e,n){return t.fields=e||[],t.fname=n,t}function n(t){return null==t?null:t.fname}function r(t){return null==t?null:t.fields}function i(t){return 1===t.length?o(t[0]):a(t)}const o=t=>function(e){return e[t]},a=t=>{const e=t.length;return function(n){for(let r=0;rr&&c(),u=r=i+1):"]"===o&&(u||s("Access path missing open bracket: "+t),u>0&&c(),u=0,r=i+1):i>r?c():r=i+1}return u&&s("Access path missing closing bracket: "+t),a&&s("Access path missing closing quote: "+t),i>r&&(i++,c()),e}function l(t,n,r){const o=u(t);return t=1===o.length?o[0]:t,e((r&&r.get||i)(o),[t],n||t)}const c=l("id"),f=e((t=>t),[],"identity"),h=e((()=>0),[],"zero"),d=e((()=>1),[],"one"),p=e((()=>!0),[],"true"),g=e((()=>!1),[],"false");function m(t,e,n){const r=[e].concat([].slice.call(n));console[t].apply(console,r)}const y=0,v=1,_=2,x=3,b=4;function w(t,e){let n=arguments.length>2&&void 0!==arguments[2]?arguments[2]:m,r=t||y;return{level(t){return arguments.length?(r=+t,this):r},error(){return r>=v&&n(e||"error","ERROR",arguments),this},warn(){return r>=_&&n(e||"warn","WARN",arguments),this},info(){return r>=x&&n(e||"log","INFO",arguments),this},debug(){return r>=b&&n(e||"log","DEBUG",arguments),this}}}var k=Array.isArray;function A(t){return t===Object(t)}const M=t=>"__proto__"!==t;function E(){for(var t=arguments.length,e=new Array(t),n=0;n{for(const n in e)if("signals"===n)t.signals=C(t.signals,e.signals);else{const r="legend"===n?{layout:1}:"style"===n||null;D(t,n,e[n],r)}return t}),{})}function D(t,e,n,r){if(!M(e))return;let i,o;if(A(n)&&!k(n))for(i in o=A(t[e])?t[e]:t[e]={},n)r&&(!0===r||r[i])?D(o,i,n[i]):M(i)&&(o[i]=n[i]);else t[e]=n}function C(t,e){if(null==t)return e;const n={},r=[];function i(t){n[t.name]||(n[t.name]=1,r.push(t))}return e.forEach(i),t.forEach(i),r}function F(t){return t[t.length-1]}function S(t){return null==t||""===t?null:+t}const $=t=>e=>t*Math.exp(e),T=t=>e=>Math.log(t*e),B=t=>e=>Math.sign(e)*Math.log1p(Math.abs(e/t)),z=t=>e=>Math.sign(e)*Math.expm1(Math.abs(e))*t,N=t=>e=>e<0?-Math.pow(-e,t):Math.pow(e,t);function O(t,e,n,r){const i=n(t[0]),o=n(F(t)),a=(o-i)*e;return[r(i-a),r(o-a)]}function R(t,e){return O(t,e,S,f)}function U(t,e){var n=Math.sign(t[0]);return O(t,e,T(n),$(n))}function L(t,e,n){return O(t,e,N(n),N(1/n))}function q(t,e,n){return O(t,e,B(n),z(n))}function P(t,e,n,r,i){const o=r(t[0]),a=r(F(t)),s=null!=e?r(e):(o+a)/2;return[i(s+(o-s)*n),i(s+(a-s)*n)]}function j(t,e,n){return P(t,e,n,S,f)}function I(t,e,n){const r=Math.sign(t[0]);return P(t,e,n,T(r),$(r))}function W(t,e,n,r){return P(t,e,n,N(r),N(1/r))}function H(t,e,n,r){return P(t,e,n,B(r),z(r))}function Y(t){return 1+~~(new Date(t).getMonth()/3)}function G(t){return 1+~~(new Date(t).getUTCMonth()/3)}function V(t){return null!=t?k(t)?t:[t]:[]}function X(t,e,n){let r,i=t[0],o=t[1];return o=n-e?[e,n]:[i=Math.min(Math.max(i,e),n-r),i+r]}function J(t){return"function"==typeof t}const Z="descending";function Q(t,n,i){i=i||{},n=V(n)||[];const o=[],a=[],s={},u=i.comparator||tt;return V(t).forEach(((t,e)=>{null!=t&&(o.push(n[e]===Z?-1:1),a.push(t=J(t)?t:l(t,null,i)),(r(t)||[]).forEach((t=>s[t]=1)))})),0===a.length?null:e(u(a,o),Object.keys(s))}const K=(t,e)=>(te||null==e)&&null!=t?1:(e=e instanceof Date?+e:e,(t=t instanceof Date?+t:t)!==t&&e==e?-1:e!=e&&t==t?1:0),tt=(t,e)=>1===t.length?et(t[0],e[0]):nt(t,e,t.length),et=(t,e)=>function(n,r){return K(t(n),t(r))*e},nt=(t,e,n)=>(e.push(0),function(r,i){let o,a=0,s=-1;for(;0===a&&++st}function it(t,e){let n;return r=>{n&&clearTimeout(n),n=setTimeout((()=>(e(r),n=null)),t)}}function ot(t){for(let e,n,r=1,i=arguments.length;ro&&(o=r))}else{for(r=e(t[a]);ao&&(o=r))}return[i,o]}function st(t,e){const n=t.length;let r,i,o,a,s,u=-1;if(null==e){for(;++u=i){r=o=i;break}if(u===n)return[-1,-1];for(a=s=u;++ui&&(r=i,a=u),o=i){r=o=i;break}if(u===n)return[-1,-1];for(a=s=u;++ui&&(r=i,a=u),or(t)?n[t]:void 0,set(t,e){return r(t)||(++i.size,n[t]===ct&&--i.empty),n[t]=e,this},delete(t){return r(t)&&(--i.size,++i.empty,n[t]=ct),this},clear(){i.size=i.empty=0,i.object=n={}},test(t){return arguments.length?(e=t,i):e},clean(){const t={};let r=0;for(const i in n){const o=n[i];o===ct||e&&e(o)||(t[i]=o,++r)}i.size=r,i.empty=0,i.object=n=t}};return t&&Object.keys(t).forEach((e=>{i.set(e,t[e])})),i}function ht(t,e,n,r,i,o){if(!n&&0!==n)return o;const a=+n;let s,u=t[0],l=F(t);la&&(i=o,o=a,a=i),r=void 0===r||r,((n=void 0===n||n)?o<=t:ot.replace(/\\(.)/g,"$1"))):V(t));const o=t&&t.length,a=r&&r.get||i,s=t=>a(n?[t]:u(t));let l;if(o)if(1===o){const e=s(t[0]);l=function(t){return""+e(t)}}else{const e=t.map(s);l=function(t){let n=""+e[0](t),r=0;for(;++r{e={},n={},r=0},o=(i,o)=>(++r>t&&(n=e,e={},r=1),e[i]=o);return i(),{clear:i,has:t=>lt(e,t)||lt(n,t),get:t=>lt(e,t)?e[t]:lt(n,t)?o(t,n[t]):void 0,set:(t,n)=>lt(e,t)?e[t]=n:o(t,n)}}function At(t,e,n,r){const i=e.length,o=n.length;if(!o)return e;if(!i)return n;const a=r||new e.constructor(i+o);let s=0,u=0,l=0;for(;s0?n[u++]:e[s++];for(;s=0;)n+=t;return n}function Et(t,e,n,r){const i=n||" ",o=t+"",a=e-o.length;return a<=0?o:"left"===r?Mt(i,a)+o:"center"===r?Mt(i,~~(a/2))+o+Mt(i,Math.ceil(a/2)):o+Mt(i,a)}function Dt(t){return t&&F(t)-t[0]||0}function Ct(t){return k(t)?"["+t.map(Ct)+"]":A(t)||xt(t)?JSON.stringify(t).replace("\u2028","\\u2028").replace("\u2029","\\u2029"):t}function Ft(t){return null==t||""===t?null:!(!t||"false"===t||"0"===t)&&!!t}const St=t=>vt(t)||mt(t)?t:Date.parse(t);function $t(t,e){return e=e||St,null==t||""===t?null:e(t)}function Tt(t){return null==t||""===t?null:t+""}function Bt(t){const e={},n=t.length;for(let r=0;r9999?"+"+It(e,6):It(e,4))+"-"+It(t.getUTCMonth()+1,2)+"-"+It(t.getUTCDate(),2)+(o?"T"+It(n,2)+":"+It(r,2)+":"+It(i,2)+"."+It(o,3)+"Z":i?"T"+It(n,2)+":"+It(r,2)+":"+It(i,2)+"Z":r||n?"T"+It(n,2)+":"+It(r,2)+"Z":"")}function Ht(t){var e=new RegExp('["'+t+"\n\r]"),n=t.charCodeAt(0);function r(t,e){var r,i=[],o=t.length,a=0,s=0,u=o<=0,l=!1;function c(){if(u)return Rt;if(l)return l=!1,Ot;var e,r,i=a;if(t.charCodeAt(i)===Ut){for(;a++=o?u=!0:(r=t.charCodeAt(a++))===Lt?l=!0:r===qt&&(l=!0,t.charCodeAt(a)===Lt&&++a),t.slice(i+1,e-1).replace(/""/g,'"')}for(;a1)r=function(t,e,n){var r,i=[],o=[];function a(t){var e=t<0?~t:t;(o[e]||(o[e]=[])).push({i:t,g:r})}function s(t){t.forEach(a)}function u(t){t.forEach(s)}function l(t){t.forEach(u)}function c(t){switch(r=t,t.type){case"GeometryCollection":t.geometries.forEach(c);break;case"LineString":s(t.arcs);break;case"MultiLineString":case"Polygon":u(t.arcs);break;case"MultiPolygon":l(t.arcs)}}return c(e),o.forEach(null==n?function(t){i.push(t[0].i)}:function(t){n(t[0].g,t[t.length-1].g)&&i.push(t[0].i)}),i}(0,e,n);else for(i=0,r=new Array(o=t.arcs.length);ie?1:t>=e?0:NaN}function te(t,e){return null==t||null==e?NaN:et?1:e>=t?0:NaN}function ee(t){let e,n,r;function i(t,r){let i=arguments.length>2&&void 0!==arguments[2]?arguments[2]:0,o=arguments.length>3&&void 0!==arguments[3]?arguments[3]:t.length;if(i>>1;n(t[e],r)<0?i=e+1:o=e}while(iKt(t(e),n),r=(e,n)=>t(e)-n):(e=t===Kt||t===te?t:ne,n=t,r=t),{left:i,center:function(t,e){let n=arguments.length>2&&void 0!==arguments[2]?arguments[2]:0;const o=i(t,e,n,(arguments.length>3&&void 0!==arguments[3]?arguments[3]:t.length)-1);return o>n&&r(t[o-1],e)>-r(t[o],e)?o-1:o},right:function(t,r){let i=arguments.length>2&&void 0!==arguments[2]?arguments[2]:0,o=arguments.length>3&&void 0!==arguments[3]?arguments[3]:t.length;if(i>>1;n(t[e],r)<=0?i=e+1:o=e}while(i0){for(o=t[--i];i>0&&(e=o,n=t[--i],o=e+n,r=n-(o-e),!r););i>0&&(r<0&&t[i-1]<0||r>0&&t[i-1]>0)&&(n=2*r,e=o+n,n==e-o&&(o=e))}return o}}class ue extends Map{constructor(t){let e=arguments.length>1&&void 0!==arguments[1]?arguments[1]:de;if(super(),Object.defineProperties(this,{_intern:{value:new Map},_key:{value:e}}),null!=t)for(const[e,n]of t)this.set(e,n)}get(t){return super.get(ce(this,t))}has(t){return super.has(ce(this,t))}set(t,e){return super.set(fe(this,t),e)}delete(t){return super.delete(he(this,t))}}class le extends Set{constructor(t){let e=arguments.length>1&&void 0!==arguments[1]?arguments[1]:de;if(super(),Object.defineProperties(this,{_intern:{value:new Map},_key:{value:e}}),null!=t)for(const e of t)this.add(e)}has(t){return super.has(ce(this,t))}add(t){return super.add(fe(this,t))}delete(t){return super.delete(he(this,t))}}function ce(t,e){let{_intern:n,_key:r}=t;const i=r(e);return n.has(i)?n.get(i):e}function fe(t,e){let{_intern:n,_key:r}=t;const i=r(e);return n.has(i)?n.get(i):(n.set(i,e),e)}function he(t,e){let{_intern:n,_key:r}=t;const i=r(e);return n.has(i)&&(e=n.get(i),n.delete(i)),e}function de(t){return null!==t&&"object"==typeof t?t.valueOf():t}function pe(t,e){return(null==t||!(t>=t))-(null==e||!(e>=e))||(te?1:0)}const ge=Math.sqrt(50),me=Math.sqrt(10),ye=Math.sqrt(2);function ve(t,e,n){const r=(e-t)/Math.max(0,n),i=Math.floor(Math.log10(r)),o=r/Math.pow(10,i),a=o>=ge?10:o>=me?5:o>=ye?2:1;let s,u,l;return i<0?(l=Math.pow(10,-i)/a,s=Math.round(t*l),u=Math.round(e*l),s/le&&--u,l=-l):(l=Math.pow(10,i)*a,s=Math.round(t/l),u=Math.round(e/l),s*le&&--u),u0))return[];if((t=+t)===(e=+e))return[t];const r=e=i))return[];const s=o-i+1,u=new Array(s);if(r)if(a<0)for(let t=0;t=e)&&(n=e);else{let r=-1;for(let i of t)null!=(i=e(i,++r,t))&&(n=i)&&(n=i)}return n}function ke(t,e){let n;if(void 0===e)for(const e of t)null!=e&&(n>e||void 0===n&&e>=e)&&(n=e);else{let r=-1;for(let i of t)null!=(i=e(i,++r,t))&&(n>i||void 0===n&&i>=i)&&(n=i)}return n}function Ae(t,e){let n=arguments.length>2&&void 0!==arguments[2]?arguments[2]:0,r=arguments.length>3&&void 0!==arguments[3]?arguments[3]:1/0,i=arguments.length>4?arguments[4]:void 0;if(e=Math.floor(e),n=Math.floor(Math.max(0,n)),r=Math.floor(Math.min(t.length-1,r)),!(n<=e&&e<=r))return t;for(i=void 0===i?pe:function(){let t=arguments.length>0&&void 0!==arguments[0]?arguments[0]:Kt;if(t===Kt)return pe;if("function"!=typeof t)throw new TypeError("compare is not a function");return(e,n)=>{const r=t(e,n);return r||0===r?r:(0===t(n,n))-(0===t(e,e))}}(i);r>n;){if(r-n>600){const o=r-n+1,a=e-n+1,s=Math.log(o),u=.5*Math.exp(2*s/3),l=.5*Math.sqrt(s*u*(o-u)/o)*(a-o/2<0?-1:1);Ae(t,e,Math.max(n,Math.floor(e-a*u/o+l)),Math.min(r,Math.floor(e+(o-a)*u/o+l)),i)}const o=t[e];let a=n,s=r;for(Me(t,n,e),i(t[r],o)>0&&Me(t,n,r);a0;)--s}0===i(t[n],o)?Me(t,n,s):(++s,Me(t,s,r)),s<=e&&(n=s+1),e<=s&&(r=s-1)}return t}function Me(t,e,n){const r=t[e];t[e]=t[n],t[n]=r}function Ee(t,e,n){if(t=Float64Array.from(function*(t,e){if(void 0===e)for(let e of t)null!=e&&(e=+e)>=e&&(yield e);else{let n=-1;for(let r of t)null!=(r=e(r,++n,t))&&(r=+r)>=r&&(yield r)}}(t,n)),(r=t.length)&&!isNaN(e=+e)){if(e<=0||r<2)return ke(t);if(e>=1)return we(t);var r,i=(r-1)*e,o=Math.floor(i),a=we(Ae(t,o).subarray(0,o+1));return a+(ke(t.subarray(o+1))-a)*(i-o)}}function De(t,e){let n=arguments.length>2&&void 0!==arguments[2]?arguments[2]:re;if((r=t.length)&&!isNaN(e=+e)){if(e<=0||r<2)return+n(t[0],0,t);if(e>=1)return+n(t[r-1],r-1,t);var r,i=(r-1)*e,o=Math.floor(i),a=+n(t[o],o,t);return a+(+n(t[o+1],o+1,t)-a)*(i-o)}}function Ce(t,e){return Ee(t,.5,e)}function Fe(t){return Array.from(function*(t){for(const e of t)yield*e}(t))}function Se(t,e,n){t=+t,e=+e,n=(i=arguments.length)<2?(e=t,t=0,1):i<3?1:+n;for(var r=-1,i=0|Math.max(0,Math.ceil((e-t)/n)),o=new Array(i);++r1?r[0]+r.slice(2):r,+t.slice(n+1)]}function ze(t){return(t=Be(Math.abs(t)))?t[1]:NaN}var Ne,Oe=/^(?:(.)?([<>=^]))?([+\-( ])?([$#])?(0)?(\d+)?(,)?(\.\d+)?(~)?([a-z%])?$/i;function Re(t){if(!(e=Oe.exec(t)))throw new Error("invalid format: "+t);var e;return new Ue({fill:e[1],align:e[2],sign:e[3],symbol:e[4],zero:e[5],width:e[6],comma:e[7],precision:e[8]&&e[8].slice(1),trim:e[9],type:e[10]})}function Ue(t){this.fill=void 0===t.fill?" ":t.fill+"",this.align=void 0===t.align?">":t.align+"",this.sign=void 0===t.sign?"-":t.sign+"",this.symbol=void 0===t.symbol?"":t.symbol+"",this.zero=!!t.zero,this.width=void 0===t.width?void 0:+t.width,this.comma=!!t.comma,this.precision=void 0===t.precision?void 0:+t.precision,this.trim=!!t.trim,this.type=void 0===t.type?"":t.type+""}function Le(t,e){var n=Be(t,e);if(!n)return t+"";var r=n[0],i=n[1];return i<0?"0."+new Array(-i).join("0")+r:r.length>i+1?r.slice(0,i+1)+"."+r.slice(i+1):r+new Array(i-r.length+2).join("0")}Re.prototype=Ue.prototype,Ue.prototype.toString=function(){return this.fill+this.align+this.sign+this.symbol+(this.zero?"0":"")+(void 0===this.width?"":Math.max(1,0|this.width))+(this.comma?",":"")+(void 0===this.precision?"":"."+Math.max(0,0|this.precision))+(this.trim?"~":"")+this.type};var qe={"%":(t,e)=>(100*t).toFixed(e),b:t=>Math.round(t).toString(2),c:t=>t+"",d:function(t){return Math.abs(t=Math.round(t))>=1e21?t.toLocaleString("en").replace(/,/g,""):t.toString(10)},e:(t,e)=>t.toExponential(e),f:(t,e)=>t.toFixed(e),g:(t,e)=>t.toPrecision(e),o:t=>Math.round(t).toString(8),p:(t,e)=>Le(100*t,e),r:Le,s:function(t,e){var n=Be(t,e);if(!n)return t+"";var r=n[0],i=n[1],o=i-(Ne=3*Math.max(-8,Math.min(8,Math.floor(i/3))))+1,a=r.length;return o===a?r:o>a?r+new Array(o-a+1).join("0"):o>0?r.slice(0,o)+"."+r.slice(o):"0."+new Array(1-o).join("0")+Be(t,Math.max(0,e+o-1))[0]},X:t=>Math.round(t).toString(16).toUpperCase(),x:t=>Math.round(t).toString(16)};function Pe(t){return t}var je,Ie,We,He=Array.prototype.map,Ye=["y","z","a","f","p","n","µ","m","","k","M","G","T","P","E","Z","Y"];function Ge(t){var e,n,r=void 0===t.grouping||void 0===t.thousands?Pe:(e=He.call(t.grouping,Number),n=t.thousands+"",function(t,r){for(var i=t.length,o=[],a=0,s=e[0],u=0;i>0&&s>0&&(u+s+1>r&&(s=Math.max(1,r-u)),o.push(t.substring(i-=s,i+s)),!((u+=s+1)>r));)s=e[a=(a+1)%e.length];return o.reverse().join(n)}),i=void 0===t.currency?"":t.currency[0]+"",o=void 0===t.currency?"":t.currency[1]+"",a=void 0===t.decimal?".":t.decimal+"",s=void 0===t.numerals?Pe:function(t){return function(e){return e.replace(/[0-9]/g,(function(e){return t[+e]}))}}(He.call(t.numerals,String)),u=void 0===t.percent?"%":t.percent+"",l=void 0===t.minus?"−":t.minus+"",c=void 0===t.nan?"NaN":t.nan+"";function f(t){var e=(t=Re(t)).fill,n=t.align,f=t.sign,h=t.symbol,d=t.zero,p=t.width,g=t.comma,m=t.precision,y=t.trim,v=t.type;"n"===v?(g=!0,v="g"):qe[v]||(void 0===m&&(m=12),y=!0,v="g"),(d||"0"===e&&"="===n)&&(d=!0,e="0",n="=");var _="$"===h?i:"#"===h&&/[boxX]/.test(v)?"0"+v.toLowerCase():"",x="$"===h?o:/[%p]/.test(v)?u:"",b=qe[v],w=/[defgprs%]/.test(v);function k(t){var i,o,u,h=_,k=x;if("c"===v)k=b(t)+k,t="";else{var A=(t=+t)<0||1/t<0;if(t=isNaN(t)?c:b(Math.abs(t),m),y&&(t=function(t){t:for(var e,n=t.length,r=1,i=-1;r0&&(i=0)}return i>0?t.slice(0,i)+t.slice(e+1):t}(t)),A&&0==+t&&"+"!==f&&(A=!1),h=(A?"("===f?f:l:"-"===f||"("===f?"":f)+h,k=("s"===v?Ye[8+Ne/3]:"")+k+(A&&"("===f?")":""),w)for(i=-1,o=t.length;++i(u=t.charCodeAt(i))||u>57){k=(46===u?a+t.slice(i+1):t.slice(i))+k,t=t.slice(0,i);break}}g&&!d&&(t=r(t,1/0));var M=h.length+t.length+k.length,E=M>1)+h+t+k+E.slice(M);break;default:t=E+h+t+k}return s(t)}return m=void 0===m?6:/[gprs]/.test(v)?Math.max(1,Math.min(21,m)):Math.max(0,Math.min(20,m)),k.toString=function(){return t+""},k}return{format:f,formatPrefix:function(t,e){var n=f(((t=Re(t)).type="f",t)),r=3*Math.max(-8,Math.min(8,Math.floor(ze(e)/3))),i=Math.pow(10,-r),o=Ye[8+r/3];return function(t){return n(i*t)+o}}}}function Ve(t){return Math.max(0,-ze(Math.abs(t)))}function Xe(t,e){return Math.max(0,3*Math.max(-8,Math.min(8,Math.floor(ze(e)/3)))-ze(Math.abs(t)))}function Je(t,e){return t=Math.abs(t),e=Math.abs(e)-t,Math.max(0,ze(e)-ze(t))+1}!function(t){je=Ge(t),Ie=je.format,We=je.formatPrefix}({thousands:",",grouping:[3],currency:["$",""]});const Ze=new Date,Qe=new Date;function Ke(t,e,n,r){function i(e){return t(e=0===arguments.length?new Date:new Date(+e)),e}return i.floor=e=>(t(e=new Date(+e)),e),i.ceil=n=>(t(n=new Date(n-1)),e(n,1),t(n),n),i.round=t=>{const e=i(t),n=i.ceil(t);return t-e(e(t=new Date(+t),null==n?1:Math.floor(n)),t),i.range=(n,r,o)=>{const a=[];if(n=i.ceil(n),o=null==o?1:Math.floor(o),!(n0))return a;let s;do{a.push(s=new Date(+n)),e(n,o),t(n)}while(sKe((e=>{if(e>=e)for(;t(e),!n(e);)e.setTime(e-1)}),((t,r)=>{if(t>=t)if(r<0)for(;++r<=0;)for(;e(t,-1),!n(t););else for(;--r>=0;)for(;e(t,1),!n(t););})),n&&(i.count=(e,r)=>(Ze.setTime(+e),Qe.setTime(+r),t(Ze),t(Qe),Math.floor(n(Ze,Qe))),i.every=t=>(t=Math.floor(t),isFinite(t)&&t>0?t>1?i.filter(r?e=>r(e)%t==0:e=>i.count(0,e)%t==0):i:null)),i}const tn=Ke((()=>{}),((t,e)=>{t.setTime(+t+e)}),((t,e)=>e-t));tn.every=t=>(t=Math.floor(t),isFinite(t)&&t>0?t>1?Ke((e=>{e.setTime(Math.floor(e/t)*t)}),((e,n)=>{e.setTime(+e+n*t)}),((e,n)=>(n-e)/t)):tn:null),tn.range;const en=1e3,nn=6e4,rn=36e5,on=864e5,an=6048e5,sn=2592e6,un=31536e6,ln=Ke((t=>{t.setTime(t-t.getMilliseconds())}),((t,e)=>{t.setTime(+t+e*en)}),((t,e)=>(e-t)/en),(t=>t.getUTCSeconds()));ln.range;const cn=Ke((t=>{t.setTime(t-t.getMilliseconds()-t.getSeconds()*en)}),((t,e)=>{t.setTime(+t+e*nn)}),((t,e)=>(e-t)/nn),(t=>t.getMinutes()));cn.range;const fn=Ke((t=>{t.setUTCSeconds(0,0)}),((t,e)=>{t.setTime(+t+e*nn)}),((t,e)=>(e-t)/nn),(t=>t.getUTCMinutes()));fn.range;const hn=Ke((t=>{t.setTime(t-t.getMilliseconds()-t.getSeconds()*en-t.getMinutes()*nn)}),((t,e)=>{t.setTime(+t+e*rn)}),((t,e)=>(e-t)/rn),(t=>t.getHours()));hn.range;const dn=Ke((t=>{t.setUTCMinutes(0,0,0)}),((t,e)=>{t.setTime(+t+e*rn)}),((t,e)=>(e-t)/rn),(t=>t.getUTCHours()));dn.range;const pn=Ke((t=>t.setHours(0,0,0,0)),((t,e)=>t.setDate(t.getDate()+e)),((t,e)=>(e-t-(e.getTimezoneOffset()-t.getTimezoneOffset())*nn)/on),(t=>t.getDate()-1));pn.range;const gn=Ke((t=>{t.setUTCHours(0,0,0,0)}),((t,e)=>{t.setUTCDate(t.getUTCDate()+e)}),((t,e)=>(e-t)/on),(t=>t.getUTCDate()-1));gn.range;const mn=Ke((t=>{t.setUTCHours(0,0,0,0)}),((t,e)=>{t.setUTCDate(t.getUTCDate()+e)}),((t,e)=>(e-t)/on),(t=>Math.floor(t/on)));function yn(t){return Ke((e=>{e.setDate(e.getDate()-(e.getDay()+7-t)%7),e.setHours(0,0,0,0)}),((t,e)=>{t.setDate(t.getDate()+7*e)}),((t,e)=>(e-t-(e.getTimezoneOffset()-t.getTimezoneOffset())*nn)/an))}mn.range;const vn=yn(0),_n=yn(1),xn=yn(2),bn=yn(3),wn=yn(4),kn=yn(5),An=yn(6);function Mn(t){return Ke((e=>{e.setUTCDate(e.getUTCDate()-(e.getUTCDay()+7-t)%7),e.setUTCHours(0,0,0,0)}),((t,e)=>{t.setUTCDate(t.getUTCDate()+7*e)}),((t,e)=>(e-t)/an))}vn.range,_n.range,xn.range,bn.range,wn.range,kn.range,An.range;const En=Mn(0),Dn=Mn(1),Cn=Mn(2),Fn=Mn(3),Sn=Mn(4),$n=Mn(5),Tn=Mn(6);En.range,Dn.range,Cn.range,Fn.range,Sn.range,$n.range,Tn.range;const Bn=Ke((t=>{t.setDate(1),t.setHours(0,0,0,0)}),((t,e)=>{t.setMonth(t.getMonth()+e)}),((t,e)=>e.getMonth()-t.getMonth()+12*(e.getFullYear()-t.getFullYear())),(t=>t.getMonth()));Bn.range;const zn=Ke((t=>{t.setUTCDate(1),t.setUTCHours(0,0,0,0)}),((t,e)=>{t.setUTCMonth(t.getUTCMonth()+e)}),((t,e)=>e.getUTCMonth()-t.getUTCMonth()+12*(e.getUTCFullYear()-t.getUTCFullYear())),(t=>t.getUTCMonth()));zn.range;const Nn=Ke((t=>{t.setMonth(0,1),t.setHours(0,0,0,0)}),((t,e)=>{t.setFullYear(t.getFullYear()+e)}),((t,e)=>e.getFullYear()-t.getFullYear()),(t=>t.getFullYear()));Nn.every=t=>isFinite(t=Math.floor(t))&&t>0?Ke((e=>{e.setFullYear(Math.floor(e.getFullYear()/t)*t),e.setMonth(0,1),e.setHours(0,0,0,0)}),((e,n)=>{e.setFullYear(e.getFullYear()+n*t)})):null,Nn.range;const On=Ke((t=>{t.setUTCMonth(0,1),t.setUTCHours(0,0,0,0)}),((t,e)=>{t.setUTCFullYear(t.getUTCFullYear()+e)}),((t,e)=>e.getUTCFullYear()-t.getUTCFullYear()),(t=>t.getUTCFullYear()));function Rn(t,e,n,r,i,o){const a=[[ln,1,en],[ln,5,5e3],[ln,15,15e3],[ln,30,3e4],[o,1,nn],[o,5,3e5],[o,15,9e5],[o,30,18e5],[i,1,rn],[i,3,108e5],[i,6,216e5],[i,12,432e5],[r,1,on],[r,2,1728e5],[n,1,an],[e,1,sn],[e,3,7776e6],[t,1,un]];function s(e,n,r){const i=Math.abs(n-e)/r,o=ee((t=>{let[,,e]=t;return e})).right(a,i);if(o===a.length)return t.every(be(e/un,n/un,r));if(0===o)return tn.every(Math.max(be(e,n,r),1));const[s,u]=a[i/a[o-1][2]isFinite(t=Math.floor(t))&&t>0?Ke((e=>{e.setUTCFullYear(Math.floor(e.getUTCFullYear()/t)*t),e.setUTCMonth(0,1),e.setUTCHours(0,0,0,0)}),((e,n)=>{e.setUTCFullYear(e.getUTCFullYear()+n*t)})):null,On.range;const[Un,Ln]=Rn(On,zn,En,mn,dn,fn),[qn,Pn]=Rn(Nn,Bn,vn,pn,hn,cn),jn="year",In="quarter",Wn="month",Hn="week",Yn="date",Gn="day",Vn="dayofyear",Xn="hours",Jn="minutes",Zn="seconds",Qn="milliseconds",Kn=[jn,In,Wn,Hn,Yn,Gn,Vn,Xn,Jn,Zn,Qn],tr=Kn.reduce(((t,e,n)=>(t[e]=1+n,t)),{});function er(t){const e=V(t).slice(),n={};e.length||s("Missing time unit."),e.forEach((t=>{lt(tr,t)?n[t]=1:s(`Invalid time unit: ${t}.`)}));return(n[Hn]||n[Gn]?1:0)+(n[In]||n[Wn]||n[Yn]?1:0)+(n[Vn]?1:0)>1&&s(`Incompatible time units: ${t}`),e.sort(((t,e)=>tr[t]-tr[e])),e}const nr={[jn]:"%Y ",[In]:"Q%q ",[Wn]:"%b ",[Yn]:"%d ",[Hn]:"W%U ",[Gn]:"%a ",[Vn]:"%j ",[Xn]:"%H:00",[Jn]:"00:%M",[Zn]:":%S",[Qn]:".%L",[`${jn}-${Wn}`]:"%Y-%m ",[`${jn}-${Wn}-${Yn}`]:"%Y-%m-%d ",[`${Xn}-${Jn}`]:"%H:%M"};function rr(t,e){const n=ot({},nr,e),r=er(t),i=r.length;let o,a,s="",u=0;for(u=0;uu;--o)if(a=r.slice(u,o).join("-"),null!=n[a]){s+=n[a],u=o;break}return s.trim()}const ir=new Date;function or(t){return ir.setFullYear(t),ir.setMonth(0),ir.setDate(1),ir.setHours(0,0,0,0),ir}function ar(t){return ur(new Date(t))}function sr(t){return lr(new Date(t))}function ur(t){return pn.count(or(t.getFullYear())-1,t)}function lr(t){return vn.count(or(t.getFullYear())-1,t)}function cr(t){return or(t).getDay()}function fr(t,e,n,r,i,o,a){if(0<=t&&t<100){const s=new Date(-1,e,n,r,i,o,a);return s.setFullYear(t),s}return new Date(t,e,n,r,i,o,a)}function hr(t){return pr(new Date(t))}function dr(t){return gr(new Date(t))}function pr(t){const e=Date.UTC(t.getUTCFullYear(),0,1);return gn.count(e-1,t)}function gr(t){const e=Date.UTC(t.getUTCFullYear(),0,1);return En.count(e-1,t)}function mr(t){return ir.setTime(Date.UTC(t,0,1)),ir.getUTCDay()}function yr(t,e,n,r,i,o,a){if(0<=t&&t<100){const t=new Date(Date.UTC(-1,e,n,r,i,o,a));return t.setUTCFullYear(n.y),t}return new Date(Date.UTC(t,e,n,r,i,o,a))}function vr(t,e,n,r,i){const o=e||1,a=F(t),s=(t,e,i)=>function(t,e,n,r){const i=n<=1?t:r?(e,i)=>r+n*Math.floor((t(e,i)-r)/n):(e,r)=>n*Math.floor(t(e,r)/n);return e?(t,n)=>e(i(t,n),n):i}(n[i=i||t],r[i],t===a&&o,e),u=new Date,l=Bt(t),c=l[jn]?s(jn):rt(2012),f=l[Wn]?s(Wn):l[In]?s(In):h,p=l[Hn]&&l[Gn]?s(Gn,1,Hn+Gn):l[Hn]?s(Hn,1):l[Gn]?s(Gn,1):l[Yn]?s(Yn,1):l[Vn]?s(Vn,1):d,g=l[Xn]?s(Xn):h,m=l[Jn]?s(Jn):h,y=l[Zn]?s(Zn):h,v=l[Qn]?s(Qn):h;return function(t){u.setTime(+t);const e=c(u);return i(e,f(u),p(u,e),g(u),m(u),y(u),v(u))}}function _r(t,e,n){return e+7*t-(n+6)%7}const xr={[jn]:t=>t.getFullYear(),[In]:t=>Math.floor(t.getMonth()/3),[Wn]:t=>t.getMonth(),[Yn]:t=>t.getDate(),[Xn]:t=>t.getHours(),[Jn]:t=>t.getMinutes(),[Zn]:t=>t.getSeconds(),[Qn]:t=>t.getMilliseconds(),[Vn]:t=>ur(t),[Hn]:t=>lr(t),[Hn+Gn]:(t,e)=>_r(lr(t),t.getDay(),cr(e)),[Gn]:(t,e)=>_r(1,t.getDay(),cr(e))},br={[In]:t=>3*t,[Hn]:(t,e)=>_r(t,0,cr(e))};function wr(t,e){return vr(t,e||1,xr,br,fr)}const kr={[jn]:t=>t.getUTCFullYear(),[In]:t=>Math.floor(t.getUTCMonth()/3),[Wn]:t=>t.getUTCMonth(),[Yn]:t=>t.getUTCDate(),[Xn]:t=>t.getUTCHours(),[Jn]:t=>t.getUTCMinutes(),[Zn]:t=>t.getUTCSeconds(),[Qn]:t=>t.getUTCMilliseconds(),[Vn]:t=>pr(t),[Hn]:t=>gr(t),[Gn]:(t,e)=>_r(1,t.getUTCDay(),mr(e)),[Hn+Gn]:(t,e)=>_r(gr(t),t.getUTCDay(),mr(e))},Ar={[In]:t=>3*t,[Hn]:(t,e)=>_r(t,0,mr(e))};function Mr(t,e){return vr(t,e||1,kr,Ar,yr)}const Er={[jn]:Nn,[In]:Bn.every(3),[Wn]:Bn,[Hn]:vn,[Yn]:pn,[Gn]:pn,[Vn]:pn,[Xn]:hn,[Jn]:cn,[Zn]:ln,[Qn]:tn},Dr={[jn]:On,[In]:zn.every(3),[Wn]:zn,[Hn]:En,[Yn]:gn,[Gn]:gn,[Vn]:gn,[Xn]:dn,[Jn]:fn,[Zn]:ln,[Qn]:tn};function Cr(t){return Er[t]}function Fr(t){return Dr[t]}function Sr(t,e,n){return t?t.offset(e,n):void 0}function $r(t,e,n){return Sr(Cr(t),e,n)}function Tr(t,e,n){return Sr(Fr(t),e,n)}function Br(t,e,n,r){return t?t.range(e,n,r):void 0}function zr(t,e,n,r){return Br(Cr(t),e,n,r)}function Nr(t,e,n,r){return Br(Fr(t),e,n,r)}const Or=1e3,Rr=6e4,Ur=36e5,Lr=864e5,qr=2592e6,Pr=31536e6,jr=[jn,Wn,Yn,Xn,Jn,Zn,Qn],Ir=jr.slice(0,-1),Wr=Ir.slice(0,-1),Hr=Wr.slice(0,-1),Yr=Hr.slice(0,-1),Gr=[jn,Wn],Vr=[jn],Xr=[[Ir,1,Or],[Ir,5,5e3],[Ir,15,15e3],[Ir,30,3e4],[Wr,1,Rr],[Wr,5,3e5],[Wr,15,9e5],[Wr,30,18e5],[Hr,1,Ur],[Hr,3,108e5],[Hr,6,216e5],[Hr,12,432e5],[Yr,1,Lr],[[jn,Hn],1,6048e5],[Gr,1,qr],[Gr,3,7776e6],[Vr,1,Pr]];function Jr(t){const e=t.extent,n=t.maxbins||40,r=Math.abs(Dt(e))/n;let i,o,a=ee((t=>t[2])).right(Xr,r);return a===Xr.length?(i=Vr,o=be(e[0]/Pr,e[1]/Pr,n)):a?(a=Xr[r/Xr[a-1][2]=12)]},q:function(t){return 1+~~(t.getMonth()/3)},Q:wo,s:ko,S:ji,u:Ii,U:Wi,V:Yi,w:Gi,W:Vi,x:null,X:null,y:Xi,Y:Zi,Z:Ki,"%":bo},x={a:function(t){return a[t.getUTCDay()]},A:function(t){return o[t.getUTCDay()]},b:function(t){return u[t.getUTCMonth()]},B:function(t){return s[t.getUTCMonth()]},c:null,d:to,e:to,f:oo,g:yo,G:_o,H:eo,I:no,j:ro,L:io,m:ao,M:so,p:function(t){return i[+(t.getUTCHours()>=12)]},q:function(t){return 1+~~(t.getUTCMonth()/3)},Q:wo,s:ko,S:uo,u:lo,U:co,V:ho,w:po,W:go,x:null,X:null,y:mo,Y:vo,Z:xo,"%":bo},b={a:function(t,e,n){var r=d.exec(e.slice(n));return r?(t.w=p.get(r[0].toLowerCase()),n+r[0].length):-1},A:function(t,e,n){var r=f.exec(e.slice(n));return r?(t.w=h.get(r[0].toLowerCase()),n+r[0].length):-1},b:function(t,e,n){var r=y.exec(e.slice(n));return r?(t.m=v.get(r[0].toLowerCase()),n+r[0].length):-1},B:function(t,e,n){var r=g.exec(e.slice(n));return r?(t.m=m.get(r[0].toLowerCase()),n+r[0].length):-1},c:function(t,n,r){return A(t,e,n,r)},d:Ai,e:Ai,f:Si,g:xi,G:_i,H:Ei,I:Ei,j:Mi,L:Fi,m:ki,M:Di,p:function(t,e,n){var r=l.exec(e.slice(n));return r?(t.p=c.get(r[0].toLowerCase()),n+r[0].length):-1},q:wi,Q:Ti,s:Bi,S:Ci,u:gi,U:mi,V:yi,w:pi,W:vi,x:function(t,e,r){return A(t,n,e,r)},X:function(t,e,n){return A(t,r,e,n)},y:xi,Y:_i,Z:bi,"%":$i};function w(t,e){return function(n){var r,i,o,a=[],s=-1,u=0,l=t.length;for(n instanceof Date||(n=new Date(+n));++s53)return null;"w"in o||(o.w=1),"Z"in o?(i=(r=Qr(Kr(o.y,0,1))).getUTCDay(),r=i>4||0===i?Dn.ceil(r):Dn(r),r=gn.offset(r,7*(o.V-1)),o.y=r.getUTCFullYear(),o.m=r.getUTCMonth(),o.d=r.getUTCDate()+(o.w+6)%7):(i=(r=Zr(Kr(o.y,0,1))).getDay(),r=i>4||0===i?_n.ceil(r):_n(r),r=pn.offset(r,7*(o.V-1)),o.y=r.getFullYear(),o.m=r.getMonth(),o.d=r.getDate()+(o.w+6)%7)}else("W"in o||"U"in o)&&("w"in o||(o.w="u"in o?o.u%7:"W"in o?1:0),i="Z"in o?Qr(Kr(o.y,0,1)).getUTCDay():Zr(Kr(o.y,0,1)).getDay(),o.m=0,o.d="W"in o?(o.w+6)%7+7*o.W-(i+5)%7:o.w+7*o.U-(i+6)%7);return"Z"in o?(o.H+=o.Z/100|0,o.M+=o.Z%100,Qr(o)):Zr(o)}}function A(t,e,n,r){for(var i,o,a=0,s=e.length,u=n.length;a=u)return-1;if(37===(i=e.charCodeAt(a++))){if(i=e.charAt(a++),!(o=b[i in ai?e.charAt(a++):i])||(r=o(t,n,r))<0)return-1}else if(i!=n.charCodeAt(r++))return-1}return r}return _.x=w(n,_),_.X=w(r,_),_.c=w(e,_),x.x=w(n,x),x.X=w(r,x),x.c=w(e,x),{format:function(t){var e=w(t+="",_);return e.toString=function(){return t},e},parse:function(t){var e=k(t+="",!1);return e.toString=function(){return t},e},utcFormat:function(t){var e=w(t+="",x);return e.toString=function(){return t},e},utcParse:function(t){var e=k(t+="",!0);return e.toString=function(){return t},e}}}var ei,ni,ri,ii,oi,ai={"-":"",_:" ",0:"0"},si=/^\s*\d+/,ui=/^%/,li=/[\\^$*+?|[\]().{}]/g;function ci(t,e,n){var r=t<0?"-":"",i=(r?-t:t)+"",o=i.length;return r+(o[t.toLowerCase(),e])))}function pi(t,e,n){var r=si.exec(e.slice(n,n+1));return r?(t.w=+r[0],n+r[0].length):-1}function gi(t,e,n){var r=si.exec(e.slice(n,n+1));return r?(t.u=+r[0],n+r[0].length):-1}function mi(t,e,n){var r=si.exec(e.slice(n,n+2));return r?(t.U=+r[0],n+r[0].length):-1}function yi(t,e,n){var r=si.exec(e.slice(n,n+2));return r?(t.V=+r[0],n+r[0].length):-1}function vi(t,e,n){var r=si.exec(e.slice(n,n+2));return r?(t.W=+r[0],n+r[0].length):-1}function _i(t,e,n){var r=si.exec(e.slice(n,n+4));return r?(t.y=+r[0],n+r[0].length):-1}function xi(t,e,n){var r=si.exec(e.slice(n,n+2));return r?(t.y=+r[0]+(+r[0]>68?1900:2e3),n+r[0].length):-1}function bi(t,e,n){var r=/^(Z)|([+-]\d\d)(?::?(\d\d))?/.exec(e.slice(n,n+6));return r?(t.Z=r[1]?0:-(r[2]+(r[3]||"00")),n+r[0].length):-1}function wi(t,e,n){var r=si.exec(e.slice(n,n+1));return r?(t.q=3*r[0]-3,n+r[0].length):-1}function ki(t,e,n){var r=si.exec(e.slice(n,n+2));return r?(t.m=r[0]-1,n+r[0].length):-1}function Ai(t,e,n){var r=si.exec(e.slice(n,n+2));return r?(t.d=+r[0],n+r[0].length):-1}function Mi(t,e,n){var r=si.exec(e.slice(n,n+3));return r?(t.m=0,t.d=+r[0],n+r[0].length):-1}function Ei(t,e,n){var r=si.exec(e.slice(n,n+2));return r?(t.H=+r[0],n+r[0].length):-1}function Di(t,e,n){var r=si.exec(e.slice(n,n+2));return r?(t.M=+r[0],n+r[0].length):-1}function Ci(t,e,n){var r=si.exec(e.slice(n,n+2));return r?(t.S=+r[0],n+r[0].length):-1}function Fi(t,e,n){var r=si.exec(e.slice(n,n+3));return r?(t.L=+r[0],n+r[0].length):-1}function Si(t,e,n){var r=si.exec(e.slice(n,n+6));return r?(t.L=Math.floor(r[0]/1e3),n+r[0].length):-1}function $i(t,e,n){var r=ui.exec(e.slice(n,n+1));return r?n+r[0].length:-1}function Ti(t,e,n){var r=si.exec(e.slice(n));return r?(t.Q=+r[0],n+r[0].length):-1}function Bi(t,e,n){var r=si.exec(e.slice(n));return r?(t.s=+r[0],n+r[0].length):-1}function zi(t,e){return ci(t.getDate(),e,2)}function Ni(t,e){return ci(t.getHours(),e,2)}function Oi(t,e){return ci(t.getHours()%12||12,e,2)}function Ri(t,e){return ci(1+pn.count(Nn(t),t),e,3)}function Ui(t,e){return ci(t.getMilliseconds(),e,3)}function Li(t,e){return Ui(t,e)+"000"}function qi(t,e){return ci(t.getMonth()+1,e,2)}function Pi(t,e){return ci(t.getMinutes(),e,2)}function ji(t,e){return ci(t.getSeconds(),e,2)}function Ii(t){var e=t.getDay();return 0===e?7:e}function Wi(t,e){return ci(vn.count(Nn(t)-1,t),e,2)}function Hi(t){var e=t.getDay();return e>=4||0===e?wn(t):wn.ceil(t)}function Yi(t,e){return t=Hi(t),ci(wn.count(Nn(t),t)+(4===Nn(t).getDay()),e,2)}function Gi(t){return t.getDay()}function Vi(t,e){return ci(_n.count(Nn(t)-1,t),e,2)}function Xi(t,e){return ci(t.getFullYear()%100,e,2)}function Ji(t,e){return ci((t=Hi(t)).getFullYear()%100,e,2)}function Zi(t,e){return ci(t.getFullYear()%1e4,e,4)}function Qi(t,e){var n=t.getDay();return ci((t=n>=4||0===n?wn(t):wn.ceil(t)).getFullYear()%1e4,e,4)}function Ki(t){var e=t.getTimezoneOffset();return(e>0?"-":(e*=-1,"+"))+ci(e/60|0,"0",2)+ci(e%60,"0",2)}function to(t,e){return ci(t.getUTCDate(),e,2)}function eo(t,e){return ci(t.getUTCHours(),e,2)}function no(t,e){return ci(t.getUTCHours()%12||12,e,2)}function ro(t,e){return ci(1+gn.count(On(t),t),e,3)}function io(t,e){return ci(t.getUTCMilliseconds(),e,3)}function oo(t,e){return io(t,e)+"000"}function ao(t,e){return ci(t.getUTCMonth()+1,e,2)}function so(t,e){return ci(t.getUTCMinutes(),e,2)}function uo(t,e){return ci(t.getUTCSeconds(),e,2)}function lo(t){var e=t.getUTCDay();return 0===e?7:e}function co(t,e){return ci(En.count(On(t)-1,t),e,2)}function fo(t){var e=t.getUTCDay();return e>=4||0===e?Sn(t):Sn.ceil(t)}function ho(t,e){return t=fo(t),ci(Sn.count(On(t),t)+(4===On(t).getUTCDay()),e,2)}function po(t){return t.getUTCDay()}function go(t,e){return ci(Dn.count(On(t)-1,t),e,2)}function mo(t,e){return ci(t.getUTCFullYear()%100,e,2)}function yo(t,e){return ci((t=fo(t)).getUTCFullYear()%100,e,2)}function vo(t,e){return ci(t.getUTCFullYear()%1e4,e,4)}function _o(t,e){var n=t.getUTCDay();return ci((t=n>=4||0===n?Sn(t):Sn.ceil(t)).getUTCFullYear()%1e4,e,4)}function xo(){return"+0000"}function bo(){return"%"}function wo(t){return+t}function ko(t){return Math.floor(+t/1e3)}function Ao(t){const e={};return n=>e[n]||(e[n]=t(n))}function Mo(t){const e=Ao(t.format),n=t.formatPrefix;return{format:e,formatPrefix:n,formatFloat(t){const n=Re(t||",");if(null==n.precision){switch(n.precision=12,n.type){case"%":n.precision-=2;break;case"e":n.precision-=1}return r=e(n),i=e(".1f")(1)[1],t=>{const e=r(t),n=e.indexOf(i);if(n<0)return e;let o=function(t,e){let n,r=t.lastIndexOf("e");if(r>0)return r;for(r=t.length;--r>e;)if(n=t.charCodeAt(r),n>=48&&n<=57)return r+1}(e,n);const a=on;)if("0"!==e[o]){++o;break}return e.slice(0,o)+a}}return e(n);var r,i},formatSpan(t,r,i,o){o=Re(null==o?",f":o);const a=be(t,r,i),s=Math.max(Math.abs(t),Math.abs(r));let u;if(null==o.precision)switch(o.type){case"s":return isNaN(u=Xe(a,s))||(o.precision=u),n(o,s);case"":case"e":case"g":case"p":case"r":isNaN(u=Je(a,s))||(o.precision=u-("e"===o.type));break;case"f":case"%":isNaN(u=Ve(a))||(o.precision=u-2*("%"===o.type))}return e(o)}}}let Eo,Do;function Co(){return Eo=Mo({format:Ie,formatPrefix:We})}function Fo(t){return Mo(Ge(t))}function So(t){return arguments.length?Eo=Fo(t):Eo}function $o(t,e,n){A(n=n||{})||s(`Invalid time multi-format specifier: ${n}`);const r=e(Zn),i=e(Jn),o=e(Xn),a=e(Yn),u=e(Hn),l=e(Wn),c=e(In),f=e(jn),h=t(n[Qn]||".%L"),d=t(n[Zn]||":%S"),p=t(n[Jn]||"%I:%M"),g=t(n[Xn]||"%I %p"),m=t(n[Yn]||n[Gn]||"%a %d"),y=t(n[Hn]||"%b %d"),v=t(n[Wn]||"%B"),_=t(n[In]||"%B"),x=t(n[jn]||"%Y");return t=>(r(t)xt(t)?e(t):$o(e,Cr,t),utcFormat:t=>xt(t)?n(t):$o(n,Fr,t),timeParse:Ao(t.parse),utcParse:Ao(t.utcParse)}}function Bo(){return Do=To({format:ni,parse:ri,utcFormat:ii,utcParse:oi})}function zo(t){return To(ti(t))}function No(t){return arguments.length?Do=zo(t):Do}!function(t){ei=ti(t),ni=ei.format,ri=ei.parse,ii=ei.utcFormat,oi=ei.utcParse}({dateTime:"%x, %X",date:"%-m/%-d/%Y",time:"%-I:%M:%S %p",periods:["AM","PM"],days:["Sunday","Monday","Tuesday","Wednesday","Thursday","Friday","Saturday"],shortDays:["Sun","Mon","Tue","Wed","Thu","Fri","Sat"],months:["January","February","March","April","May","June","July","August","September","October","November","December"],shortMonths:["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]}),Co(),Bo();const Oo=(t,e)=>ot({},t,e);function Ro(t,e){const n=t?Fo(t):So(),r=e?zo(e):No();return Oo(n,r)}function Uo(t,e){const n=arguments.length;return n&&2!==n&&s("defaultLocale expects either zero or two arguments."),n?Oo(So(t),No(e)):Oo(So(),No())}const Lo=/^(data:|([A-Za-z]+:)?\/\/)/,qo=/^(?:(?:(?:f|ht)tps?|mailto|tel|callto|cid|xmpp|file|data):|[^a-z]|[a-z+.\-]+(?:[^a-z+.\-:]|$))/i,Po=/[\u0000-\u0020\u00A0\u1680\u180E\u2000-\u2029\u205f\u3000]/g,jo="file://";async function Io(t,e){const n=await this.sanitize(t,e),r=n.href;return n.localFile?this.file(r):this.http(r,e)}async function Wo(t,e){e=ot({},this.options,e);const n=this.fileAccess,r={href:null};let i,o,a;const u=qo.test(t.replace(Po,""));null!=t&&"string"==typeof t&&u||s("Sanitize failure, invalid URI: "+Ct(t));const l=Lo.test(t);return(a=e.baseURL)&&!l&&(t.startsWith("/")||a.endsWith("/")||(t="/"+t),t=a+t),o=(i=t.startsWith(jo))||"file"===e.mode||"http"!==e.mode&&!l&&n,i?t=t.slice(jo.length):t.startsWith("//")&&("file"===e.defaultProtocol?(t=t.slice(2),o=!0):t=(e.defaultProtocol||"http")+":"+t),Object.defineProperty(r,"localFile",{value:!!o}),r.href=t,e.target&&(r.target=e.target+""),e.rel&&(r.rel=e.rel+""),"image"===e.context&&e.crossOrigin&&(r.crossOrigin=e.crossOrigin+""),r}function Ho(t){return t?e=>new Promise(((n,r)=>{t.readFile(e,((t,e)=>{t?r(t):n(e)}))})):Yo}async function Yo(){s("No file system access.")}function Go(t){return t?async function(e,n){const r=ot({},this.options.http,n),i=n&&n.response,o=await t(e,r);return o.ok?J(o[i])?o[i]():o.text():s(o.status+""+o.statusText)}:Vo}async function Vo(){s("No HTTP fetch method available.")}const Xo=t=>null!=t&&t==t,Jo=t=>!(Number.isNaN(+t)||t instanceof Date),Zo={boolean:Ft,integer:S,number:S,date:$t,string:Tt,unknown:f},Qo=[t=>"true"===t||"false"===t||!0===t||!1===t,t=>Jo(t)&&Number.isInteger(+t),Jo,t=>!Number.isNaN(Date.parse(t))],Ko=["boolean","integer","number","date"];function ta(t,e){if(!t||!t.length)return"unknown";const n=t.length,r=Qo.length,i=Qo.map(((t,e)=>e+1));for(let o,a,s=0,u=0;s0===t?e:t),0)-1]}function ea(t,e){return e.reduce(((e,n)=>(e[n]=ta(t,n),e)),{})}function na(t){const e=function(e,n){const r={delimiter:t};return ra(e,n?ot(n,r):r)};return e.responseType="text",e}function ra(t,e){return e.header&&(t=e.header.map(Ct).join(e.delimiter)+"\n"+t),Ht(e.delimiter).parse(t+"")}function ia(t,e){const n=e&&e.property?l(e.property):f;return!A(t)||(r=t,"function"==typeof Buffer&&J(Buffer.isBuffer)&&Buffer.isBuffer(r))?n(JSON.parse(t)):function(t,e){!k(t)&&yt(t)&&(t=[...t]);return e&&e.copy?JSON.parse(JSON.stringify(t)):t}(n(t),e);var r}ra.responseType="text",ia.responseType="json";const oa={interior:(t,e)=>t!==e,exterior:(t,e)=>t===e};function aa(t,e){let n,r,i,o;return t=ia(t,e),e&&e.feature?(n=Gt,i=e.feature):e&&e.mesh?(n=Zt,i=e.mesh,o=oa[e.filter]):s("Missing TopoJSON feature or mesh parameter."),r=(r=t.objects[i])?n(t,r,o):s("Invalid TopoJSON object: "+i),r&&r.features||[r]}aa.responseType="json";const sa={dsv:ra,csv:na(","),tsv:na("\t"),json:ia,topojson:aa};function ua(t,e){return arguments.length>1?(sa[t]=e,this):lt(sa,t)?sa[t]:null}function la(t){const e=ua(t);return e&&e.responseType||"text"}function ca(t,e,n,r){const i=ua((e=e||{}).type||"json");return i||s("Unknown data format type: "+e.type),t=i(t,e),e.parse&&function(t,e,n,r){if(!t.length)return;const i=No();n=n||i.timeParse,r=r||i.utcParse;let o,a,s,u,l,c,f=t.columns||Object.keys(t[0]);"auto"===e&&(e=ea(t,f));f=Object.keys(e);const h=f.map((t=>{const i=e[t];let o,a;if(i&&(i.startsWith("date:")||i.startsWith("utc:"))){o=i.split(/:(.+)?/,2),a=o[1],("'"===a[0]&&"'"===a[a.length-1]||'"'===a[0]&&'"'===a[a.length-1])&&(a=a.slice(1,-1));return("utc"===o[0]?r:n)(a)}if(!Zo[i])throw Error("Illegal format pattern: "+t+":"+i);return Zo[i]}));for(s=0,l=t.length,c=f.length;s({options:n||{},sanitize:Wo,load:Io,fileAccess:!!e,file:Ho(e),http:Go(t)})}("undefined"!=typeof fetch&&fetch,null);function ha(t){const e=t||f,n=[],r={};return n.add=t=>{const i=e(t);return r[i]||(r[i]=1,n.push(t)),n},n.remove=t=>{const i=e(t);if(r[i]){r[i]=0;const e=n.indexOf(t);e>=0&&n.splice(e,1)}return n},n}async function da(t,e){try{await e(t)}catch(e){t.error(e)}}const pa=Symbol("vega_id");let ga=1;function ma(t){return!(!t||!ya(t))}function ya(t){return t[pa]}function va(t,e){return t[pa]=e,t}function _a(t){const e=t===Object(t)?t:{data:t};return ya(e)?e:va(e,ga++)}function xa(t){return ba(t,_a({}))}function ba(t,e){for(const n in t)e[n]=t[n];return e}function wa(t,e){return va(e,ya(t))}function ka(t,e){return t?e?(n,r)=>t(n,r)||ya(e(n))-ya(e(r)):(e,n)=>t(e,n)||ya(e)-ya(n):null}function Aa(t){return t&&t.constructor===Ma}function Ma(){const t=[],e=[],n=[],r=[],i=[];let o=null,a=!1;return{constructor:Ma,insert(e){const n=V(e),r=n.length;for(let e=0;e{p(t)&&(l[ya(t)]=-1)}));for(f=0,h=t.length;f0&&(y(g,p,d.value),s.modifies(p));for(f=0,h=i.length;f{p(t)&&l[ya(t)]>0&&y(t,d.field,d.value)})),s.modifies(d.field);if(a)s.mod=e.length||r.length?u.filter((t=>l[ya(t)]>0)):u.slice();else for(m in c)s.mod.push(c[m]);return(o||null==o&&(e.length||r.length))&&s.clean(!0),s}}}const Ea="_:mod:_";function Da(){Object.defineProperty(this,Ea,{writable:!0,value:{}})}Da.prototype={set(t,e,n,r){const i=this,o=i[t],a=i[Ea];return null!=e&&e>=0?(o[e]!==n||r)&&(o[e]=n,a[e+":"+t]=-1,a[t]=-1):(o!==n||r)&&(i[t]=n,a[t]=k(n)?1+n.length:-1),i},modified(t,e){const n=this[Ea];if(!arguments.length){for(const t in n)if(n[t])return!0;return!1}if(k(t)){for(let e=0;e=0?e+1{a instanceof Sa?(a!==this&&(e&&a.targets().add(this),o.push(a)),i.push({op:a,name:t,index:n})):r.set(t,n,a)};for(a in t)if(u=t[a],"pulse"===a)V(u).forEach((t=>{t instanceof Sa?t!==this&&(t.targets().add(this),o.push(t)):s("Pulse parameters must be operator instances.")})),this.source=u;else if(k(u))for(r.set(a,-1,Array(l=u.length)),c=0;c{const n=Date.now();return n-e>t?(e=n,1):0}))},debounce(t){const e=za();return this.targets().add(za(null,null,it(t,(t=>{const n=t.dataflow;e.receive(t),n&&n.run&&n.run()})))),e},between(t,e){let n=!1;return t.targets().add(za(null,null,(()=>n=!0))),e.targets().add(za(null,null,(()=>n=!1))),this.filter((()=>n))},detach(){this._filter=p,this._targets=null}};const Na={skip:!0};function Oa(t,e,n,r,i,o){const a=ot({},o,Na);let s,u;J(n)||(n=rt(n)),void 0===r?s=e=>t.touch(n(e)):J(r)?(u=new Sa(null,r,i,!1),s=e=>{u.evaluate(e);const r=n(e),i=u.value;Aa(i)?t.pulse(r,i,o):t.update(r,i,a)}):s=e=>t.update(n(e),r,a),e.apply(s)}function Ra(t,e,n,r,i,o){if(void 0===r)e.targets().add(n);else{const a=o||{},s=new Sa(null,function(t,e){return e=J(e)?e:rt(e),t?function(n,r){const i=e(n,r);return t.skip()||(t.skip(i!==this.value).value=i),i}:e}(n,r),i,!1);s.modified(a.force),s.rank=e.rank,e.targets().add(s),n&&(s.skip(!0),s.value=n.value,s.targets().add(n),t.connect(n,[s]))}}const Ua={};function La(t,e,n){this.dataflow=t,this.stamp=null==e?-1:e,this.add=[],this.rem=[],this.mod=[],this.fields=null,this.encode=n||null}function qa(t,e){const n=[];return Nt(t,e,(t=>n.push(t))),n}function Pa(t,e){const n={};return t.visit(e,(t=>{n[ya(t)]=1})),t=>n[ya(t)]?null:t}function ja(t,e){return t?(n,r)=>t(n,r)&&e(n,r):e}function Ia(t,e,n,r){const i=this;let o=0;this.dataflow=t,this.stamp=e,this.fields=null,this.encode=r||null,this.pulses=n;for(const t of n)if(t.stamp===e){if(t.fields){const e=i.fields||(i.fields={});for(const n in t.fields)e[n]=1}t.changed(i.ADD)&&(o|=i.ADD),t.changed(i.REM)&&(o|=i.REM),t.changed(i.MOD)&&(o|=i.MOD)}this.changes=o}function Wa(t){return t.error("Dataflow already running. Use runAsync() to chain invocations."),t}La.prototype={StopPropagation:Ua,ADD:1,REM:2,MOD:4,ADD_REM:3,ADD_MOD:5,ALL:7,REFLOW:8,SOURCE:16,NO_SOURCE:32,NO_FIELDS:64,fork(t){return new La(this.dataflow).init(this,t)},clone(){const t=this.fork(7);return t.add=t.add.slice(),t.rem=t.rem.slice(),t.mod=t.mod.slice(),t.source&&(t.source=t.source.slice()),t.materialize(23)},addAll(){let t=this;return!t.source||t.add===t.rem||!t.rem.length&&t.source.length===t.add.length||(t=new La(this.dataflow).init(this),t.add=t.source,t.rem=[]),t},init(t,e){const n=this;return n.stamp=t.stamp,n.encode=t.encode,!t.fields||64&e||(n.fields=t.fields),1&e?(n.addF=t.addF,n.add=t.add):(n.addF=null,n.add=[]),2&e?(n.remF=t.remF,n.rem=t.rem):(n.remF=null,n.rem=[]),4&e?(n.modF=t.modF,n.mod=t.mod):(n.modF=null,n.mod=[]),32&e?(n.srcF=null,n.source=null):(n.srcF=t.srcF,n.source=t.source,t.cleans&&(n.cleans=t.cleans)),n},runAfter(t){this.dataflow.runAfter(t)},changed(t){const e=t||7;return 1&e&&this.add.length||2&e&&this.rem.length||4&e&&this.mod.length},reflow(t){if(t)return this.fork(7).reflow();const e=this.add.length,n=this.source&&this.source.length;return n&&n!==e&&(this.mod=this.source,e&&this.filter(4,Pa(this,1))),this},clean(t){return arguments.length?(this.cleans=!!t,this):this.cleans},modifies(t){const e=this.fields||(this.fields={});return k(t)?t.forEach((t=>e[t]=!0)):e[t]=!0,this},modified(t,e){const n=this.fields;return!(!e&&!this.mod.length||!n)&&(arguments.length?k(t)?t.some((t=>n[t])):n[t]:!!n)},filter(t,e){const n=this;return 1&t&&(n.addF=ja(n.addF,e)),2&t&&(n.remF=ja(n.remF,e)),4&t&&(n.modF=ja(n.modF,e)),16&t&&(n.srcF=ja(n.srcF,e)),n},materialize(t){const e=this;return 1&(t=t||7)&&e.addF&&(e.add=qa(e.add,e.addF),e.addF=null),2&t&&e.remF&&(e.rem=qa(e.rem,e.remF),e.remF=null),4&t&&e.modF&&(e.mod=qa(e.mod,e.modF),e.modF=null),16&t&&e.srcF&&(e.source=e.source.filter(e.srcF),e.srcF=null),e},visit(t,e){const n=this,r=e;if(16&t)return Nt(n.source,n.srcF,r),n;1&t&&Nt(n.add,n.addF,r),2&t&&Nt(n.rem,n.remF,r),4&t&&Nt(n.mod,n.modF,r);const i=n.source;if(8&t&&i){const t=n.add.length+n.mod.length;t===i.length||Nt(i,t?Pa(n,5):n.srcF,r)}return n}},dt(Ia,La,{fork(t){const e=new La(this.dataflow).init(this,t&this.NO_FIELDS);return void 0!==t&&(t&e.ADD&&this.visit(e.ADD,(t=>e.add.push(t))),t&e.REM&&this.visit(e.REM,(t=>e.rem.push(t))),t&e.MOD&&this.visit(e.MOD,(t=>e.mod.push(t)))),e},changed(t){return this.changes&t},modified(t){const e=this,n=e.fields;return n&&e.changes&e.MOD?k(t)?t.some((t=>n[t])):n[t]:0},filter(){s("MultiPulse does not support filtering.")},materialize(){s("MultiPulse does not support materialization.")},visit(t,e){const n=this,r=n.pulses,i=r.length;let o=0;if(t&n.SOURCE)for(;oe=[],size:()=>e.length,peek:()=>e[0],push:n=>(e.push(n),Ga(e,0,e.length-1,t)),pop:()=>{const n=e.pop();let r;return e.length?(r=e[0],e[0]=n,function(t,e,n){const r=e,i=t.length,o=t[e];let a,s=1+(e<<1);for(;s=0&&(s=a),t[e]=t[s],s=1+((e=s)<<1);t[e]=o,Ga(t,r,e,n)}(e,0,t)):r=n,r}}}function Ga(t,e,n,r){let i,o;const a=t[n];for(;n>e&&(o=n-1>>1,i=t[o],r(a,i)<0);)t[n]=i,n=o;return t[n]=a}function Va(){this.logger(w()),this.logLevel(v),this._clock=0,this._rank=0,this._locale=Uo();try{this._loader=fa()}catch(t){}this._touched=ha(c),this._input={},this._pulse=null,this._heap=Ya(((t,e)=>t.qrank-e.qrank)),this._postrun=[]}function Xa(t){return function(){return this._log[t].apply(this,arguments)}}function Ja(t,e){Sa.call(this,t,null,e)}Va.prototype={stamp(){return this._clock},loader(t){return arguments.length?(this._loader=t,this):this._loader},locale(t){return arguments.length?(this._locale=t,this):this._locale},logger(t){return arguments.length?(this._log=t,this):this._log},error:Xa("error"),warn:Xa("warn"),info:Xa("info"),debug:Xa("debug"),logLevel:Xa("level"),cleanThreshold:1e4,add:function(t,e,n,r){let i,o=1;return t instanceof Sa?i=t:t&&t.prototype instanceof Sa?i=new t:J(t)?i=new Sa(null,t):(o=0,i=new Sa(t,e)),this.rank(i),o&&(r=n,n=e),n&&this.connect(i,i.parameters(n,r)),this.touch(i),i},connect:function(t,e){const n=t.rank,r=e.length;for(let i=0;i=0;)e.push(n=r[i]),n===t&&s("Cycle detected in dataflow graph.")},pulse:function(t,e,n){this.touch(t,n||Ha);const r=new La(this,this._clock+(this._pulse?0:1)),i=t.pulse&&t.pulse.source||[];return r.target=t,this._input[t.id]=e.pulse(r,i),this},touch:function(t,e){const n=e||Ha;return this._pulse?this._enqueue(t):this._touched.add(t),n.skip&&t.skip(!0),this},update:function(t,e,n){const r=n||Ha;return(t.set(e)||r.force)&&this.touch(t,r),this},changeset:Ma,ingest:function(t,e,n){return e=this.parse(e,n),this.pulse(t,this.changeset().insert(e))},parse:function(t,e){const n=this.locale();return ca(t,e,n.timeParse,n.utcParse)},preload:async function(t,e,n){const r=this,i=r._pending||function(t){let e;const n=new Promise((t=>e=t));return n.requests=0,n.done=()=>{0==--n.requests&&(t._pending=null,e(t))},t._pending=n}(r);i.requests+=1;const o=await r.request(e,n);return r.pulse(t,r.changeset().remove(p).insert(o.data||[])),i.done(),o},request:async function(t,e){const n=this;let r,i=0;try{r=await n.loader().load(t,{context:"dataflow",response:la(e&&e.type)});try{r=n.parse(r,e)}catch(e){i=-2,n.warn("Data ingestion failed",t,e)}}catch(e){i=-1,n.warn("Loading failed",t,e)}return{data:r,status:i}},events:function(t,e,n,r){const i=this,o=za(n,r),a=function(t){t.dataflow=i;try{o.receive(t)}catch(t){i.error(t)}finally{i.run()}};let s;s="string"==typeof t&&"undefined"!=typeof document?document.querySelectorAll(t):V(t);const u=s.length;for(let t=0;tr._enqueue(t,!0))),r._touched=ha(c);let a,s,u,l=0;try{for(;r._heap.size()>0;)a=r._heap.pop(),a.rank===a.qrank?(s=a.run(r._getPulse(a,t)),s.then?s=await s:s.async&&(i.push(s.async),s=Ua),s!==Ua&&a._targets&&a._targets.forEach((t=>r._enqueue(t))),++l):r._enqueue(a,!0)}catch(t){r._heap.clear(),u=t}if(r._input={},r._pulse=null,r.debug(`Pulse ${o}: ${l} operators`),u&&(r._postrun=[],r.error(u)),r._postrun.length){const t=r._postrun.sort(((t,e)=>e.priority-t.priority));r._postrun=[];for(let e=0;er.runAsync(null,(()=>{t.forEach((t=>{try{t(r)}catch(t){r.error(t)}}))})))),r},run:function(t,e,n){return this._pulse?Wa(this):(this.evaluate(t,e,n),this)},runAsync:async function(t,e,n){for(;this._running;)await this._running;const r=()=>this._running=null;return(this._running=this.evaluate(t,e,n)).then(r,r),this._running},runAfter:function(t,e,n){if(this._pulse||e)this._postrun.push({priority:n||0,callback:t});else try{t(this)}catch(t){this.error(t)}},_enqueue:function(t,e){const n=t.stampt.pulse)),e):this._input[t.id]||function(t,e){if(e&&e.stamp===t.stamp)return e;t=t.fork(),e&&e!==Ua&&(t.source=e.source);return t}(this._pulse,n&&n.pulse)}},dt(Ja,Sa,{run(t){if(t.stampthis.pulse=t)):e!==t.StopPropagation&&(this.pulse=e),e},evaluate(t){const e=this.marshall(t.stamp),n=this.transform(e,t);return e.clear(),n},transform(){}});const Za={};function Qa(t){const e=Ka(t);return e&&e.Definition||null}function Ka(t){return t=t&&t.toLowerCase(),lt(Za,t)?Za[t]:null}function*ts(t,e){if(null==e)for(let e of t)null!=e&&""!==e&&(e=+e)>=e&&(yield e);else{let n=-1;for(let r of t)r=e(r,++n,t),null!=r&&""!==r&&(r=+r)>=r&&(yield r)}}function es(t,e,n){const r=Float64Array.from(ts(t,n));return r.sort(Kt),e.map((t=>De(r,t)))}function ns(t,e){return es(t,[.25,.5,.75],e)}function rs(t,e){const n=t.length,r=function(t,e){const n=function(t,e){let n,r=0,i=0,o=0;if(void 0===e)for(let e of t)null!=e&&(e=+e)>=e&&(n=e-i,i+=n/++r,o+=n*(e-i));else{let a=-1;for(let s of t)null!=(s=e(s,++a,t))&&(s=+s)>=s&&(n=s-i,i+=n/++r,o+=n*(s-i))}if(r>1)return o/(r-1)}(t,e);return n?Math.sqrt(n):n}(t,e),i=ns(t,e),o=(i[2]-i[0])/1.34;return 1.06*(Math.min(r,o)||r||Math.abs(i[0])||1)*Math.pow(n,-.2)}function is(t){const e=t.maxbins||20,n=t.base||10,r=Math.log(n),i=t.divide||[5,2];let o,a,s,u,l,c,f=t.extent[0],h=t.extent[1];const d=t.span||h-f||Math.abs(f)||1;if(t.step)o=t.step;else if(t.steps){for(u=d/e,l=0,c=t.steps.length;le;)o*=n;for(l=0,c=i.length;l=s&&d/u<=e&&(o=u)}u=Math.log(o);const p=u>=0?0:1+~~(-u/r),g=Math.pow(n,-p-1);return(t.nice||void 0===t.nice)&&(u=Math.floor(f/o+g)*o,f=ft);const i=t.length,o=new Float64Array(i);let a,s=0,u=1,l=r(t[0]),c=l,f=l+e;for(;u=f){for(c=(l+c)/2;s>1);ia;)t[i--]=t[o]}o=a,a=r}return t}(o,e+e/4):o}t.random=Math.random;const ss=Math.sqrt(2*Math.PI),us=Math.SQRT2;let ls=NaN;function cs(e,n){e=e||0,n=null==n?1:n;let r,i,o=0,a=0;if(ls==ls)o=ls,ls=NaN;else{do{o=2*t.random()-1,a=2*t.random()-1,r=o*o+a*a}while(0===r||r>1);i=Math.sqrt(-2*Math.log(r)/r),o*=i,ls=a*i}return e+o*n}function fs(t,e,n){const r=(t-(e||0))/(n=null==n?1:n);return Math.exp(-.5*r*r)/(n*ss)}function hs(t,e,n){const r=(t-(e=e||0))/(n=null==n?1:n),i=Math.abs(r);let o;if(i>37)o=0;else{const t=Math.exp(-i*i/2);let e;i<7.07106781186547?(e=.0352624965998911*i+.700383064443688,e=e*i+6.37396220353165,e=e*i+33.912866078383,e=e*i+112.079291497871,e=e*i+221.213596169931,e=e*i+220.206867912376,o=t*e,e=.0883883476483184*i+1.75566716318264,e=e*i+16.064177579207,e=e*i+86.7807322029461,e=e*i+296.564248779674,e=e*i+637.333633378831,e=e*i+793.826512519948,e=e*i+440.413735824752,o/=e):(e=i+.65,e=i+4/e,e=i+3/e,e=i+2/e,e=i+1/e,o=t/e/2.506628274631)}return r>0?1-o:o}function ds(t,e,n){return t<0||t>1?NaN:(e||0)+(null==n?1:n)*us*function(t){let e,n=-Math.log((1-t)*(1+t));n<6.25?(n-=3.125,e=-364441206401782e-35,e=e*n-16850591381820166e-35,e=128584807152564e-32+e*n,e=11157877678025181e-33+e*n,e=e*n-1333171662854621e-31,e=20972767875968562e-33+e*n,e=6637638134358324e-30+e*n,e=e*n-4054566272975207e-29,e=e*n-8151934197605472e-29,e=26335093153082323e-28+e*n,e=e*n-12975133253453532e-27,e=e*n-5415412054294628e-26,e=1.0512122733215323e-9+e*n,e=e*n-4.112633980346984e-9,e=e*n-2.9070369957882005e-8,e=4.2347877827932404e-7+e*n,e=e*n-13654692000834679e-22,e=e*n-13882523362786469e-21,e=.00018673420803405714+e*n,e=e*n-.000740702534166267,e=e*n-.006033670871430149,e=.24015818242558962+e*n,e=1.6536545626831027+e*n):n<16?(n=Math.sqrt(n)-3.25,e=2.2137376921775787e-9,e=9.075656193888539e-8+e*n,e=e*n-2.7517406297064545e-7,e=1.8239629214389228e-8+e*n,e=15027403968909828e-22+e*n,e=e*n-4013867526981546e-21,e=29234449089955446e-22+e*n,e=12475304481671779e-21+e*n,e=e*n-47318229009055734e-21,e=6828485145957318e-20+e*n,e=24031110387097894e-21+e*n,e=e*n-.0003550375203628475,e=.0009532893797373805+e*n,e=e*n-.0016882755560235047,e=.002491442096107851+e*n,e=e*n-.003751208507569241,e=.005370914553590064+e*n,e=1.0052589676941592+e*n,e=3.0838856104922208+e*n):Number.isFinite(n)?(n=Math.sqrt(n)-5,e=-27109920616438573e-27,e=e*n-2.555641816996525e-10,e=1.5076572693500548e-9+e*n,e=e*n-3.789465440126737e-9,e=7.61570120807834e-9+e*n,e=e*n-1.496002662714924e-8,e=2.914795345090108e-8+e*n,e=e*n-6.771199775845234e-8,e=2.2900482228026655e-7+e*n,e=e*n-9.9298272942317e-7,e=4526062597223154e-21+e*n,e=e*n-1968177810553167e-20,e=7599527703001776e-20+e*n,e=e*n-.00021503011930044477,e=e*n-.00013871931833623122,e=1.0103004648645344+e*n,e=4.849906401408584+e*n):e=1/0;return e*t}(2*t-1)}function ps(t,e){let n,r;const i={mean(t){return arguments.length?(n=t||0,i):n},stdev(t){return arguments.length?(r=null==t?1:t,i):r},sample:()=>cs(n,r),pdf:t=>fs(t,n,r),cdf:t=>hs(t,n,r),icdf:t=>ds(t,n,r)};return i.mean(t).stdev(e)}function gs(e,n){const r=ps();let i=0;const o={data(t){return arguments.length?(e=t,i=t?t.length:0,o.bandwidth(n)):e},bandwidth(t){return arguments.length?(!(n=t)&&e&&(n=rs(e)),o):n},sample:()=>e[~~(t.random()*i)]+n*r.sample(),pdf(t){let o=0,a=0;for(;ams(n,r),pdf:t=>ys(t,n,r),cdf:t=>vs(t,n,r),icdf:t=>_s(t,n,r)};return i.mean(t).stdev(e)}function bs(e,n){let r,i=0;const o={weights(t){return arguments.length?(r=function(t){const e=[];let n,r=0;for(n=0;n=e&&t<=n?1/(n-e):0}function As(t,e,n){return null==n&&(n=null==e?1:e,e=0),tn?1:(t-e)/(n-e)}function Ms(t,e,n){return null==n&&(n=null==e?1:e,e=0),t>=0&&t<=1?e+t*(n-e):NaN}function Es(t,e){let n,r;const i={min(t){return arguments.length?(n=t||0,i):n},max(t){return arguments.length?(r=null==t?1:t,i):r},sample:()=>ws(n,r),pdf:t=>ks(t,n,r),cdf:t=>As(t,n,r),icdf:t=>Ms(t,n,r)};return null==e&&(e=null==t?1:t,t=0),i.min(t).max(e)}function Ds(t,e,n){let r=0,i=0;for(const o of t){const t=n(o);null==e(o)||null==t||isNaN(t)||(r+=(t-r)/++i)}return{coef:[r],predict:()=>r,rSquared:0}}function Cs(t,e,n,r){const i=r-t*t,o=Math.abs(i)<1e-24?0:(n-t*e)/i;return[e-o*t,o]}function Fs(t,e,n,r){t=t.filter((t=>{let r=e(t),i=n(t);return null!=r&&(r=+r)>=r&&null!=i&&(i=+i)>=i})),r&&t.sort(((t,n)=>e(t)-e(n)));const i=t.length,o=new Float64Array(i),a=new Float64Array(i);let s,u,l,c=0,f=0,h=0;for(l of t)o[c]=s=+e(l),a[c]=u=+n(l),++c,f+=(s-f)/c,h+=(u-h)/c;for(c=0;c=i&&null!=o&&(o=+o)>=o&&r(i,o,++a)}function $s(t,e,n,r,i){let o=0,a=0;return Ss(t,e,n,((t,e)=>{const n=e-i(t),s=e-r;o+=n*n,a+=s*s})),1-o/a}function Ts(t,e,n){let r=0,i=0,o=0,a=0,s=0;Ss(t,e,n,((t,e)=>{++s,r+=(t-r)/s,i+=(e-i)/s,o+=(t*e-o)/s,a+=(t*t-a)/s}));const u=Cs(r,i,o,a),l=t=>u[0]+u[1]*t;return{coef:u,predict:l,rSquared:$s(t,e,n,i,l)}}function Bs(t,e,n){let r=0,i=0,o=0,a=0,s=0;Ss(t,e,n,((t,e)=>{++s,t=Math.log(t),r+=(t-r)/s,i+=(e-i)/s,o+=(t*e-o)/s,a+=(t*t-a)/s}));const u=Cs(r,i,o,a),l=t=>u[0]+u[1]*Math.log(t);return{coef:u,predict:l,rSquared:$s(t,e,n,i,l)}}function zs(t,e,n){const[r,i,o,a]=Fs(t,e,n);let s,u,l,c=0,f=0,h=0,d=0,p=0;Ss(t,e,n,((t,e)=>{s=r[p++],u=Math.log(e),l=s*e,c+=(e*u-c)/p,f+=(l-f)/p,h+=(l*u-h)/p,d+=(s*l-d)/p}));const[g,m]=Cs(f/a,c/a,h/a,d/a),y=t=>Math.exp(g+m*(t-o));return{coef:[Math.exp(g-m*o),m],predict:y,rSquared:$s(t,e,n,a,y)}}function Ns(t,e,n){let r=0,i=0,o=0,a=0,s=0,u=0;Ss(t,e,n,((t,e)=>{const n=Math.log(t),l=Math.log(e);++u,r+=(n-r)/u,i+=(l-i)/u,o+=(n*l-o)/u,a+=(n*n-a)/u,s+=(e-s)/u}));const l=Cs(r,i,o,a),c=t=>l[0]*Math.pow(t,l[1]);return l[0]=Math.exp(l[0]),{coef:l,predict:c,rSquared:$s(t,e,n,s,c)}}function Os(t,e,n){const[r,i,o,a]=Fs(t,e,n),s=r.length;let u,l,c,f,h=0,d=0,p=0,g=0,m=0;for(u=0;u_*(t-=o)*t+x*t+b+a;return{coef:[b-x*o+_*o*o+a,x-2*_*o,_],predict:w,rSquared:$s(t,e,n,a,w)}}function Rs(t,e,n,r){if(0===r)return Ds(t,e,n);if(1===r)return Ts(t,e,n);if(2===r)return Os(t,e,n);const[i,o,a,s]=Fs(t,e,n),u=i.length,l=[],c=[],f=r+1;let h,d,p,g,m;for(h=0;hMath.abs(t[r][a])&&(a=i);for(o=r;o=r;o--)t[o][i]-=t[o][r]*t[r][i]/t[r][r]}for(i=e-1;i>=0;--i){for(s=0,o=i+1;o{t-=a;let e=s+y[0]+y[1]*t+y[2]*t*t;for(h=3;h=0;--o)for(s=e[o],u=1,i[o]+=s,a=1;a<=o;++a)u*=(o+1-a)/a,i[o-a]+=s*Math.pow(n,a)*u;return i[0]+=r,i}function Ls(t,e,n,r){const[i,o,a,s]=Fs(t,e,n,!0),u=i.length,l=Math.max(2,~~(r*u)),c=new Float64Array(u),f=new Float64Array(u),h=new Float64Array(u).fill(1);for(let t=-1;++t<=2;){const e=[0,l-1];for(let t=0;ti[a]-n?r:a;let u=0,l=0,d=0,p=0,g=0;const m=1/Math.abs(i[s]-n||1);for(let t=r;t<=a;++t){const e=i[t],r=o[t],a=qs(Math.abs(n-e)*m)*h[t],s=e*a;u+=a,l+=s,d+=r*a,p+=r*s,g+=e*s}const[y,v]=Cs(l/u,d/u,p/u,g/u);c[t]=y+v*n,f[t]=Math.abs(o[t]-c[t]),Ps(i,t+1,e)}if(2===t)break;const n=Ce(f);if(Math.abs(n)<1e-12)break;for(let t,e,r=0;r=1?1e-12:(e=1-t*t)*e}return function(t,e,n,r){const i=t.length,o=[];let a,s=0,u=0,l=[];for(;s=t.length))for(;e>i&&t[o]-r<=r-t[i];)n[0]=++i,n[1]=o,++o}const js=.5*Math.PI/180;function Is(t,e,n,r){n=n||25,r=Math.max(n,r||200);const i=e=>[e,t(e)],o=e[0],a=e[1],s=a-o,u=s/r,l=[i(o)],c=[];if(n===r){for(let t=1;t0;)c.push(i(o+t/n*s));let f=l[0],h=c[c.length-1];const d=1/s,p=function(t,e){let n=t,r=t;const i=e.length;for(let t=0;tr&&(r=i)}return 1/(r-n)}(f[1],c);for(;h;){const t=i((f[0]+h[0])/2);t[0]-f[0]>=u&&Ws(f,t,h,d,p)>js?c.push(t):(f=h,l.push(h),c.pop()),h=c[c.length-1]}return l}function Ws(t,e,n,r,i){const o=Math.atan2(i*(n[1]-t[1]),r*(n[0]-t[0])),a=Math.atan2(i*(e[1]-t[1]),r*(e[0]-t[0]));return Math.abs(o-a)}function Hs(t){return t&&t.length?1===t.length?t[0]:(e=t,t=>{const n=e.length;let r=1,i=String(e[0](t));for(;r{},Vs={init:Gs,add:Gs,rem:Gs,idx:0},Xs={values:{init:t=>t.cell.store=!0,value:t=>t.cell.data.values(),idx:-1},count:{value:t=>t.cell.num},__count__:{value:t=>t.missing+t.valid},missing:{value:t=>t.missing},valid:{value:t=>t.valid},sum:{init:t=>t.sum=0,value:t=>t.valid?t.sum:void 0,add:(t,e)=>t.sum+=+e,rem:(t,e)=>t.sum-=e},product:{init:t=>t.product=1,value:t=>t.valid?t.product:void 0,add:(t,e)=>t.product*=e,rem:(t,e)=>t.product/=e},mean:{init:t=>t.mean=0,value:t=>t.valid?t.mean:void 0,add:(t,e)=>(t.mean_d=e-t.mean,t.mean+=t.mean_d/t.valid),rem:(t,e)=>(t.mean_d=e-t.mean,t.mean-=t.valid?t.mean_d/t.valid:t.mean)},average:{value:t=>t.valid?t.mean:void 0,req:["mean"],idx:1},variance:{init:t=>t.dev=0,value:t=>t.valid>1?t.dev/(t.valid-1):void 0,add:(t,e)=>t.dev+=t.mean_d*(e-t.mean),rem:(t,e)=>t.dev-=t.mean_d*(e-t.mean),req:["mean"],idx:1},variancep:{value:t=>t.valid>1?t.dev/t.valid:void 0,req:["variance"],idx:2},stdev:{value:t=>t.valid>1?Math.sqrt(t.dev/(t.valid-1)):void 0,req:["variance"],idx:2},stdevp:{value:t=>t.valid>1?Math.sqrt(t.dev/t.valid):void 0,req:["variance"],idx:2},stderr:{value:t=>t.valid>1?Math.sqrt(t.dev/(t.valid*(t.valid-1))):void 0,req:["variance"],idx:2},distinct:{value:t=>t.cell.data.distinct(t.get),req:["values"],idx:3},ci0:{value:t=>t.cell.data.ci0(t.get),req:["values"],idx:3},ci1:{value:t=>t.cell.data.ci1(t.get),req:["values"],idx:3},median:{value:t=>t.cell.data.q2(t.get),req:["values"],idx:3},q1:{value:t=>t.cell.data.q1(t.get),req:["values"],idx:3},q3:{value:t=>t.cell.data.q3(t.get),req:["values"],idx:3},min:{init:t=>t.min=void 0,value:t=>t.min=Number.isNaN(t.min)?t.cell.data.min(t.get):t.min,add:(t,e)=>{(e{e<=t.min&&(t.min=NaN)},req:["values"],idx:4},max:{init:t=>t.max=void 0,value:t=>t.max=Number.isNaN(t.max)?t.cell.data.max(t.get):t.max,add:(t,e)=>{(e>t.max||void 0===t.max)&&(t.max=e)},rem:(t,e)=>{e>=t.max&&(t.max=NaN)},req:["values"],idx:4},argmin:{init:t=>t.argmin=void 0,value:t=>t.argmin||t.cell.data.argmin(t.get),add:(t,e,n)=>{e{e<=t.min&&(t.argmin=void 0)},req:["min","values"],idx:3},argmax:{init:t=>t.argmax=void 0,value:t=>t.argmax||t.cell.data.argmax(t.get),add:(t,e,n)=>{e>t.max&&(t.argmax=n)},rem:(t,e)=>{e>=t.max&&(t.argmax=void 0)},req:["max","values"],idx:3},exponential:{init:(t,e)=>{t.exp=0,t.exp_r=e},value:t=>t.valid?t.exp*(1-t.exp_r)/(1-t.exp_r**t.valid):void 0,add:(t,e)=>t.exp=t.exp_r*t.exp+e,rem:(t,e)=>t.exp=(t.exp-e/t.exp_r**(t.valid-1))/t.exp_r},exponentialb:{value:t=>t.valid?t.exp*(1-t.exp_r):void 0,req:["exponential"],idx:1}},Js=Object.keys(Xs).filter((t=>"__count__"!==t));function Zs(t,e,n){return Xs[t](n,e)}function Qs(t,e){return t.idx-e.idx}function Ks(){this.valid=0,this.missing=0,this._ops.forEach((t=>null==t.aggregate_param?t.init(this):t.init(this,t.aggregate_param)))}function tu(t,e){null!=t&&""!==t?t==t&&(++this.valid,this._ops.forEach((n=>n.add(this,t,e)))):++this.missing}function eu(t,e){null!=t&&""!==t?t==t&&(--this.valid,this._ops.forEach((n=>n.rem(this,t,e)))):--this.missing}function nu(t){return this._out.forEach((e=>t[e.out]=e.value(this))),t}function ru(t,e){const n=e||f,r=function(t){const e={};t.forEach((t=>e[t.name]=t));const n=t=>{t.req&&t.req.forEach((t=>{e[t]||n(e[t]=Xs[t]())}))};return t.forEach(n),Object.values(e).sort(Qs)}(t),i=t.slice().sort(Qs);function o(t){this._ops=r,this._out=i,this.cell=t,this.init()}return o.prototype.init=Ks,o.prototype.add=tu,o.prototype.rem=eu,o.prototype.set=nu,o.prototype.get=n,o.fields=t.map((t=>t.out)),o}function iu(t){this._key=t?l(t):ya,this.reset()}[...Js,"__count__"].forEach((t=>{Xs[t]=function(t,e){return(n,r)=>ot({name:t,aggregate_param:r,out:n||t},Vs,e)}(t,Xs[t])}));const ou=iu.prototype;function au(t){Ja.call(this,null,t),this._adds=[],this._mods=[],this._alen=0,this._mlen=0,this._drop=!0,this._cross=!1,this._dims=[],this._dnames=[],this._measures=[],this._countOnly=!1,this._counts=null,this._prev=null,this._inputs=null,this._outputs=null}ou.reset=function(){this._add=[],this._rem=[],this._ext=null,this._get=null,this._q=null},ou.add=function(t){this._add.push(t)},ou.rem=function(t){this._rem.push(t)},ou.values=function(){if(this._get=null,0===this._rem.length)return this._add;const t=this._add,e=this._rem,n=this._key,r=t.length,i=e.length,o=Array(r-i),a={};let s,u,l;for(s=0;s=0;)r=t(e[i])+"",lt(n,r)||(n[r]=1,++o);return o},ou.extent=function(t){if(this._get!==t||!this._ext){const e=this.values(),n=st(e,t);this._ext=[e[n[0]],e[n[1]]],this._get=t}return this._ext},ou.argmin=function(t){return this.extent(t)[0]||{}},ou.argmax=function(t){return this.extent(t)[1]||{}},ou.min=function(t){const e=this.extent(t)[0];return null!=e?t(e):void 0},ou.max=function(t){const e=this.extent(t)[1];return null!=e?t(e):void 0},ou.quartile=function(t){return this._get===t&&this._q||(this._q=ns(this.values(),t),this._get=t),this._q},ou.q1=function(t){return this.quartile(t)[0]},ou.q2=function(t){return this.quartile(t)[1]},ou.q3=function(t){return this.quartile(t)[2]},ou.ci=function(t){return this._get===t&&this._ci||(this._ci=os(this.values(),1e3,.05,t),this._get=t),this._ci},ou.ci0=function(t){return this.ci(t)[0]},ou.ci1=function(t){return this.ci(t)[1]},au.Definition={type:"Aggregate",metadata:{generates:!0,changes:!0},params:[{name:"groupby",type:"field",array:!0},{name:"ops",type:"enum",array:!0,values:Js},{name:"aggregate_params",type:"number",null:!0,array:!0},{name:"fields",type:"field",null:!0,array:!0},{name:"as",type:"string",null:!0,array:!0},{name:"drop",type:"boolean",default:!0},{name:"cross",type:"boolean",default:!1},{name:"key",type:"field"}]},dt(au,Ja,{transform(t,e){const n=this,r=e.fork(e.NO_SOURCE|e.NO_FIELDS),i=t.modified();return n.stamp=r.stamp,n.value&&(i||e.modified(n._inputs,!0))?(n._prev=n.value,n.value=i?n.init(t):Object.create(null),e.visit(e.SOURCE,(t=>n.add(t)))):(n.value=n.value||n.init(t),e.visit(e.REM,(t=>n.rem(t))),e.visit(e.ADD,(t=>n.add(t)))),r.modifies(n._outputs),n._drop=!1!==t.drop,t.cross&&n._dims.length>1&&(n._drop=!1,n.cross()),e.clean()&&n._drop&&r.clean(!0).runAfter((()=>this.clean())),n.changes(r)},cross(){const t=this,e=t.value,n=t._dnames,r=n.map((()=>({}))),i=n.length;function o(t){let e,o,a,s;for(e in t)for(a=t[e].tuple,o=0;o{const e=n(t);return a(t),i.push(e),e})),this.cellkey=t.key?t.key:Hs(this._dims),this._countOnly=!0,this._counts=[],this._measures=[];const u=t.fields||[null],l=t.ops||["count"],c=t.aggregate_params||[null],f=t.as||[],h=u.length,d={};let p,g,m,y,v,_,x;for(h!==l.length&&s("Unmatched number of fields and aggregate ops."),x=0;xru(t,t.field))),Object.create(null)},cellkey:Hs(),cell(t,e){let n=this.value[t];return n?0===n.num&&this._drop&&n.stampo.push(t),remove:t=>a[r(t)]=++s,size:()=>i.length,data:(t,e)=>(s&&(i=i.filter((t=>!a[r(t)])),a={},s=0),e&&t&&i.sort(t),o.length&&(i=t?At(t,i,o.sort(t)):i.concat(o),o=[]),i)}}function lu(t){Ja.call(this,[],t)}function cu(t){Sa.call(this,null,fu,t)}function fu(t){return this.value&&!t.modified()?this.value:Q(t.fields,t.orders)}function hu(t){Ja.call(this,null,t)}function du(t){Ja.call(this,null,t)}su.Definition={type:"Bin",metadata:{modifies:!0},params:[{name:"field",type:"field",required:!0},{name:"interval",type:"boolean",default:!0},{name:"anchor",type:"number"},{name:"maxbins",type:"number",default:20},{name:"base",type:"number",default:10},{name:"divide",type:"number",array:!0,default:[5,2]},{name:"extent",type:"number",array:!0,length:2,required:!0},{name:"span",type:"number"},{name:"step",type:"number"},{name:"steps",type:"number",array:!0},{name:"minstep",type:"number",default:0},{name:"nice",type:"boolean",default:!0},{name:"name",type:"string"},{name:"as",type:"string",array:!0,length:2,default:["bin0","bin1"]}]},dt(su,Ja,{transform(t,e){const n=!1!==t.interval,i=this._bins(t),o=i.start,a=i.step,s=t.as||["bin0","bin1"],u=s[0],l=s[1];let c;return c=t.modified()?(e=e.reflow(!0)).SOURCE:e.modified(r(t.field))?e.ADD_MOD:e.ADD,e.visit(c,n?t=>{const e=i(t);t[u]=e,t[l]=null==e?null:o+a*(1+(e-o)/a)}:t=>t[u]=i(t)),e.modifies(n?s:u)},_bins(t){if(this.value&&!t.modified())return this.value;const i=t.field,o=is(t),a=o.step;let s,u,l=o.start,c=l+Math.ceil((o.stop-l)/a)*a;null!=(s=t.anchor)&&(u=s-(l+a*Math.floor((s-l)/a)),l+=u,c+=u);const f=function(t){let e=S(i(t));return null==e?null:ec?1/0:(e=Math.max(l,Math.min(e,c-a)),l+a*Math.floor(1e-14+(e-l)/a))};return f.start=l,f.stop=o.stop,f.step=a,this.value=e(f,r(i),t.name||"bin_"+n(i))}}),lu.Definition={type:"Collect",metadata:{source:!0},params:[{name:"sort",type:"compare"}]},dt(lu,Ja,{transform(t,e){const n=e.fork(e.ALL),r=uu(ya,this.value,n.materialize(n.ADD).add),i=t.sort,o=e.changed()||i&&(t.modified("sort")||e.modified(i.fields));return n.visit(n.REM,r.remove),this.modified(o),this.value=n.source=r.data(ka(i),o),e.source&&e.source.root&&(this.value.root=e.source.root),n}}),dt(cu,Sa),hu.Definition={type:"CountPattern",metadata:{generates:!0,changes:!0},params:[{name:"field",type:"field",required:!0},{name:"case",type:"enum",values:["upper","lower","mixed"],default:"mixed"},{name:"pattern",type:"string",default:'[\\w"]+'},{name:"stopwords",type:"string",default:""},{name:"as",type:"string",array:!0,length:2,default:["text","count"]}]},dt(hu,Ja,{transform(t,e){const n=e=>n=>{for(var r,i=function(t,e,n){switch(e){case"upper":t=t.toUpperCase();break;case"lower":t=t.toLowerCase()}return t.match(n)}(s(n),t.case,o)||[],u=0,l=i.length;ui[t]=1+(i[t]||0))),c=n((t=>i[t]-=1));return r?e.visit(e.SOURCE,l):(e.visit(e.ADD,l),e.visit(e.REM,c)),this._finish(e,u)},_parameterCheck(t,e){let n=!1;return!t.modified("stopwords")&&this._stop||(this._stop=new RegExp("^"+(t.stopwords||"")+"$","i"),n=!0),!t.modified("pattern")&&this._match||(this._match=new RegExp(t.pattern||"[\\w']+","g"),n=!0),(t.modified("field")||e.modified(t.field.fields))&&(n=!0),n&&(this._counts={}),n},_finish(t,e){const n=this._counts,r=this._tuples||(this._tuples={}),i=e[0],o=e[1],a=t.fork(t.NO_SOURCE|t.NO_FIELDS);let s,u,l;for(s in n)u=r[s],l=n[s]||0,!u&&l?(r[s]=u=_a({}),u[i]=s,u[o]=l,a.add.push(u)):0===l?(u&&a.rem.push(u),n[s]=null,r[s]=null):u[o]!==l&&(u[o]=l,a.mod.push(u));return a.modifies(e)}}),du.Definition={type:"Cross",metadata:{generates:!0},params:[{name:"filter",type:"expr"},{name:"as",type:"string",array:!0,length:2,default:["a","b"]}]},dt(du,Ja,{transform(t,e){const n=e.fork(e.NO_SOURCE),r=t.as||["a","b"],i=r[0],o=r[1],a=!this.value||e.changed(e.ADD_REM)||t.modified("as")||t.modified("filter");let s=this.value;return a?(s&&(n.rem=s),s=e.materialize(e.SOURCE).source,n.add=this.value=function(t,e,n,r){for(var i,o,a=[],s={},u=t.length,l=0;lmu(t,e)))):typeof r[n]===gu&&r[n](t[n]);return r}function yu(t){Ja.call(this,null,t)}const vu=[{key:{function:"normal"},params:[{name:"mean",type:"number",default:0},{name:"stdev",type:"number",default:1}]},{key:{function:"lognormal"},params:[{name:"mean",type:"number",default:0},{name:"stdev",type:"number",default:1}]},{key:{function:"uniform"},params:[{name:"min",type:"number",default:0},{name:"max",type:"number",default:1}]},{key:{function:"kde"},params:[{name:"field",type:"field",required:!0},{name:"from",type:"data"},{name:"bandwidth",type:"number",default:0}]}],_u={key:{function:"mixture"},params:[{name:"distributions",type:"param",array:!0,params:vu},{name:"weights",type:"number",array:!0}]};function xu(t,e){return t?t.map(((t,r)=>e[r]||n(t))):null}function bu(t,e,n){const r=[],i=t=>t(u);let o,a,s,u,l,c;if(null==e)r.push(t.map(n));else for(o={},a=0,s=t.length;at.materialize(t.SOURCE).source}(e)),i=t.steps||t.minsteps||25,o=t.steps||t.maxsteps||200;let a=t.method||"pdf";"pdf"!==a&&"cdf"!==a&&s("Invalid density method: "+a),t.extent||r.data||s("Missing density extent parameter."),a=r[a];const u=t.as||["value","density"],l=Is(a,t.extent||at(r.data()),i,o).map((t=>{const e={};return e[u[0]]=t[0],e[u[1]]=t[1],_a(e)}));this.value&&(n.rem=this.value),this.value=n.add=n.source=l}return n}});function wu(t){Ja.call(this,null,t)}wu.Definition={type:"DotBin",metadata:{modifies:!0},params:[{name:"field",type:"field",required:!0},{name:"groupby",type:"field",array:!0},{name:"step",type:"number"},{name:"smooth",type:"boolean",default:!1},{name:"as",type:"string",default:"bin"}]};function ku(t){Sa.call(this,null,Au,t),this.modified(!0)}function Au(t){const i=t.expr;return this.value&&!t.modified("expr")?this.value:e((e=>i(e,t)),r(i),n(i))}function Mu(t){Ja.call(this,[void 0,void 0],t)}function Eu(t,e){Sa.call(this,t),this.parent=e,this.count=0}function Du(t){Ja.call(this,{},t),this._keys=ft();const e=this._targets=[];e.active=0,e.forEach=t=>{for(let n=0,r=e.active;nl(t))):l(t.name,t.as)}function Su(t){Ja.call(this,ft(),t)}function $u(t){Ja.call(this,[],t)}function Tu(t){Ja.call(this,[],t)}function Bu(t){Ja.call(this,null,t)}function zu(t){Ja.call(this,[],t)}dt(wu,Ja,{transform(t,e){if(this.value&&!t.modified()&&!e.changed())return e;const n=e.materialize(e.SOURCE).source,r=bu(e.source,t.groupby,f),i=t.smooth||!1,o=t.field,a=t.step||((t,e)=>Dt(at(t,e))/30)(n,o),s=ka(((t,e)=>o(t)-o(e))),u=t.as||"bin",l=r.length;let c,h=1/0,d=-1/0,p=0;for(;pd&&(d=e),t[++c][u]=e}return this.value={start:h,stop:d,step:a},e.reflow(!0).modifies(u)}}),dt(ku,Sa),Mu.Definition={type:"Extent",metadata:{},params:[{name:"field",type:"field",required:!0}]},dt(Mu,Ja,{transform(t,e){const r=this.value,i=t.field,o=e.changed()||e.modified(i.fields)||t.modified("field");let a=r[0],s=r[1];if((o||null==a)&&(a=1/0,s=-1/0),e.visit(o?e.SOURCE:e.ADD,(t=>{const e=S(i(t));null!=e&&(es&&(s=e))})),!Number.isFinite(a)||!Number.isFinite(s)){let t=n(i);t&&(t=` for field "${t}"`),e.dataflow.warn(`Infinite extent${t}: [${a}, ${s}]`),a=s=void 0}this.value=[a,s]}}),dt(Eu,Sa,{connect(t){return this.detachSubflow=t.detachSubflow,this.targets().add(t),t.source=this},add(t){this.count+=1,this.value.add.push(t)},rem(t){this.count-=1,this.value.rem.push(t)},mod(t){this.value.mod.push(t)},init(t){this.value.init(t,t.NO_SOURCE)},evaluate(){return this.value}}),dt(Du,Ja,{activate(t){this._targets[this._targets.active++]=t},subflow(t,e,n,r){const i=this.value;let o,a,s=lt(i,t)&&i[t];return s?s.value.stampt&&t.count>0));this.initTargets(t)}},initTargets(t){const e=this._targets,n=e.length,r=t?t.length:0;let i=0;for(;ithis.subflow(t,i,e);return this._group=t.group||{},this.initTargets(),e.visit(e.REM,(t=>{const e=ya(t),n=o.get(e);void 0!==n&&(o.delete(e),s(n).rem(t))})),e.visit(e.ADD,(t=>{const e=r(t);o.set(ya(t),e),s(e).add(t)})),a||e.modified(r.fields)?e.visit(e.MOD,(t=>{const e=ya(t),n=o.get(e),i=r(t);n===i?s(i).mod(t):(o.set(e,i),s(n).rem(t),s(i).add(t))})):e.changed(e.MOD)&&e.visit(e.MOD,(t=>{s(o.get(ya(t))).mod(t)})),a&&e.visit(e.REFLOW,(t=>{const e=ya(t),n=o.get(e),i=r(t);n!==i&&(o.set(e,i),s(n).rem(t),s(i).add(t))})),e.clean()?n.runAfter((()=>{this.clean(),o.clean()})):o.empty>n.cleanThreshold&&n.runAfter(o.clean),e}}),dt(Cu,Sa),Su.Definition={type:"Filter",metadata:{changes:!0},params:[{name:"expr",type:"expr",required:!0}]},dt(Su,Ja,{transform(t,e){const n=e.dataflow,r=this.value,i=e.fork(),o=i.add,a=i.rem,s=i.mod,u=t.expr;let l=!0;function c(e){const n=ya(e),i=u(e,t),c=r.get(n);i&&c?(r.delete(n),o.push(e)):i||c?l&&i&&!c&&s.push(e):(r.set(n,1),a.push(e))}return e.visit(e.REM,(t=>{const e=ya(t);r.has(e)?r.delete(e):a.push(t)})),e.visit(e.ADD,(e=>{u(e,t)?o.push(e):r.set(ya(e),1)})),e.visit(e.MOD,c),t.modified()&&(l=!1,e.visit(e.REFLOW,c)),r.empty>n.cleanThreshold&&n.runAfter(r.clean),i}}),$u.Definition={type:"Flatten",metadata:{generates:!0},params:[{name:"fields",type:"field",array:!0,required:!0},{name:"index",type:"string"},{name:"as",type:"string",array:!0}]},dt($u,Ja,{transform(t,e){const n=e.fork(e.NO_SOURCE),r=t.fields,i=xu(r,t.as||[]),o=t.index||null,a=i.length;return n.rem=this.value,e.visit(e.SOURCE,(t=>{const e=r.map((e=>e(t))),s=e.reduce(((t,e)=>Math.max(t,e.length)),0);let u,l,c,f=0;for(;f{for(let e,n=0;ne[r]=n(e,t)))}}),dt(zu,Ja,{transform(t,e){const n=e.fork(e.ALL),r=t.generator;let i,o,a,s=this.value,u=t.size-s.length;if(u>0){for(i=[];--u>=0;)i.push(a=_a(r(t))),s.push(a);n.add=n.add.length?n.materialize(n.ADD).add.concat(i):i}else o=s.slice(0,-u),n.rem=n.rem.length?n.materialize(n.REM).rem.concat(o):o,s=s.slice(-u);return n.source=this.value=s,n}});const Nu={value:"value",median:Ce,mean:function(t,e){let n=0,r=0;if(void 0===e)for(let e of t)null!=e&&(e=+e)>=e&&(++n,r+=e);else{let i=-1;for(let o of t)null!=(o=e(o,++i,t))&&(o=+o)>=o&&(++n,r+=o)}if(n)return r/n},min:ke,max:we},Ou=[];function Ru(t){Ja.call(this,[],t)}function Uu(t){au.call(this,t)}function Lu(t){Ja.call(this,null,t)}function qu(t){Sa.call(this,null,Pu,t)}function Pu(t){return this.value&&!t.modified()?this.value:bt(t.fields,t.flat)}function ju(t){Ja.call(this,[],t),this._pending=null}function Iu(t,e,n){n.forEach(_a);const r=e.fork(e.NO_FIELDS&e.NO_SOURCE);return r.rem=t.value,t.value=r.source=r.add=n,t._pending=null,r.rem.length&&r.clean(!0),r}function Wu(t){Ja.call(this,{},t)}function Hu(t){Sa.call(this,null,Yu,t)}function Yu(t){if(this.value&&!t.modified())return this.value;const e=t.extents,n=e.length;let r,i,o=1/0,a=-1/0;for(r=0;ra&&(a=i[1]);return[o,a]}function Gu(t){Sa.call(this,null,Vu,t)}function Vu(t){return this.value&&!t.modified()?this.value:t.values.reduce(((t,e)=>t.concat(e)),[])}function Xu(t){Ja.call(this,null,t)}function Ju(t){au.call(this,t)}function Zu(t){Du.call(this,t)}function Qu(t){Ja.call(this,null,t)}function Ku(t){Ja.call(this,null,t)}function tl(t){Ja.call(this,null,t)}Ru.Definition={type:"Impute",metadata:{changes:!0},params:[{name:"field",type:"field",required:!0},{name:"key",type:"field",required:!0},{name:"keyvals",array:!0},{name:"groupby",type:"field",array:!0},{name:"method",type:"enum",default:"value",values:["value","mean","median","max","min"]},{name:"value",default:0}]},dt(Ru,Ja,{transform(t,e){var r,i,o,a,u,l,c,f,h,d,p=e.fork(e.ALL),g=function(t){var e,n=t.method||Nu.value;if(null!=Nu[n])return n===Nu.value?(e=void 0!==t.value?t.value:0,()=>e):Nu[n];s("Unrecognized imputation method: "+n)}(t),m=function(t){const e=t.field;return t=>t?e(t):NaN}(t),y=n(t.field),v=n(t.key),_=(t.groupby||[]).map(n),x=function(t,e,n,r){var i,o,a,s,u,l,c,f,h=t=>t(f),d=[],p=r?r.slice():[],g={},m={};for(p.forEach(((t,e)=>g[t]=e+1)),s=0,c=t.length;sn.add(t)))):(i=n.value=n.value||this.init(t),e.visit(e.REM,(t=>n.rem(t))),e.visit(e.ADD,(t=>n.add(t)))),n.changes(),e.visit(e.SOURCE,(t=>{ot(t,i[n.cellkey(t)].tuple)})),e.reflow(r).modifies(this._outputs)},changes(){const t=this._adds,e=this._mods;let n,r;for(n=0,r=this._alen;n{const n=gs(e,u)[l],r=t.counts?e.length:1;Is(n,h||at(e),d,p).forEach((t=>{const n={};for(let t=0;t(this._pending=V(t.data),t=>t.touch(this))));return{async:e}}return n.request(t.url,t.format).then((t=>Iu(this,e,V(t.data))))}}),Wu.Definition={type:"Lookup",metadata:{modifies:!0},params:[{name:"index",type:"index",params:[{name:"from",type:"data",required:!0},{name:"key",type:"field",required:!0}]},{name:"values",type:"field",array:!0},{name:"fields",type:"field",array:!0,required:!0},{name:"as",type:"string",array:!0},{name:"default",default:null}]},dt(Wu,Ja,{transform(t,e){const r=t.fields,i=t.index,o=t.values,a=null==t.default?null:t.default,u=t.modified(),l=r.length;let c,f,h,d=u?e.SOURCE:e.ADD,p=e,g=t.as;return o?(f=o.length,l>1&&!g&&s('Multi-field lookup requires explicit "as" parameter.'),g&&g.length!==l*f&&s('The "as" parameter has too few output field names.'),g=g||o.map(n),c=function(t){for(var e,n,s=0,u=0;se.modified(t.fields))),d|=h?e.MOD:0),e.visit(d,c),p.modifies(g)}}),dt(Hu,Sa),dt(Gu,Sa),dt(Xu,Ja,{transform(t,e){return this.modified(t.modified()),this.value=t,e.fork(e.NO_SOURCE|e.NO_FIELDS)}}),Ju.Definition={type:"Pivot",metadata:{generates:!0,changes:!0},params:[{name:"groupby",type:"field",array:!0},{name:"field",type:"field",required:!0},{name:"value",type:"field",required:!0},{name:"op",type:"enum",values:Js,default:"sum"},{name:"limit",type:"number",default:0},{name:"key",type:"field"}]},dt(Ju,au,{_transform:au.prototype.transform,transform(t,n){return this._transform(function(t,n){const i=t.field,o=t.value,a=("count"===t.op?"__count__":t.op)||"sum",s=r(i).concat(r(o)),u=function(t,e,n){const r={},i=[];return n.visit(n.SOURCE,(e=>{const n=t(e);r[n]||(r[n]=1,i.push(n))})),i.sort(K),e?i.slice(0,e):i}(i,t.limit||0,n);n.changed()&&t.set("__pivot__",null,null,!0);return{key:t.key,groupby:t.groupby,ops:u.map((()=>a)),fields:u.map((t=>function(t,n,r,i){return e((e=>n(e)===t?r(e):NaN),i,t+"")}(t,i,o,s))),as:u.map((t=>t+"")),modified:t.modified.bind(t)}}(t,n),n)}}),dt(Zu,Du,{transform(t,e){const n=t.subflow,i=t.field,o=t=>this.subflow(ya(t),n,e,t);return(t.modified("field")||i&&e.modified(r(i)))&&s("PreFacet does not support field modification."),this.initTargets(),i?(e.visit(e.MOD,(t=>{const e=o(t);i(t).forEach((t=>e.mod(t)))})),e.visit(e.ADD,(t=>{const e=o(t);i(t).forEach((t=>e.add(_a(t))))})),e.visit(e.REM,(t=>{const e=o(t);i(t).forEach((t=>e.rem(t)))}))):(e.visit(e.MOD,(t=>o(t).mod(t))),e.visit(e.ADD,(t=>o(t).add(t))),e.visit(e.REM,(t=>o(t).rem(t)))),e.clean()&&e.runAfter((()=>this.clean())),e}}),Qu.Definition={type:"Project",metadata:{generates:!0,changes:!0},params:[{name:"fields",type:"field",array:!0},{name:"as",type:"string",null:!0,array:!0}]},dt(Qu,Ja,{transform(t,e){const n=e.fork(e.NO_SOURCE),r=t.fields,i=xu(t.fields,t.as||[]),o=r?(t,e)=>function(t,e,n,r){for(let i=0,o=n.length;i{const e=ya(t);n.rem.push(a[e]),a[e]=null})),e.visit(e.ADD,(t=>{const e=o(t,_a({}));a[ya(t)]=e,n.add.push(e)})),e.visit(e.MOD,(t=>{n.mod.push(o(t,a[ya(t)]))})),n}}),dt(Ku,Ja,{transform(t,e){return this.value=t.value,t.modified("value")?e.fork(e.NO_SOURCE|e.NO_FIELDS):e.StopPropagation}}),tl.Definition={type:"Quantile",metadata:{generates:!0,changes:!0},params:[{name:"groupby",type:"field",array:!0},{name:"field",type:"field",required:!0},{name:"probs",type:"number",array:!0},{name:"step",type:"number",default:.01},{name:"as",type:"string",array:!0,default:["prob","value"]}]};function el(t){Ja.call(this,null,t)}function nl(t){Ja.call(this,[],t),this.count=0}function rl(t){Ja.call(this,null,t)}function il(t){Ja.call(this,null,t),this.modified(!0)}function ol(t){Ja.call(this,null,t)}dt(tl,Ja,{transform(t,e){const r=e.fork(e.NO_SOURCE|e.NO_FIELDS),i=t.as||["prob","value"];if(this.value&&!t.modified()&&!e.changed())return r.source=this.value,r;const o=bu(e.materialize(e.SOURCE).source,t.groupby,t.field),a=(t.groupby||[]).map(n),s=[],u=t.step||.01,l=t.probs||Se(u/2,1-1e-14,u),c=l.length;return o.forEach((t=>{const e=es(t,l);for(let n=0;n{const e=ya(t);n.rem.push(r[e]),r[e]=null})),e.visit(e.ADD,(t=>{const e=xa(t);r[ya(t)]=e,n.add.push(e)})),e.visit(e.MOD,(t=>{const e=r[ya(t)];for(const r in t)e[r]=t[r],n.modifies(r);n.mod.push(e)}))),n}}),nl.Definition={type:"Sample",metadata:{},params:[{name:"size",type:"number",default:1e3}]},dt(nl,Ja,{transform(e,n){const r=n.fork(n.NO_SOURCE),i=e.modified("size"),o=e.size,a=this.value.reduce(((t,e)=>(t[ya(e)]=1,t)),{});let s=this.value,u=this.count,l=0;function c(e){let n,i;s.length=l&&(n=s[i],a[ya(n)]&&r.rem.push(n),s[i]=e)),++u}if(n.rem.length&&(n.visit(n.REM,(t=>{const e=ya(t);a[e]&&(a[e]=-1,r.rem.push(t)),--u})),s=s.filter((t=>-1!==a[ya(t)]))),(n.rem.length||i)&&s.length{a[ya(t)]||c(t)})),l=-1),i&&s.length>o){const t=s.length-o;for(let e=0;e{a[ya(t)]&&r.mod.push(t)})),n.add.length&&n.visit(n.ADD,c),(n.add.length||l<0)&&(r.add=s.filter((t=>!a[ya(t)]))),this.count=u,this.value=r.source=s,r}}),rl.Definition={type:"Sequence",metadata:{generates:!0,changes:!0},params:[{name:"start",type:"number",required:!0},{name:"stop",type:"number",required:!0},{name:"step",type:"number",default:1},{name:"as",type:"string",default:"data"}]},dt(rl,Ja,{transform(t,e){if(this.value&&!t.modified())return;const n=e.materialize().fork(e.MOD),r=t.as||"data";return n.rem=this.value?e.rem.concat(this.value):e.rem,this.value=Se(t.start,t.stop,t.step||1).map((t=>{const e={};return e[r]=t,_a(e)})),n.add=e.add.concat(this.value),n}}),dt(il,Ja,{transform(t,e){return this.value=e.source,e.changed()?e.fork(e.NO_SOURCE|e.NO_FIELDS):e.StopPropagation}});const al=["unit0","unit1"];function sl(t){Ja.call(this,ft(),t)}function ul(t){Ja.call(this,null,t)}ol.Definition={type:"TimeUnit",metadata:{modifies:!0},params:[{name:"field",type:"field",required:!0},{name:"interval",type:"boolean",default:!0},{name:"units",type:"enum",values:Kn,array:!0},{name:"step",type:"number",default:1},{name:"maxbins",type:"number",default:40},{name:"extent",type:"date",array:!0},{name:"timezone",type:"enum",default:"local",values:["local","utc"]},{name:"as",type:"string",array:!0,length:2,default:al}]},dt(ol,Ja,{transform(t,e){const n=t.field,i=!1!==t.interval,o="utc"===t.timezone,a=this._floor(t,e),s=(o?Fr:Cr)(a.unit).offset,u=t.as||al,l=u[0],c=u[1],f=a.step;let h=a.start||1/0,d=a.stop||-1/0,p=e.ADD;return(t.modified()||e.changed(e.REM)||e.modified(r(n)))&&(p=(e=e.reflow(!0)).SOURCE,h=1/0,d=-1/0),e.visit(p,(t=>{const e=n(t);let r,o;null==e?(t[l]=null,i&&(t[c]=null)):(t[l]=r=o=a(e),i&&(t[c]=o=s(r,f)),rd&&(d=o))})),a.start=h,a.stop=d,e.modifies(i?u:l)},_floor(t,e){const n="utc"===t.timezone,{units:r,step:i}=t.units?{units:t.units,step:t.step||1}:Jr({extent:t.extent||at(e.materialize(e.SOURCE).source,t.field),maxbins:t.maxbins}),o=er(r),a=this.value||{},s=(n?Mr:wr)(o,i);return s.unit=F(o),s.units=o,s.step=i,s.start=a.start,s.stop=a.stop,this.value=s}}),dt(sl,Ja,{transform(t,e){const n=e.dataflow,r=t.field,i=this.value,o=t=>i.set(r(t),t);let a=!0;return t.modified("field")||e.modified(r.fields)?(i.clear(),e.visit(e.SOURCE,o)):e.changed()?(e.visit(e.REM,(t=>i.delete(r(t)))),e.visit(e.ADD,o)):a=!1,this.modified(a),i.empty>n.cleanThreshold&&n.runAfter(i.clean),e.fork()}}),dt(ul,Ja,{transform(t,e){(!this.value||t.modified("field")||t.modified("sort")||e.changed()||t.sort&&e.modified(t.sort.fields))&&(this.value=(t.sort?e.source.slice().sort(ka(t.sort)):e.source).map(t.field))}});const ll={row_number:function(){return{next:t=>t.index+1}},rank:function(){let t;return{init:()=>t=1,next:e=>{const n=e.index,r=e.data;return n&&e.compare(r[n-1],r[n])?t=n+1:t}}},dense_rank:function(){let t;return{init:()=>t=1,next:e=>{const n=e.index,r=e.data;return n&&e.compare(r[n-1],r[n])?++t:t}}},percent_rank:function(){const t=ll.rank(),e=t.next;return{init:t.init,next:t=>(e(t)-1)/(t.data.length-1)}},cume_dist:function(){let t;return{init:()=>t=0,next:e=>{const n=e.data,r=e.compare;let i=e.index;if(t0||s("ntile num must be greater than zero.");const n=ll.cume_dist(),r=n.next;return{init:n.init,next:t=>Math.ceil(e*r(t))}},lag:function(t,e){return e=+e||1,{next:n=>{const r=n.index-e;return r>=0?t(n.data[r]):null}}},lead:function(t,e){return e=+e||1,{next:n=>{const r=n.index+e,i=n.data;return rt(e.data[e.i0])}},last_value:function(t){return{next:e=>t(e.data[e.i1-1])}},nth_value:function(t,e){return(e=+e)>0||s("nth_value nth must be greater than zero."),{next:n=>{const r=n.i0+(e-1);return re=null,next:n=>{const r=t(n.data[n.index]);return null!=r?e=r:e}}},next_value:function(t){let e,n;return{init:()=>(e=null,n=-1),next:r=>{const i=r.data;return r.index<=n?e:(n=function(t,e,n){for(let r=e.length;nf[t]=1))}y(t.sort),e.forEach(((t,e)=>{const r=i[e],f=o[e],v=a[e]||null,_=n(r),x=Ys(t,_,u[e]);if(y(r),l.push(x),lt(ll,t))c.push(function(t,e,n,r){const i=ll[t](e,n);return{init:i.init||h,update:function(t,e){e[r]=i.next(t)}}}(t,r,f,x));else{if(null==r&&"count"!==t&&s("Null aggregate field specified."),"count"===t)return void p.push(x);m=!1;let e=d[_];e||(e=d[_]=[],e.field=r,g.push(e)),e.push(Zs(t,v,x))}})),(p.length||g.length)&&(this.cell=function(t,e,n){t=t.map((t=>ru(t,t.field)));const r={num:0,agg:null,store:!1,count:e};if(!n)for(var i=t.length,o=r.agg=Array(i),a=0;a0&&!i(o[n],o[n-1])&&(t.i0=e.left(o,o[n])),rt.init())),this.cell&&this.cell.init()},hl.update=function(t,e){const n=this.cell,r=this.windows,i=t.data,o=r&&r.length;let a;if(n){for(a=t.p0;athis.group(i(t));let a=this.state;a&&!n||(a=this.state=new fl(t)),n||e.modified(a.inputs)?(this.value={},e.visit(e.SOURCE,(t=>o(t).add(t)))):(e.visit(e.REM,(t=>o(t).remove(t))),e.visit(e.ADD,(t=>o(t).add(t))));for(let e=0,n=this._mlen;e=1?Cl:t<=-1?-Cl:Math.asin(t)}const $l=Math.PI,Tl=2*$l,Bl=1e-6,zl=Tl-Bl;function Nl(t){this._+=t[0];for(let e=1,n=t.length;e=0))throw new Error(`invalid digits: ${t}`);if(e>15)return Nl;const n=10**e;return function(t){this._+=t[0];for(let e=1,r=t.length;eBl)if(Math.abs(c*s-u*l)>Bl&&i){let h=n-o,d=r-a,p=s*s+u*u,g=h*h+d*d,m=Math.sqrt(p),y=Math.sqrt(f),v=i*Math.tan(($l-Math.acos((p+f-g)/(2*m*y)))/2),_=v/y,x=v/m;Math.abs(_-1)>Bl&&this._append`L${t+_*l},${e+_*c}`,this._append`A${i},${i},0,0,${+(c*h>l*d)},${this._x1=t+x*s},${this._y1=e+x*u}`}else this._append`L${this._x1=t},${this._y1=e}`;else;}arc(t,e,n,r,i,o){if(t=+t,e=+e,o=!!o,(n=+n)<0)throw new Error(`negative radius: ${n}`);let a=n*Math.cos(r),s=n*Math.sin(r),u=t+a,l=e+s,c=1^o,f=o?r-i:i-r;null===this._x1?this._append`M${u},${l}`:(Math.abs(this._x1-u)>Bl||Math.abs(this._y1-l)>Bl)&&this._append`L${u},${l}`,n&&(f<0&&(f=f%Tl+Tl),f>zl?this._append`A${n},${n},0,1,${c},${t-a},${e-s}A${n},${n},0,1,${c},${this._x1=u},${this._y1=l}`:f>Bl&&this._append`A${n},${n},0,${+(f>=$l)},${c},${this._x1=t+n*Math.cos(i)},${this._y1=e+n*Math.sin(i)}`)}rect(t,e,n,r){this._append`M${this._x0=this._x1=+t},${this._y0=this._y1=+e}h${n=+n}v${+r}h${-n}Z`}toString(){return this._}};function Rl(){return new Ol}function Ul(t){let e=3;return t.digits=function(n){if(!arguments.length)return e;if(null==n)e=null;else{const t=Math.floor(n);if(!(t>=0))throw new RangeError(`invalid digits: ${n}`);e=t}return t},()=>new Ol(e)}function Ll(t){return t.innerRadius}function ql(t){return t.outerRadius}function Pl(t){return t.startAngle}function jl(t){return t.endAngle}function Il(t){return t&&t.padAngle}function Wl(t,e,n,r,i,o,a){var s=t-n,u=e-r,l=(a?o:-o)/Ml(s*s+u*u),c=l*u,f=-l*s,h=t+c,d=e+f,p=n+c,g=r+f,m=(h+p)/2,y=(d+g)/2,v=p-h,_=g-d,x=v*v+_*_,b=i-o,w=h*g-p*d,k=(_<0?-1:1)*Ml(wl(0,b*b*x-w*w)),A=(w*_-v*k)/x,M=(-w*v-_*k)/x,E=(w*_+v*k)/x,D=(-w*v+_*k)/x,C=A-m,F=M-y,S=E-m,$=D-y;return C*C+F*F>S*S+$*$&&(A=E,M=D),{cx:A,cy:M,x01:-c,y01:-f,x11:A*(i/b-1),y11:M*(i/b-1)}}function Hl(t){return"object"==typeof t&&"length"in t?t:Array.from(t)}function Yl(t){this._context=t}function Gl(t){return new Yl(t)}function Vl(t){return t[0]}function Xl(t){return t[1]}function Jl(t,e){var n=vl(!0),r=null,i=Gl,o=null,a=Ul(s);function s(s){var u,l,c,f=(s=Hl(s)).length,h=!1;for(null==r&&(o=i(c=a())),u=0;u<=f;++u)!(u=f;--h)s.point(y[h],v[h]);s.lineEnd(),s.areaEnd()}m&&(y[c]=+t(d,c,l),v[c]=+e(d,c,l),s.point(r?+r(d,c,l):y[c],n?+n(d,c,l):v[c]))}if(p)return s=null,p+""||null}function c(){return Jl().defined(i).curve(a).context(o)}return t="function"==typeof t?t:void 0===t?Vl:vl(+t),e="function"==typeof e?e:vl(void 0===e?0:+e),n="function"==typeof n?n:void 0===n?Xl:vl(+n),l.x=function(e){return arguments.length?(t="function"==typeof e?e:vl(+e),r=null,l):t},l.x0=function(e){return arguments.length?(t="function"==typeof e?e:vl(+e),l):t},l.x1=function(t){return arguments.length?(r=null==t?null:"function"==typeof t?t:vl(+t),l):r},l.y=function(t){return arguments.length?(e="function"==typeof t?t:vl(+t),n=null,l):e},l.y0=function(t){return arguments.length?(e="function"==typeof t?t:vl(+t),l):e},l.y1=function(t){return arguments.length?(n=null==t?null:"function"==typeof t?t:vl(+t),l):n},l.lineX0=l.lineY0=function(){return c().x(t).y(e)},l.lineY1=function(){return c().x(t).y(n)},l.lineX1=function(){return c().x(r).y(e)},l.defined=function(t){return arguments.length?(i="function"==typeof t?t:vl(!!t),l):i},l.curve=function(t){return arguments.length?(a=t,null!=o&&(s=a(o)),l):a},l.context=function(t){return arguments.length?(null==t?o=s=null:s=a(o=t),l):o},l}Rl.prototype=Ol.prototype,Yl.prototype={areaStart:function(){this._line=0},areaEnd:function(){this._line=NaN},lineStart:function(){this._point=0},lineEnd:function(){(this._line||0!==this._line&&1===this._point)&&this._context.closePath(),this._line=1-this._line},point:function(t,e){switch(t=+t,e=+e,this._point){case 0:this._point=1,this._line?this._context.lineTo(t,e):this._context.moveTo(t,e);break;case 1:this._point=2;default:this._context.lineTo(t,e)}}};var Ql={draw(t,e){const n=Ml(e/Dl);t.moveTo(n,0),t.arc(0,0,n,0,Fl)}};function Kl(){}function tc(t,e,n){t._context.bezierCurveTo((2*t._x0+t._x1)/3,(2*t._y0+t._y1)/3,(t._x0+2*t._x1)/3,(t._y0+2*t._y1)/3,(t._x0+4*t._x1+e)/6,(t._y0+4*t._y1+n)/6)}function ec(t){this._context=t}function nc(t){this._context=t}function rc(t){this._context=t}function ic(t,e){this._basis=new ec(t),this._beta=e}ec.prototype={areaStart:function(){this._line=0},areaEnd:function(){this._line=NaN},lineStart:function(){this._x0=this._x1=this._y0=this._y1=NaN,this._point=0},lineEnd:function(){switch(this._point){case 3:tc(this,this._x1,this._y1);case 2:this._context.lineTo(this._x1,this._y1)}(this._line||0!==this._line&&1===this._point)&&this._context.closePath(),this._line=1-this._line},point:function(t,e){switch(t=+t,e=+e,this._point){case 0:this._point=1,this._line?this._context.lineTo(t,e):this._context.moveTo(t,e);break;case 1:this._point=2;break;case 2:this._point=3,this._context.lineTo((5*this._x0+this._x1)/6,(5*this._y0+this._y1)/6);default:tc(this,t,e)}this._x0=this._x1,this._x1=t,this._y0=this._y1,this._y1=e}},nc.prototype={areaStart:Kl,areaEnd:Kl,lineStart:function(){this._x0=this._x1=this._x2=this._x3=this._x4=this._y0=this._y1=this._y2=this._y3=this._y4=NaN,this._point=0},lineEnd:function(){switch(this._point){case 1:this._context.moveTo(this._x2,this._y2),this._context.closePath();break;case 2:this._context.moveTo((this._x2+2*this._x3)/3,(this._y2+2*this._y3)/3),this._context.lineTo((this._x3+2*this._x2)/3,(this._y3+2*this._y2)/3),this._context.closePath();break;case 3:this.point(this._x2,this._y2),this.point(this._x3,this._y3),this.point(this._x4,this._y4)}},point:function(t,e){switch(t=+t,e=+e,this._point){case 0:this._point=1,this._x2=t,this._y2=e;break;case 1:this._point=2,this._x3=t,this._y3=e;break;case 2:this._point=3,this._x4=t,this._y4=e,this._context.moveTo((this._x0+4*this._x1+t)/6,(this._y0+4*this._y1+e)/6);break;default:tc(this,t,e)}this._x0=this._x1,this._x1=t,this._y0=this._y1,this._y1=e}},rc.prototype={areaStart:function(){this._line=0},areaEnd:function(){this._line=NaN},lineStart:function(){this._x0=this._x1=this._y0=this._y1=NaN,this._point=0},lineEnd:function(){(this._line||0!==this._line&&3===this._point)&&this._context.closePath(),this._line=1-this._line},point:function(t,e){switch(t=+t,e=+e,this._point){case 0:this._point=1;break;case 1:this._point=2;break;case 2:this._point=3;var n=(this._x0+4*this._x1+t)/6,r=(this._y0+4*this._y1+e)/6;this._line?this._context.lineTo(n,r):this._context.moveTo(n,r);break;case 3:this._point=4;default:tc(this,t,e)}this._x0=this._x1,this._x1=t,this._y0=this._y1,this._y1=e}},ic.prototype={lineStart:function(){this._x=[],this._y=[],this._basis.lineStart()},lineEnd:function(){var t=this._x,e=this._y,n=t.length-1;if(n>0)for(var r,i=t[0],o=e[0],a=t[n]-i,s=e[n]-o,u=-1;++u<=n;)r=u/n,this._basis.point(this._beta*t[u]+(1-this._beta)*(i+r*a),this._beta*e[u]+(1-this._beta)*(o+r*s));this._x=this._y=null,this._basis.lineEnd()},point:function(t,e){this._x.push(+t),this._y.push(+e)}};var oc=function t(e){function n(t){return 1===e?new ec(t):new ic(t,e)}return n.beta=function(e){return t(+e)},n}(.85);function ac(t,e,n){t._context.bezierCurveTo(t._x1+t._k*(t._x2-t._x0),t._y1+t._k*(t._y2-t._y0),t._x2+t._k*(t._x1-e),t._y2+t._k*(t._y1-n),t._x2,t._y2)}function sc(t,e){this._context=t,this._k=(1-e)/6}sc.prototype={areaStart:function(){this._line=0},areaEnd:function(){this._line=NaN},lineStart:function(){this._x0=this._x1=this._x2=this._y0=this._y1=this._y2=NaN,this._point=0},lineEnd:function(){switch(this._point){case 2:this._context.lineTo(this._x2,this._y2);break;case 3:ac(this,this._x1,this._y1)}(this._line||0!==this._line&&1===this._point)&&this._context.closePath(),this._line=1-this._line},point:function(t,e){switch(t=+t,e=+e,this._point){case 0:this._point=1,this._line?this._context.lineTo(t,e):this._context.moveTo(t,e);break;case 1:this._point=2,this._x1=t,this._y1=e;break;case 2:this._point=3;default:ac(this,t,e)}this._x0=this._x1,this._x1=this._x2,this._x2=t,this._y0=this._y1,this._y1=this._y2,this._y2=e}};var uc=function t(e){function n(t){return new sc(t,e)}return n.tension=function(e){return t(+e)},n}(0);function lc(t,e){this._context=t,this._k=(1-e)/6}lc.prototype={areaStart:Kl,areaEnd:Kl,lineStart:function(){this._x0=this._x1=this._x2=this._x3=this._x4=this._x5=this._y0=this._y1=this._y2=this._y3=this._y4=this._y5=NaN,this._point=0},lineEnd:function(){switch(this._point){case 1:this._context.moveTo(this._x3,this._y3),this._context.closePath();break;case 2:this._context.lineTo(this._x3,this._y3),this._context.closePath();break;case 3:this.point(this._x3,this._y3),this.point(this._x4,this._y4),this.point(this._x5,this._y5)}},point:function(t,e){switch(t=+t,e=+e,this._point){case 0:this._point=1,this._x3=t,this._y3=e;break;case 1:this._point=2,this._context.moveTo(this._x4=t,this._y4=e);break;case 2:this._point=3,this._x5=t,this._y5=e;break;default:ac(this,t,e)}this._x0=this._x1,this._x1=this._x2,this._x2=t,this._y0=this._y1,this._y1=this._y2,this._y2=e}};var cc=function t(e){function n(t){return new lc(t,e)}return n.tension=function(e){return t(+e)},n}(0);function fc(t,e){this._context=t,this._k=(1-e)/6}fc.prototype={areaStart:function(){this._line=0},areaEnd:function(){this._line=NaN},lineStart:function(){this._x0=this._x1=this._x2=this._y0=this._y1=this._y2=NaN,this._point=0},lineEnd:function(){(this._line||0!==this._line&&3===this._point)&&this._context.closePath(),this._line=1-this._line},point:function(t,e){switch(t=+t,e=+e,this._point){case 0:this._point=1;break;case 1:this._point=2;break;case 2:this._point=3,this._line?this._context.lineTo(this._x2,this._y2):this._context.moveTo(this._x2,this._y2);break;case 3:this._point=4;default:ac(this,t,e)}this._x0=this._x1,this._x1=this._x2,this._x2=t,this._y0=this._y1,this._y1=this._y2,this._y2=e}};var hc=function t(e){function n(t){return new fc(t,e)}return n.tension=function(e){return t(+e)},n}(0);function dc(t,e,n){var r=t._x1,i=t._y1,o=t._x2,a=t._y2;if(t._l01_a>El){var s=2*t._l01_2a+3*t._l01_a*t._l12_a+t._l12_2a,u=3*t._l01_a*(t._l01_a+t._l12_a);r=(r*s-t._x0*t._l12_2a+t._x2*t._l01_2a)/u,i=(i*s-t._y0*t._l12_2a+t._y2*t._l01_2a)/u}if(t._l23_a>El){var l=2*t._l23_2a+3*t._l23_a*t._l12_a+t._l12_2a,c=3*t._l23_a*(t._l23_a+t._l12_a);o=(o*l+t._x1*t._l23_2a-e*t._l12_2a)/c,a=(a*l+t._y1*t._l23_2a-n*t._l12_2a)/c}t._context.bezierCurveTo(r,i,o,a,t._x2,t._y2)}function pc(t,e){this._context=t,this._alpha=e}pc.prototype={areaStart:function(){this._line=0},areaEnd:function(){this._line=NaN},lineStart:function(){this._x0=this._x1=this._x2=this._y0=this._y1=this._y2=NaN,this._l01_a=this._l12_a=this._l23_a=this._l01_2a=this._l12_2a=this._l23_2a=this._point=0},lineEnd:function(){switch(this._point){case 2:this._context.lineTo(this._x2,this._y2);break;case 3:this.point(this._x2,this._y2)}(this._line||0!==this._line&&1===this._point)&&this._context.closePath(),this._line=1-this._line},point:function(t,e){if(t=+t,e=+e,this._point){var n=this._x2-t,r=this._y2-e;this._l23_a=Math.sqrt(this._l23_2a=Math.pow(n*n+r*r,this._alpha))}switch(this._point){case 0:this._point=1,this._line?this._context.lineTo(t,e):this._context.moveTo(t,e);break;case 1:this._point=2;break;case 2:this._point=3;default:dc(this,t,e)}this._l01_a=this._l12_a,this._l12_a=this._l23_a,this._l01_2a=this._l12_2a,this._l12_2a=this._l23_2a,this._x0=this._x1,this._x1=this._x2,this._x2=t,this._y0=this._y1,this._y1=this._y2,this._y2=e}};var gc=function t(e){function n(t){return e?new pc(t,e):new sc(t,0)}return n.alpha=function(e){return t(+e)},n}(.5);function mc(t,e){this._context=t,this._alpha=e}mc.prototype={areaStart:Kl,areaEnd:Kl,lineStart:function(){this._x0=this._x1=this._x2=this._x3=this._x4=this._x5=this._y0=this._y1=this._y2=this._y3=this._y4=this._y5=NaN,this._l01_a=this._l12_a=this._l23_a=this._l01_2a=this._l12_2a=this._l23_2a=this._point=0},lineEnd:function(){switch(this._point){case 1:this._context.moveTo(this._x3,this._y3),this._context.closePath();break;case 2:this._context.lineTo(this._x3,this._y3),this._context.closePath();break;case 3:this.point(this._x3,this._y3),this.point(this._x4,this._y4),this.point(this._x5,this._y5)}},point:function(t,e){if(t=+t,e=+e,this._point){var n=this._x2-t,r=this._y2-e;this._l23_a=Math.sqrt(this._l23_2a=Math.pow(n*n+r*r,this._alpha))}switch(this._point){case 0:this._point=1,this._x3=t,this._y3=e;break;case 1:this._point=2,this._context.moveTo(this._x4=t,this._y4=e);break;case 2:this._point=3,this._x5=t,this._y5=e;break;default:dc(this,t,e)}this._l01_a=this._l12_a,this._l12_a=this._l23_a,this._l01_2a=this._l12_2a,this._l12_2a=this._l23_2a,this._x0=this._x1,this._x1=this._x2,this._x2=t,this._y0=this._y1,this._y1=this._y2,this._y2=e}};var yc=function t(e){function n(t){return e?new mc(t,e):new lc(t,0)}return n.alpha=function(e){return t(+e)},n}(.5);function vc(t,e){this._context=t,this._alpha=e}vc.prototype={areaStart:function(){this._line=0},areaEnd:function(){this._line=NaN},lineStart:function(){this._x0=this._x1=this._x2=this._y0=this._y1=this._y2=NaN,this._l01_a=this._l12_a=this._l23_a=this._l01_2a=this._l12_2a=this._l23_2a=this._point=0},lineEnd:function(){(this._line||0!==this._line&&3===this._point)&&this._context.closePath(),this._line=1-this._line},point:function(t,e){if(t=+t,e=+e,this._point){var n=this._x2-t,r=this._y2-e;this._l23_a=Math.sqrt(this._l23_2a=Math.pow(n*n+r*r,this._alpha))}switch(this._point){case 0:this._point=1;break;case 1:this._point=2;break;case 2:this._point=3,this._line?this._context.lineTo(this._x2,this._y2):this._context.moveTo(this._x2,this._y2);break;case 3:this._point=4;default:dc(this,t,e)}this._l01_a=this._l12_a,this._l12_a=this._l23_a,this._l01_2a=this._l12_2a,this._l12_2a=this._l23_2a,this._x0=this._x1,this._x1=this._x2,this._x2=t,this._y0=this._y1,this._y1=this._y2,this._y2=e}};var _c=function t(e){function n(t){return e?new vc(t,e):new fc(t,0)}return n.alpha=function(e){return t(+e)},n}(.5);function xc(t){this._context=t}function bc(t){return t<0?-1:1}function wc(t,e,n){var r=t._x1-t._x0,i=e-t._x1,o=(t._y1-t._y0)/(r||i<0&&-0),a=(n-t._y1)/(i||r<0&&-0),s=(o*i+a*r)/(r+i);return(bc(o)+bc(a))*Math.min(Math.abs(o),Math.abs(a),.5*Math.abs(s))||0}function kc(t,e){var n=t._x1-t._x0;return n?(3*(t._y1-t._y0)/n-e)/2:e}function Ac(t,e,n){var r=t._x0,i=t._y0,o=t._x1,a=t._y1,s=(o-r)/3;t._context.bezierCurveTo(r+s,i+s*e,o-s,a-s*n,o,a)}function Mc(t){this._context=t}function Ec(t){this._context=new Dc(t)}function Dc(t){this._context=t}function Cc(t){this._context=t}function Fc(t){var e,n,r=t.length-1,i=new Array(r),o=new Array(r),a=new Array(r);for(i[0]=0,o[0]=2,a[0]=t[0]+2*t[1],e=1;e=0;--e)i[e]=(a[e]-i[e+1])/o[e];for(o[r-1]=(t[r]+i[r-1])/2,e=0;e=0&&(this._t=1-this._t,this._line=1-this._line)},point:function(t,e){switch(t=+t,e=+e,this._point){case 0:this._point=1,this._line?this._context.lineTo(t,e):this._context.moveTo(t,e);break;case 1:this._point=2;default:if(this._t<=0)this._context.lineTo(this._x,e),this._context.lineTo(t,e);else{var n=this._x*(1-this._t)+t*this._t;this._context.lineTo(n,this._y),this._context.lineTo(n,e)}}this._x=t,this._y=e}};const Tc=()=>"undefined"!=typeof Image?Image:null;function Bc(t,e){switch(arguments.length){case 0:break;case 1:this.range(t);break;default:this.range(e).domain(t)}return this}function zc(t,e){switch(arguments.length){case 0:break;case 1:"function"==typeof t?this.interpolator(t):this.range(t);break;default:this.domain(t),"function"==typeof e?this.interpolator(e):this.range(e)}return this}const Nc=Symbol("implicit");function Oc(){var t=new ue,e=[],n=[],r=Nc;function i(i){let o=t.get(i);if(void 0===o){if(r!==Nc)return r;t.set(i,o=e.push(i)-1)}return n[o%n.length]}return i.domain=function(n){if(!arguments.length)return e.slice();e=[],t=new ue;for(const r of n)t.has(r)||t.set(r,e.push(r)-1);return i},i.range=function(t){return arguments.length?(n=Array.from(t),i):n.slice()},i.unknown=function(t){return arguments.length?(r=t,i):r},i.copy=function(){return Oc(e,n).unknown(r)},Bc.apply(i,arguments),i}function Rc(t,e,n){t.prototype=e.prototype=n,n.constructor=t}function Uc(t,e){var n=Object.create(t.prototype);for(var r in e)n[r]=e[r];return n}function Lc(){}var qc=.7,Pc=1/qc,jc="\\s*([+-]?\\d+)\\s*",Ic="\\s*([+-]?(?:\\d*\\.)?\\d+(?:[eE][+-]?\\d+)?)\\s*",Wc="\\s*([+-]?(?:\\d*\\.)?\\d+(?:[eE][+-]?\\d+)?)%\\s*",Hc=/^#([0-9a-f]{3,8})$/,Yc=new RegExp(`^rgb\\(${jc},${jc},${jc}\\)$`),Gc=new RegExp(`^rgb\\(${Wc},${Wc},${Wc}\\)$`),Vc=new RegExp(`^rgba\\(${jc},${jc},${jc},${Ic}\\)$`),Xc=new RegExp(`^rgba\\(${Wc},${Wc},${Wc},${Ic}\\)$`),Jc=new RegExp(`^hsl\\(${Ic},${Wc},${Wc}\\)$`),Zc=new RegExp(`^hsla\\(${Ic},${Wc},${Wc},${Ic}\\)$`),Qc={aliceblue:15792383,antiquewhite:16444375,aqua:65535,aquamarine:8388564,azure:15794175,beige:16119260,bisque:16770244,black:0,blanchedalmond:16772045,blue:255,blueviolet:9055202,brown:10824234,burlywood:14596231,cadetblue:6266528,chartreuse:8388352,chocolate:13789470,coral:16744272,cornflowerblue:6591981,cornsilk:16775388,crimson:14423100,cyan:65535,darkblue:139,darkcyan:35723,darkgoldenrod:12092939,darkgray:11119017,darkgreen:25600,darkgrey:11119017,darkkhaki:12433259,darkmagenta:9109643,darkolivegreen:5597999,darkorange:16747520,darkorchid:10040012,darkred:9109504,darksalmon:15308410,darkseagreen:9419919,darkslateblue:4734347,darkslategray:3100495,darkslategrey:3100495,darkturquoise:52945,darkviolet:9699539,deeppink:16716947,deepskyblue:49151,dimgray:6908265,dimgrey:6908265,dodgerblue:2003199,firebrick:11674146,floralwhite:16775920,forestgreen:2263842,fuchsia:16711935,gainsboro:14474460,ghostwhite:16316671,gold:16766720,goldenrod:14329120,gray:8421504,green:32768,greenyellow:11403055,grey:8421504,honeydew:15794160,hotpink:16738740,indianred:13458524,indigo:4915330,ivory:16777200,khaki:15787660,lavender:15132410,lavenderblush:16773365,lawngreen:8190976,lemonchiffon:16775885,lightblue:11393254,lightcoral:15761536,lightcyan:14745599,lightgoldenrodyellow:16448210,lightgray:13882323,lightgreen:9498256,lightgrey:13882323,lightpink:16758465,lightsalmon:16752762,lightseagreen:2142890,lightskyblue:8900346,lightslategray:7833753,lightslategrey:7833753,lightsteelblue:11584734,lightyellow:16777184,lime:65280,limegreen:3329330,linen:16445670,magenta:16711935,maroon:8388608,mediumaquamarine:6737322,mediumblue:205,mediumorchid:12211667,mediumpurple:9662683,mediumseagreen:3978097,mediumslateblue:8087790,mediumspringgreen:64154,mediumturquoise:4772300,mediumvioletred:13047173,midnightblue:1644912,mintcream:16121850,mistyrose:16770273,moccasin:16770229,navajowhite:16768685,navy:128,oldlace:16643558,olive:8421376,olivedrab:7048739,orange:16753920,orangered:16729344,orchid:14315734,palegoldenrod:15657130,palegreen:10025880,paleturquoise:11529966,palevioletred:14381203,papayawhip:16773077,peachpuff:16767673,peru:13468991,pink:16761035,plum:14524637,powderblue:11591910,purple:8388736,rebeccapurple:6697881,red:16711680,rosybrown:12357519,royalblue:4286945,saddlebrown:9127187,salmon:16416882,sandybrown:16032864,seagreen:3050327,seashell:16774638,sienna:10506797,silver:12632256,skyblue:8900331,slateblue:6970061,slategray:7372944,slategrey:7372944,snow:16775930,springgreen:65407,steelblue:4620980,tan:13808780,teal:32896,thistle:14204888,tomato:16737095,turquoise:4251856,violet:15631086,wheat:16113331,white:16777215,whitesmoke:16119285,yellow:16776960,yellowgreen:10145074};function Kc(){return this.rgb().formatHex()}function tf(){return this.rgb().formatRgb()}function ef(t){var e,n;return t=(t+"").trim().toLowerCase(),(e=Hc.exec(t))?(n=e[1].length,e=parseInt(e[1],16),6===n?nf(e):3===n?new sf(e>>8&15|e>>4&240,e>>4&15|240&e,(15&e)<<4|15&e,1):8===n?rf(e>>24&255,e>>16&255,e>>8&255,(255&e)/255):4===n?rf(e>>12&15|e>>8&240,e>>8&15|e>>4&240,e>>4&15|240&e,((15&e)<<4|15&e)/255):null):(e=Yc.exec(t))?new sf(e[1],e[2],e[3],1):(e=Gc.exec(t))?new sf(255*e[1]/100,255*e[2]/100,255*e[3]/100,1):(e=Vc.exec(t))?rf(e[1],e[2],e[3],e[4]):(e=Xc.exec(t))?rf(255*e[1]/100,255*e[2]/100,255*e[3]/100,e[4]):(e=Jc.exec(t))?df(e[1],e[2]/100,e[3]/100,1):(e=Zc.exec(t))?df(e[1],e[2]/100,e[3]/100,e[4]):Qc.hasOwnProperty(t)?nf(Qc[t]):"transparent"===t?new sf(NaN,NaN,NaN,0):null}function nf(t){return new sf(t>>16&255,t>>8&255,255&t,1)}function rf(t,e,n,r){return r<=0&&(t=e=n=NaN),new sf(t,e,n,r)}function of(t){return t instanceof Lc||(t=ef(t)),t?new sf((t=t.rgb()).r,t.g,t.b,t.opacity):new sf}function af(t,e,n,r){return 1===arguments.length?of(t):new sf(t,e,n,null==r?1:r)}function sf(t,e,n,r){this.r=+t,this.g=+e,this.b=+n,this.opacity=+r}function uf(){return`#${hf(this.r)}${hf(this.g)}${hf(this.b)}`}function lf(){const t=cf(this.opacity);return`${1===t?"rgb(":"rgba("}${ff(this.r)}, ${ff(this.g)}, ${ff(this.b)}${1===t?")":`, ${t})`}`}function cf(t){return isNaN(t)?1:Math.max(0,Math.min(1,t))}function ff(t){return Math.max(0,Math.min(255,Math.round(t)||0))}function hf(t){return((t=ff(t))<16?"0":"")+t.toString(16)}function df(t,e,n,r){return r<=0?t=e=n=NaN:n<=0||n>=1?t=e=NaN:e<=0&&(t=NaN),new mf(t,e,n,r)}function pf(t){if(t instanceof mf)return new mf(t.h,t.s,t.l,t.opacity);if(t instanceof Lc||(t=ef(t)),!t)return new mf;if(t instanceof mf)return t;var e=(t=t.rgb()).r/255,n=t.g/255,r=t.b/255,i=Math.min(e,n,r),o=Math.max(e,n,r),a=NaN,s=o-i,u=(o+i)/2;return s?(a=e===o?(n-r)/s+6*(n0&&u<1?0:a,new mf(a,s,u,t.opacity)}function gf(t,e,n,r){return 1===arguments.length?pf(t):new mf(t,e,n,null==r?1:r)}function mf(t,e,n,r){this.h=+t,this.s=+e,this.l=+n,this.opacity=+r}function yf(t){return(t=(t||0)%360)<0?t+360:t}function vf(t){return Math.max(0,Math.min(1,t||0))}function _f(t,e,n){return 255*(t<60?e+(n-e)*t/60:t<180?n:t<240?e+(n-e)*(240-t)/60:e)}Rc(Lc,ef,{copy(t){return Object.assign(new this.constructor,this,t)},displayable(){return this.rgb().displayable()},hex:Kc,formatHex:Kc,formatHex8:function(){return this.rgb().formatHex8()},formatHsl:function(){return pf(this).formatHsl()},formatRgb:tf,toString:tf}),Rc(sf,af,Uc(Lc,{brighter(t){return t=null==t?Pc:Math.pow(Pc,t),new sf(this.r*t,this.g*t,this.b*t,this.opacity)},darker(t){return t=null==t?qc:Math.pow(qc,t),new sf(this.r*t,this.g*t,this.b*t,this.opacity)},rgb(){return this},clamp(){return new sf(ff(this.r),ff(this.g),ff(this.b),cf(this.opacity))},displayable(){return-.5<=this.r&&this.r<255.5&&-.5<=this.g&&this.g<255.5&&-.5<=this.b&&this.b<255.5&&0<=this.opacity&&this.opacity<=1},hex:uf,formatHex:uf,formatHex8:function(){return`#${hf(this.r)}${hf(this.g)}${hf(this.b)}${hf(255*(isNaN(this.opacity)?1:this.opacity))}`},formatRgb:lf,toString:lf})),Rc(mf,gf,Uc(Lc,{brighter(t){return t=null==t?Pc:Math.pow(Pc,t),new mf(this.h,this.s,this.l*t,this.opacity)},darker(t){return t=null==t?qc:Math.pow(qc,t),new mf(this.h,this.s,this.l*t,this.opacity)},rgb(){var t=this.h%360+360*(this.h<0),e=isNaN(t)||isNaN(this.s)?0:this.s,n=this.l,r=n+(n<.5?n:1-n)*e,i=2*n-r;return new sf(_f(t>=240?t-240:t+120,i,r),_f(t,i,r),_f(t<120?t+240:t-120,i,r),this.opacity)},clamp(){return new mf(yf(this.h),vf(this.s),vf(this.l),cf(this.opacity))},displayable(){return(0<=this.s&&this.s<=1||isNaN(this.s))&&0<=this.l&&this.l<=1&&0<=this.opacity&&this.opacity<=1},formatHsl(){const t=cf(this.opacity);return`${1===t?"hsl(":"hsla("}${yf(this.h)}, ${100*vf(this.s)}%, ${100*vf(this.l)}%${1===t?")":`, ${t})`}`}}));const xf=Math.PI/180,bf=180/Math.PI,wf=.96422,kf=1,Af=.82521,Mf=4/29,Ef=6/29,Df=3*Ef*Ef,Cf=Ef*Ef*Ef;function Ff(t){if(t instanceof $f)return new $f(t.l,t.a,t.b,t.opacity);if(t instanceof Rf)return Uf(t);t instanceof sf||(t=of(t));var e,n,r=Nf(t.r),i=Nf(t.g),o=Nf(t.b),a=Tf((.2225045*r+.7168786*i+.0606169*o)/kf);return r===i&&i===o?e=n=a:(e=Tf((.4360747*r+.3850649*i+.1430804*o)/wf),n=Tf((.0139322*r+.0971045*i+.7141733*o)/Af)),new $f(116*a-16,500*(e-a),200*(a-n),t.opacity)}function Sf(t,e,n,r){return 1===arguments.length?Ff(t):new $f(t,e,n,null==r?1:r)}function $f(t,e,n,r){this.l=+t,this.a=+e,this.b=+n,this.opacity=+r}function Tf(t){return t>Cf?Math.pow(t,1/3):t/Df+Mf}function Bf(t){return t>Ef?t*t*t:Df*(t-Mf)}function zf(t){return 255*(t<=.0031308?12.92*t:1.055*Math.pow(t,1/2.4)-.055)}function Nf(t){return(t/=255)<=.04045?t/12.92:Math.pow((t+.055)/1.055,2.4)}function Of(t,e,n,r){return 1===arguments.length?function(t){if(t instanceof Rf)return new Rf(t.h,t.c,t.l,t.opacity);if(t instanceof $f||(t=Ff(t)),0===t.a&&0===t.b)return new Rf(NaN,0=1?(n=1,e-1):Math.floor(n*e),i=t[r],o=t[r+1],a=r>0?t[r-1]:2*i-o,s=r()=>t;function Kf(t,e){return function(n){return t+n*e}}function th(t,e){var n=e-t;return n?Kf(t,n>180||n<-180?n-360*Math.round(n/360):n):Qf(isNaN(t)?e:t)}function eh(t){return 1==(t=+t)?nh:function(e,n){return n-e?function(t,e,n){return t=Math.pow(t,n),e=Math.pow(e,n)-t,n=1/n,function(r){return Math.pow(t+r*e,n)}}(e,n,t):Qf(isNaN(e)?n:e)}}function nh(t,e){var n=e-t;return n?Kf(t,n):Qf(isNaN(t)?e:t)}var rh=function t(e){var n=eh(e);function r(t,e){var r=n((t=af(t)).r,(e=af(e)).r),i=n(t.g,e.g),o=n(t.b,e.b),a=nh(t.opacity,e.opacity);return function(e){return t.r=r(e),t.g=i(e),t.b=o(e),t.opacity=a(e),t+""}}return r.gamma=t,r}(1);function ih(t){return function(e){var n,r,i=e.length,o=new Array(i),a=new Array(i),s=new Array(i);for(n=0;no&&(i=e.slice(o,i),s[a]?s[a]+=i:s[++a]=i),(n=n[0])===(r=r[0])?s[a]?s[a]+=r:s[++a]=r:(s[++a]=null,u.push({i:a,x:fh(n,r)})),o=ph.lastIndex;return o180?e+=360:e-t>180&&(t+=360),o.push({i:n.push(i(n)+"rotate(",null,r)-2,x:fh(t,e)})):e&&n.push(i(n)+"rotate("+e+r)}(o.rotate,a.rotate,s,u),function(t,e,n,o){t!==e?o.push({i:n.push(i(n)+"skewX(",null,r)-2,x:fh(t,e)}):e&&n.push(i(n)+"skewX("+e+r)}(o.skewX,a.skewX,s,u),function(t,e,n,r,o,a){if(t!==n||e!==r){var s=o.push(i(o)+"scale(",null,",",null,")");a.push({i:s-4,x:fh(t,n)},{i:s-2,x:fh(e,r)})}else 1===n&&1===r||o.push(i(o)+"scale("+n+","+r+")")}(o.scaleX,o.scaleY,a.scaleX,a.scaleY,s,u),o=a=null,function(t){for(var e,n=-1,r=u.length;++ne&&(n=t,t=e,e=n),function(n){return Math.max(t,Math.min(e,n))}}(a[0],a[t-1])),r=t>2?Ih:jh,i=o=null,f}function f(e){return null==e||isNaN(e=+e)?n:(i||(i=r(a.map(t),s,u)))(t(l(e)))}return f.invert=function(n){return l(e((o||(o=r(s,a.map(t),fh)))(n)))},f.domain=function(t){return arguments.length?(a=Array.from(t,Uh),c()):a.slice()},f.range=function(t){return arguments.length?(s=Array.from(t),c()):s.slice()},f.rangeRound=function(t){return s=Array.from(t),u=yh,c()},f.clamp=function(t){return arguments.length?(l=!!t||qh,c()):l!==qh},f.interpolate=function(t){return arguments.length?(u=t,c()):u},f.unknown=function(t){return arguments.length?(n=t,f):n},function(n,r){return t=n,e=r,c()}}function Yh(){return Hh()(qh,qh)}function Gh(t,e,n,r){var i,o=be(t,e,n);switch((r=Re(null==r?",f":r)).type){case"s":var a=Math.max(Math.abs(t),Math.abs(e));return null!=r.precision||isNaN(i=Xe(o,a))||(r.precision=i),We(r,a);case"":case"e":case"g":case"p":case"r":null!=r.precision||isNaN(i=Je(o,Math.max(Math.abs(t),Math.abs(e))))||(r.precision=i-("e"===r.type));break;case"f":case"%":null!=r.precision||isNaN(i=Ve(o))||(r.precision=i-2*("%"===r.type))}return Ie(r)}function Vh(t){var e=t.domain;return t.ticks=function(t){var n=e();return _e(n[0],n[n.length-1],null==t?10:t)},t.tickFormat=function(t,n){var r=e();return Gh(r[0],r[r.length-1],null==t?10:t,n)},t.nice=function(n){null==n&&(n=10);var r,i,o=e(),a=0,s=o.length-1,u=o[a],l=o[s],c=10;for(l0;){if((i=xe(u,l,n))===r)return o[a]=u,o[s]=l,e(o);if(i>0)u=Math.floor(u/i)*i,l=Math.ceil(l/i)*i;else{if(!(i<0))break;u=Math.ceil(u*i)/i,l=Math.floor(l*i)/i}r=i}return t},t}function Xh(t,e){var n,r=0,i=(t=t.slice()).length-1,o=t[r],a=t[i];return a-t(-e,n)}function nd(t){const e=t(Jh,Zh),n=e.domain;let r,i,o=10;function a(){return r=function(t){return t===Math.E?Math.log:10===t&&Math.log10||2===t&&Math.log2||(t=Math.log(t),e=>Math.log(e)/t)}(o),i=function(t){return 10===t?td:t===Math.E?Math.exp:e=>Math.pow(t,e)}(o),n()[0]<0?(r=ed(r),i=ed(i),t(Qh,Kh)):t(Jh,Zh),e}return e.base=function(t){return arguments.length?(o=+t,a()):o},e.domain=function(t){return arguments.length?(n(t),a()):n()},e.ticks=t=>{const e=n();let a=e[0],s=e[e.length-1];const u=s0){for(;f<=h;++f)for(l=1;ls)break;p.push(c)}}else for(;f<=h;++f)for(l=o-1;l>=1;--l)if(c=f>0?l/i(-f):l*i(f),!(cs)break;p.push(c)}2*p.length{if(null==t&&(t=10),null==n&&(n=10===o?"s":","),"function"!=typeof n&&(o%1||null!=(n=Re(n)).precision||(n.trim=!0),n=Ie(n)),t===1/0)return n;const a=Math.max(1,o*t/e.ticks().length);return t=>{let e=t/i(Math.round(r(t)));return e*on(Xh(n(),{floor:t=>i(Math.floor(r(t))),ceil:t=>i(Math.ceil(r(t)))})),e}function rd(t){return function(e){return Math.sign(e)*Math.log1p(Math.abs(e/t))}}function id(t){return function(e){return Math.sign(e)*Math.expm1(Math.abs(e))*t}}function od(t){var e=1,n=t(rd(e),id(e));return n.constant=function(n){return arguments.length?t(rd(e=+n),id(e)):e},Vh(n)}function ad(t){return function(e){return e<0?-Math.pow(-e,t):Math.pow(e,t)}}function sd(t){return t<0?-Math.sqrt(-t):Math.sqrt(t)}function ud(t){return t<0?-t*t:t*t}function ld(t){var e=t(qh,qh),n=1;return e.exponent=function(e){return arguments.length?1===(n=+e)?t(qh,qh):.5===n?t(sd,ud):t(ad(n),ad(1/n)):n},Vh(e)}function cd(){var t=ld(Hh());return t.copy=function(){return Wh(t,cd()).exponent(t.exponent())},Bc.apply(t,arguments),t}function fd(t){return new Date(t)}function hd(t){return t instanceof Date?+t:+new Date(+t)}function dd(t,e,n,r,i,o,a,s,u,l){var c=Yh(),f=c.invert,h=c.domain,d=l(".%L"),p=l(":%S"),g=l("%I:%M"),m=l("%I %p"),y=l("%a %d"),v=l("%b %d"),_=l("%B"),x=l("%Y");function b(t){return(u(t)0?r:1:0}const bd="linear",wd="log",kd="pow",Ad="sqrt",Md="symlog",Ed="time",Dd="utc",Cd="sequential",Fd="diverging",Sd="quantile",$d="quantize",Td="threshold",Bd="ordinal",zd="point",Nd="band",Od="bin-ordinal",Rd="continuous",Ud="discrete",Ld="discretizing",qd="interpolating",Pd="temporal";function jd(){const t=Oc().unknown(void 0),e=t.domain,n=t.range;let r,i,o=[0,1],a=!1,s=0,u=0,l=.5;function c(){const t=e().length,c=o[1]d+r*t));return n(c?p.reverse():p)}return delete t.unknown,t.domain=function(t){return arguments.length?(e(t),c()):e()},t.range=function(t){return arguments.length?(o=[+t[0],+t[1]],c()):o.slice()},t.rangeRound=function(t){return o=[+t[0],+t[1]],a=!0,c()},t.bandwidth=function(){return i},t.step=function(){return r},t.round=function(t){return arguments.length?(a=!!t,c()):a},t.padding=function(t){return arguments.length?(u=Math.max(0,Math.min(1,t)),s=u,c()):s},t.paddingInner=function(t){return arguments.length?(s=Math.max(0,Math.min(1,t)),c()):s},t.paddingOuter=function(t){return arguments.length?(u=Math.max(0,Math.min(1,t)),c()):u},t.align=function(t){return arguments.length?(l=Math.max(0,Math.min(1,t)),c()):l},t.invertRange=function(t){if(null==t[0]||null==t[1])return;const r=o[1]o[1-r])?void 0:(u=Math.max(0,oe(a,f)-1),l=f===h?u:oe(a,h)-1,f-a[u]>i+1e-10&&++u,r&&(c=u,u=s-l,l=s-c),u>l?void 0:e().slice(u,l+1))},t.invert=function(e){const n=t.invertRange([e,e]);return n?n[0]:n},t.copy=function(){return jd().domain(e()).range(o).round(a).paddingInner(s).paddingOuter(u).align(l)},c()}function Id(t){const e=t.copy;return t.padding=t.paddingOuter,delete t.paddingInner,t.copy=function(){return Id(e())},t}var Wd=Array.prototype.map;const Hd=Array.prototype.slice;const Yd=new Map,Gd=Symbol("vega_scale");function Vd(t){return t[Gd]=!0,t}function Xd(t,e,n){return arguments.length>1?(Yd.set(t,function(t,e,n){const r=function(){const n=e();return n.invertRange||(n.invertRange=n.invert?function(t){return function(e){let n,r=e[0],i=e[1];return i=s&&n[o]<=u&&(l<0&&(l=o),r=o);if(!(l<0))return s=t.invertExtent(n[l]),u=t.invertExtent(n[r]),[void 0===s[0]?s[1]:s[0],void 0===u[1]?u[0]:u[1]]}}(n):void 0),n.type=t,Vd(n)};return r.metadata=Bt(V(n)),r}(t,e,n)),this):Jd(t)?Yd.get(t):void 0}function Jd(t){return Yd.has(t)}function Zd(t,e){const n=Yd.get(t);return n&&n.metadata[e]}function Qd(t){return Zd(t,Rd)}function Kd(t){return Zd(t,Ud)}function tp(t){return Zd(t,Ld)}function ep(t){return Zd(t,wd)}function np(t){return Zd(t,qd)}function rp(t){return Zd(t,Sd)}Xd("identity",(function t(e){var n;function r(t){return null==t||isNaN(t=+t)?n:t}return r.invert=r,r.domain=r.range=function(t){return arguments.length?(e=Array.from(t,Uh),r):e.slice()},r.unknown=function(t){return arguments.length?(n=t,r):n},r.copy=function(){return t(e).unknown(n)},e=arguments.length?Array.from(e,Uh):[0,1],Vh(r)})),Xd(bd,(function t(){var e=Yh();return e.copy=function(){return Wh(e,t())},Bc.apply(e,arguments),Vh(e)}),Rd),Xd(wd,(function t(){const e=nd(Hh()).domain([1,10]);return e.copy=()=>Wh(e,t()).base(e.base()),Bc.apply(e,arguments),e}),[Rd,wd]),Xd(kd,cd,Rd),Xd(Ad,(function(){return cd.apply(null,arguments).exponent(.5)}),Rd),Xd(Md,(function t(){var e=od(Hh());return e.copy=function(){return Wh(e,t()).constant(e.constant())},Bc.apply(e,arguments)}),Rd),Xd(Ed,(function(){return Bc.apply(dd(qn,Pn,Nn,Bn,vn,pn,hn,cn,ln,ni).domain([new Date(2e3,0,1),new Date(2e3,0,2)]),arguments)}),[Rd,Pd]),Xd(Dd,(function(){return Bc.apply(dd(Un,Ln,On,zn,En,gn,dn,fn,ln,ii).domain([Date.UTC(2e3,0,1),Date.UTC(2e3,0,2)]),arguments)}),[Rd,Pd]),Xd(Cd,md,[Rd,qd]),Xd(`${Cd}-${bd}`,md,[Rd,qd]),Xd(`${Cd}-${wd}`,(function t(){var e=nd(pd()).domain([1,10]);return e.copy=function(){return gd(e,t()).base(e.base())},zc.apply(e,arguments)}),[Rd,qd,wd]),Xd(`${Cd}-${kd}`,yd,[Rd,qd]),Xd(`${Cd}-${Ad}`,(function(){return yd.apply(null,arguments).exponent(.5)}),[Rd,qd]),Xd(`${Cd}-${Md}`,(function t(){var e=od(pd());return e.copy=function(){return gd(e,t()).constant(e.constant())},zc.apply(e,arguments)}),[Rd,qd]),Xd(`${Fd}-${bd}`,(function t(){var e=Vh(vd()(qh));return e.copy=function(){return gd(e,t())},zc.apply(e,arguments)}),[Rd,qd]),Xd(`${Fd}-${wd}`,(function t(){var e=nd(vd()).domain([.1,1,10]);return e.copy=function(){return gd(e,t()).base(e.base())},zc.apply(e,arguments)}),[Rd,qd,wd]),Xd(`${Fd}-${kd}`,_d,[Rd,qd]),Xd(`${Fd}-${Ad}`,(function(){return _d.apply(null,arguments).exponent(.5)}),[Rd,qd]),Xd(`${Fd}-${Md}`,(function t(){var e=od(vd());return e.copy=function(){return gd(e,t()).constant(e.constant())},zc.apply(e,arguments)}),[Rd,qd]),Xd(Sd,(function t(){var e,n=[],r=[],i=[];function o(){var t=0,e=Math.max(1,r.length);for(i=new Array(e-1);++t0?i[e-1]:n[0],e=i?[o[i-1],r]:[o[e-1],o[e]]},s.unknown=function(t){return arguments.length?(e=t,s):s},s.thresholds=function(){return o.slice()},s.copy=function(){return t().domain([n,r]).range(a).unknown(e)},Bc.apply(Vh(s),arguments)}),Ld),Xd(Td,(function t(){var e,n=[.5],r=[0,1],i=1;function o(t){return null!=t&&t<=t?r[oe(n,t,0,i)]:e}return o.domain=function(t){return arguments.length?(n=Array.from(t),i=Math.min(n.length,r.length-1),o):n.slice()},o.range=function(t){return arguments.length?(r=Array.from(t),i=Math.min(n.length,r.length-1),o):r.slice()},o.invertExtent=function(t){var e=r.indexOf(t);return[n[e-1],n[e]]},o.unknown=function(t){return arguments.length?(e=t,o):e},o.copy=function(){return t().domain(n).range(r).unknown(e)},Bc.apply(o,arguments)}),Ld),Xd(Od,(function t(){let e=[],n=[];function r(t){return null==t||t!=t?void 0:n[(oe(e,t)-1)%n.length]}return r.domain=function(t){return arguments.length?(e=function(t){return Wd.call(t,S)}(t),r):e.slice()},r.range=function(t){return arguments.length?(n=Hd.call(t),r):n.slice()},r.tickFormat=function(t,n){return Gh(e[0],F(e),null==t?10:t,n)},r.copy=function(){return t().domain(r.domain()).range(r.range())},r}),[Ud,Ld]),Xd(Bd,Oc,Ud),Xd(Nd,jd,Ud),Xd(zd,(function(){return Id(jd().paddingInner(1))}),Ud);const ip=["clamp","base","constant","exponent"];function op(t,e){const n=e[0],r=F(e)-n;return function(e){return t(n+e*r)}}function ap(t,e,n){return Oh(lp(e||"rgb",n),t)}function sp(t,e){const n=new Array(e),r=e+1;for(let i=0;it[e]?a[e](t[e]()):0)),a):rt(.5)}function lp(t,e){const n=Rh[function(t){return"interpolate"+t.toLowerCase().split("-").map((t=>t[0].toUpperCase()+t.slice(1))).join("")}(t)];return null!=e&&n&&n.gamma?n.gamma(e):n}function cp(t){const e=t.length/6|0,n=new Array(e);for(let r=0;r1?(hp[t]=e,this):hp[t]}fp({category10:"1f77b4ff7f0e2ca02cd627289467bd8c564be377c27f7f7fbcbd2217becf",category20:"1f77b4aec7e8ff7f0effbb782ca02c98df8ad62728ff98969467bdc5b0d58c564bc49c94e377c2f7b6d27f7f7fc7c7c7bcbd22dbdb8d17becf9edae5",category20b:"393b795254a36b6ecf9c9ede6379398ca252b5cf6bcedb9c8c6d31bd9e39e7ba52e7cb94843c39ad494ad6616be7969c7b4173a55194ce6dbdde9ed6",category20c:"3182bd6baed69ecae1c6dbefe6550dfd8d3cfdae6bfdd0a231a35474c476a1d99bc7e9c0756bb19e9ac8bcbddcdadaeb636363969696bdbdbdd9d9d9",tableau10:"4c78a8f58518e4575672b7b254a24beeca3bb279a2ff9da69d755dbab0ac",tableau20:"4c78a89ecae9f58518ffbf7954a24b88d27ab79a20f2cf5b43989483bcb6e45756ff9d9879706ebab0acd67195fcbfd2b279a2d6a5c99e765fd8b5a5",accent:"7fc97fbeaed4fdc086ffff99386cb0f0027fbf5b17666666",dark2:"1b9e77d95f027570b3e7298a66a61ee6ab02a6761d666666",paired:"a6cee31f78b4b2df8a33a02cfb9a99e31a1cfdbf6fff7f00cab2d66a3d9affff99b15928",pastel1:"fbb4aeb3cde3ccebc5decbe4fed9a6ffffcce5d8bdfddaecf2f2f2",pastel2:"b3e2cdfdcdaccbd5e8f4cae4e6f5c9fff2aef1e2cccccccc",set1:"e41a1c377eb84daf4a984ea3ff7f00ffff33a65628f781bf999999",set2:"66c2a5fc8d628da0cbe78ac3a6d854ffd92fe5c494b3b3b3",set3:"8dd3c7ffffb3bebadafb807280b1d3fdb462b3de69fccde5d9d9d9bc80bdccebc5ffed6f"},cp),fp({blues:"cfe1f2bed8eca8cee58fc1de74b2d75ba3cf4592c63181bd206fb2125ca40a4a90",greens:"d3eecdc0e6baabdda594d3917bc77d60ba6c46ab5e329a512089430e7735036429",greys:"e2e2e2d4d4d4c4c4c4b1b1b19d9d9d8888887575756262624d4d4d3535351e1e1e",oranges:"fdd8b3fdc998fdb87bfda55efc9244f87f2cf06b18e4580bd14904b93d029f3303",purples:"e2e1efd4d4e8c4c5e0b4b3d6a3a0cc928ec3827cb97566ae684ea25c3696501f8c",reds:"fdc9b4fcb49afc9e80fc8767fa7051f6573fec3f2fdc2a25c81b1db21218970b13",blueGreen:"d5efedc1e8e0a7ddd18bd2be70c6a958ba9144ad77319c5d2089460e7736036429",bluePurple:"ccddecbad0e4a8c2dd9ab0d4919cc98d85be8b6db28a55a6873c99822287730f71",greenBlue:"d3eecec5e8c3b1e1bb9bd8bb82cec269c2ca51b2cd3c9fc7288abd1675b10b60a1",orangeRed:"fddcaffdcf9bfdc18afdad77fb9562f67d53ee6545e24932d32d1ebf130da70403",purpleBlue:"dbdaebc8cee4b1c3de97b7d87bacd15b9fc93a90c01e7fb70b70ab056199045281",purpleBlueGreen:"dbd8eac8cee4b0c3de93b7d872acd1549fc83892bb1c88a3097f8702736b016353",purpleRed:"dcc9e2d3b3d7ce9eccd186c0da6bb2e14da0e23189d91e6fc61159ab07498f023a",redPurple:"fccfccfcbec0faa9b8f98faff571a5ec539ddb3695c41b8aa908808d0179700174",yellowGreen:"e4f4acd1eca0b9e2949ed68880c97c62bb6e47aa5e3297502083440e723b036034",yellowOrangeBrown:"feeaa1fedd84fecc63feb746fca031f68921eb7215db5e0bc54c05ab3d038f3204",yellowOrangeRed:"fee087fed16ffebd59fea849fd903efc7335f9522bee3423de1b20ca0b22af0225",blueOrange:"134b852f78b35da2cb9dcae1d2e5eff2f0ebfce0bafbbf74e8932fc5690d994a07",brownBlueGreen:"704108a0651ac79548e3c78af3e6c6eef1eac9e9e48ed1c74da79e187a72025147",purpleGreen:"5b1667834792a67fb6c9aed3e6d6e8eff0efd9efd5aedda971bb75368e490e5e29",purpleOrange:"4114696647968f83b7b9b4d6dadbebf3eeeafce0bafbbf74e8932fc5690d994a07",redBlue:"8c0d25bf363adf745ef4ae91fbdbc9f2efeed2e5ef9dcae15da2cb2f78b3134b85",redGrey:"8c0d25bf363adf745ef4ae91fcdccbfaf4f1e2e2e2c0c0c0969696646464343434",yellowGreenBlue:"eff9bddbf1b4bde5b594d5b969c5be45b4c22c9ec02182b82163aa23479c1c3185",redYellowBlue:"a50026d4322cf16e43fcac64fedd90faf8c1dcf1ecabd6e875abd04a74b4313695",redYellowGreen:"a50026d4322cf16e43fcac63fedd8df9f7aed7ee8ea4d86e64bc6122964f006837",pinkYellowGreen:"8e0152c0267edd72adf0b3d6faddedf5f3efe1f2cab6de8780bb474f9125276419",spectral:"9e0142d13c4bf0704afcac63fedd8dfbf8b0e0f3a1a9dda269bda94288b55e4fa2",viridis:"440154470e61481a6c482575472f7d443a834144873d4e8a39568c35608d31688e2d708e2a788e27818e23888e21918d1f988b1fa08822a8842ab07f35b77943bf7154c56866cc5d7ad1518fd744a5db36bcdf27d2e21be9e51afde725",magma:"0000040404130b0924150e3720114b2c11603b0f704a107957157e651a80721f817f24828c29819a2e80a8327db6377ac43c75d1426fde4968e95462f1605df76f5cfa7f5efc8f65fe9f6dfeaf78febf84fece91fddea0fcedaffcfdbf",inferno:"0000040403130c0826170c3b240c4f330a5f420a68500d6c5d126e6b176e781c6d86216b932667a12b62ae305cbb3755c73e4cd24644dd513ae65c30ed6925f3771af8850ffb9506fca50afcb519fac62df6d645f2e661f3f484fcffa4",plasma:"0d088723069033059742039d5002a25d01a66a00a87801a88405a7900da49c179ea72198b12a90ba3488c33d80cb4779d35171da5a69e16462e76e5bed7953f2834cf68f44fa9a3dfca636fdb32ffec029fcce25f9dc24f5ea27f0f921",cividis:"00205100235800265d002961012b65042e670831690d346b11366c16396d1c3c6e213f6e26426e2c456e31476e374a6e3c4d6e42506e47536d4c566d51586e555b6e5a5e6e5e616e62646f66676f6a6a706e6d717270717573727976737c79747f7c75827f758682768985778c8877908b78938e789691789a94789e9778a19b78a59e77a9a177aea575b2a874b6ab73bbaf71c0b26fc5b66dc9b96acebd68d3c065d8c462ddc85fe2cb5ce7cf58ebd355f0d652f3da4ff7de4cfae249fce647",rainbow:"6e40aa883eb1a43db3bf3cafd83fa4ee4395fe4b83ff576eff6659ff7847ff8c38f3a130e2b72fcfcc36bee044aff05b8ff4576ff65b52f6673af27828ea8d1ddfa319d0b81cbecb23abd82f96e03d82e14c6edb5a5dd0664dbf6e40aa",sinebow:"ff4040fc582af47218e78d0bd5a703bfbf00a7d5038de70b72f41858fc2a40ff402afc5818f4720be78d03d5a700bfbf03a7d50b8de71872f42a58fc4040ff582afc7218f48d0be7a703d5bf00bfd503a7e70b8df41872fc2a58ff4040",turbo:"23171b32204a3e2a71453493493eae4b49c54a53d7485ee44569ee4074f53c7ff8378af93295f72e9ff42ba9ef28b3e926bce125c5d925cdcf27d5c629dcbc2de3b232e9a738ee9d3ff39347f68950f9805afc7765fd6e70fe667cfd5e88fc5795fb51a1f84badf545b9f140c5ec3cd0e637dae034e4d931ecd12ef4c92bfac029ffb626ffad24ffa223ff9821ff8d1fff821dff771cfd6c1af76118f05616e84b14df4111d5380fcb2f0dc0260ab61f07ac1805a313029b0f00950c00910b00",browns:"eedbbdecca96e9b97ae4a865dc9856d18954c7784cc0673fb85536ad44339f3632",tealBlues:"bce4d89dd3d181c3cb65b3c245a2b9368fae347da0306a932c5985",teals:"bbdfdfa2d4d58ac9c975bcbb61b0af4da5a43799982b8b8c1e7f7f127273006667",warmGreys:"dcd4d0cec5c1c0b8b4b3aaa7a59c9998908c8b827f7e7673726866665c5a59504e",goldGreen:"f4d166d5ca60b6c35c98bb597cb25760a6564b9c533f8f4f33834a257740146c36",goldOrange:"f4d166f8be5cf8aa4cf5983bf3852aef701be2621fd65322c54923b142239e3a26",goldRed:"f4d166f6be59f9aa51fc964ef6834bee734ae56249db5247cf4244c43141b71d3e",lightGreyRed:"efe9e6e1dad7d5cbc8c8bdb9bbaea9cd967ddc7b43e15f19df4011dc000b",lightGreyTeal:"e4eaead6dcddc8ced2b7c2c7a6b4bc64b0bf22a6c32295c11f85be1876bc",lightMulti:"e0f1f2c4e9d0b0de9fd0e181f6e072f6c053f3993ef77440ef4a3c",lightOrange:"f2e7daf7d5baf9c499fab184fa9c73f68967ef7860e8645bde515bd43d5b",lightTealBlue:"e3e9e0c0dccf9aceca7abfc859afc0389fb9328dad2f7ca0276b95255988",darkBlue:"3232322d46681a5c930074af008cbf05a7ce25c0dd38daed50f3faffffff",darkGold:"3c3c3c584b37725e348c7631ae8b2bcfa424ecc31ef9de30fff184ffffff",darkGreen:"3a3a3a215748006f4d048942489e4276b340a6c63dd2d836ffeb2cffffaa",darkMulti:"3737371f5287197d8c29a86995ce3fffe800ffffff",darkRed:"3434347036339e3c38cc4037e75d1eec8620eeab29f0ce32ffeb2c"},(t=>ap(cp(t))));const pp="symbol",gp="discrete",mp=t=>k(t)?t.map((t=>String(t))):String(t),yp=(t,e)=>t[1]-e[1],vp=(t,e)=>e[1]-t[1];function _p(t,e,n){let r;return vt(e)&&(t.bins&&(e=Math.max(e,t.bins.length)),null!=n&&(e=Math.min(e,Math.floor(Dt(t.domain())/n||1)+1))),A(e)&&(r=e.step,e=e.interval),xt(e)&&(e=t.type===Ed?Cr(e):t.type==Dd?Fr(e):s("Only time and utc scales accept interval strings."),r&&(e=e.every(r))),e}function xp(t,e,n){let r=t.range(),i=r[0],o=F(r),a=yp;if(i>o&&(r=o,o=i,i=r,a=vp),i=Math.floor(i),o=Math.ceil(o),e=e.map((e=>[e,t(e)])).filter((t=>i<=t[1]&&t[1]<=o)).sort(a).map((t=>t[0])),n>0&&e.length>1){const t=[e[0],F(e)];for(;e.length>n&&e.length>=3;)e=e.filter(((t,e)=>!(e%2)));e.length<3&&(e=t)}return e}function bp(t,e){return t.bins?xp(t,t.bins):t.ticks?t.ticks(e):t.domain()}function wp(t,e,n,r,i,o){const a=e.type;let s=mp;if(a===Ed||i===Ed)s=t.timeFormat(r);else if(a===Dd||i===Dd)s=t.utcFormat(r);else if(ep(a)){const i=t.formatFloat(r);if(o||e.bins)s=i;else{const t=kp(e,n,!1);s=e=>t(e)?i(e):""}}else if(e.tickFormat){const i=e.domain();s=t.formatSpan(i[0],i[i.length-1],n,r)}else r&&(s=t.format(r));return s}function kp(t,e,n){const r=bp(t,e),i=t.base(),o=Math.log(i),a=Math.max(1,i*e/r.length),s=t=>{let e=t/Math.pow(i,Math.round(Math.log(t)/o));return e*iAp[t.type]||t.bins;function Cp(t,e,n,r,i,o,a){const s=Mp[e.type]&&o!==Ed&&o!==Dd?function(t,e,n){const r=e[Mp[e.type]](),i=r.length;let o,a=i>1?r[1]-r[0]:r[0];for(o=1;o(e,n,r)=>{const i=Sp(r[n+1],Sp(r.max,1/0)),o=Bp(e,t),a=Bp(i,t);return o&&a?o+" – "+a:a?"< "+a:"≥ "+o},Sp=(t,e)=>null!=t?t:e,$p=t=>(e,n)=>n?t(e):null,Tp=t=>e=>t(e),Bp=(t,e)=>Number.isFinite(t)?e(t):null;function zp(t,e,n,r){const i=r||e.type;return xt(n)&&function(t){return Zd(t,Pd)}(i)&&(n=n.replace(/%a/g,"%A").replace(/%b/g,"%B")),n||i!==Ed?n||i!==Dd?Cp(t,e,5,null,n,r,!0):t.utcFormat("%A, %d %B %Y, %X UTC"):t.timeFormat("%A, %d %B %Y, %X")}function Np(t,e,n){n=n||{};const r=Math.max(3,n.maxlen||7),i=zp(t,e,n.format,n.formatType);if(tp(e.type)){const t=Ep(e).slice(1).map(i),n=t.length;return`${n} boundar${1===n?"y":"ies"}: ${t.join(", ")}`}if(Kd(e.type)){const t=e.domain(),n=t.length;return`${n} value${1===n?"":"s"}: ${n>r?t.slice(0,r-2).map(i).join(", ")+", ending with "+t.slice(-1).map(i):t.map(i).join(", ")}`}{const t=e.domain();return`values from ${i(t[0])} to ${i(F(t))}`}}let Op=0;const Rp="p_";function Up(t){return t&&t.gradient}function Lp(t,e,n){const r=t.gradient;let i=t.id,o="radial"===r?Rp:"";return i||(i=t.id="gradient_"+Op++,"radial"===r?(t.x1=qp(t.x1,.5),t.y1=qp(t.y1,.5),t.r1=qp(t.r1,0),t.x2=qp(t.x2,.5),t.y2=qp(t.y2,.5),t.r2=qp(t.r2,.5),o=Rp):(t.x1=qp(t.x1,0),t.y1=qp(t.y1,0),t.x2=qp(t.x2,1),t.y2=qp(t.y2,0))),e[i]=t,"url("+(n||"")+"#"+o+i+")"}function qp(t,e){return null!=t?t:e}function Pp(t,e){var n,r=[];return n={gradient:"linear",x1:t?t[0]:0,y1:t?t[1]:0,x2:e?e[0]:1,y2:e?e[1]:0,stops:r,stop:function(t,e){return r.push({offset:t,color:e}),n}}}const jp={basis:{curve:function(t){return new ec(t)}},"basis-closed":{curve:function(t){return new nc(t)}},"basis-open":{curve:function(t){return new rc(t)}},bundle:{curve:oc,tension:"beta",value:.85},cardinal:{curve:uc,tension:"tension",value:0},"cardinal-open":{curve:hc,tension:"tension",value:0},"cardinal-closed":{curve:cc,tension:"tension",value:0},"catmull-rom":{curve:gc,tension:"alpha",value:.5},"catmull-rom-closed":{curve:yc,tension:"alpha",value:.5},"catmull-rom-open":{curve:_c,tension:"alpha",value:.5},linear:{curve:Gl},"linear-closed":{curve:function(t){return new xc(t)}},monotone:{horizontal:function(t){return new Ec(t)},vertical:function(t){return new Mc(t)}},natural:{curve:function(t){return new Cc(t)}},step:{curve:function(t){return new Sc(t,.5)}},"step-after":{curve:function(t){return new Sc(t,1)}},"step-before":{curve:function(t){return new Sc(t,0)}}};function Ip(t,e,n){var r=lt(jp,t)&&jp[t],i=null;return r&&(i=r.curve||r[e||"vertical"],r.tension&&null!=n&&(i=i[r.tension](n))),i}const Wp={m:2,l:2,h:1,v:1,z:0,c:6,s:4,q:4,t:2,a:7},Hp=/[mlhvzcsqta]([^mlhvzcsqta]+|$)/gi,Yp=/^[+-]?(([0-9]*\.[0-9]+)|([0-9]+\.)|([0-9]+))([eE][+-]?[0-9]+)?/,Gp=/^((\s+,?\s*)|(,\s*))/,Vp=/^[01]/;function Xp(t){const e=[];return(t.match(Hp)||[]).forEach((t=>{let n=t[0];const r=n.toLowerCase(),i=Wp[r],o=function(t,e,n){const r=[];for(let i=0;e&&i1&&(g=Math.sqrt(g),n*=g,r*=g);const m=h/n,y=f/n,v=-f/r,_=h/r,x=m*s+y*u,b=v*s+_*u,w=m*t+y*e,k=v*t+_*e;let A=1/((w-x)*(w-x)+(k-b)*(k-b))-.25;A<0&&(A=0);let M=Math.sqrt(A);o==i&&(M=-M);const E=.5*(x+w)-M*(k-b),D=.5*(b+k)+M*(w-x),C=Math.atan2(b-D,x-E);let F=Math.atan2(k-D,w-E)-C;F<0&&1===o?F+=Qp:F>0&&0===o&&(F-=Qp);const S=Math.ceil(Math.abs(F/(Zp+.001))),$=[];for(let t=0;t+t}function vg(t,e,n){return Math.max(e,Math.min(t,n))}function _g(){var t=dg,e=pg,n=gg,r=mg,i=yg(0),o=i,a=i,s=i,u=null;function l(l,c,f){var h,d=null!=c?c:+t.call(this,l),p=null!=f?f:+e.call(this,l),g=+n.call(this,l),m=+r.call(this,l),y=Math.min(g,m)/2,v=vg(+i.call(this,l),0,y),_=vg(+o.call(this,l),0,y),x=vg(+a.call(this,l),0,y),b=vg(+s.call(this,l),0,y);if(u||(u=h=Rl()),v<=0&&_<=0&&x<=0&&b<=0)u.rect(d,p,g,m);else{var w=d+g,k=p+m;u.moveTo(d+v,p),u.lineTo(w-_,p),u.bezierCurveTo(w-hg*_,p,w,p+hg*_,w,p+_),u.lineTo(w,k-b),u.bezierCurveTo(w,k-hg*b,w-hg*b,k,w-b,k),u.lineTo(d+x,k),u.bezierCurveTo(d+hg*x,k,d,k-hg*x,d,k-x),u.lineTo(d,p+v),u.bezierCurveTo(d,p+hg*v,d+hg*v,p,d+v,p),u.closePath()}if(h)return u=null,h+""||null}return l.x=function(e){return arguments.length?(t=yg(e),l):t},l.y=function(t){return arguments.length?(e=yg(t),l):e},l.width=function(t){return arguments.length?(n=yg(t),l):n},l.height=function(t){return arguments.length?(r=yg(t),l):r},l.cornerRadius=function(t,e,n,r){return arguments.length?(i=yg(t),o=null!=e?yg(e):i,s=null!=n?yg(n):i,a=null!=r?yg(r):o,l):i},l.context=function(t){return arguments.length?(u=null==t?null:t,l):u},l}function xg(){var t,e,n,r,i,o,a,s,u=null;function l(t,e,n){const r=n/2;if(i){var l=a-e,c=t-o;if(l||c){var f=Math.hypot(l,c),h=(l/=f)*s,d=(c/=f)*s,p=Math.atan2(c,l);u.moveTo(o-h,a-d),u.lineTo(t-l*r,e-c*r),u.arc(t,e,r,p-Math.PI,p),u.lineTo(o+h,a+d),u.arc(o,a,s,p,p+Math.PI)}else u.arc(t,e,r,0,Qp);u.closePath()}else i=1;o=t,a=e,s=r}function c(o){var a,s,c,f=o.length,h=!1;for(null==u&&(u=c=Rl()),a=0;a<=f;++a)!(at.x||0,kg=t=>t.y||0,Ag=t=>!(!1===t.defined),Mg=function(){var t=Ll,e=ql,n=vl(0),r=null,i=Pl,o=jl,a=Il,s=null,u=Ul(l);function l(){var l,c,f=+t.apply(this,arguments),h=+e.apply(this,arguments),d=i.apply(this,arguments)-Cl,p=o.apply(this,arguments)-Cl,g=_l(p-d),m=p>d;if(s||(s=l=u()),hEl)if(g>Fl-El)s.moveTo(h*bl(d),h*Al(d)),s.arc(0,0,h,d,p,!m),f>El&&(s.moveTo(f*bl(p),f*Al(p)),s.arc(0,0,f,p,d,m));else{var y,v,_=d,x=p,b=d,w=p,k=g,A=g,M=a.apply(this,arguments)/2,E=M>El&&(r?+r.apply(this,arguments):Ml(f*f+h*h)),D=kl(_l(h-f)/2,+n.apply(this,arguments)),C=D,F=D;if(E>El){var S=Sl(E/f*Al(M)),$=Sl(E/h*Al(M));(k-=2*S)>El?(b+=S*=m?1:-1,w-=S):(k=0,b=w=(d+p)/2),(A-=2*$)>El?(_+=$*=m?1:-1,x-=$):(A=0,_=x=(d+p)/2)}var T=h*bl(_),B=h*Al(_),z=f*bl(w),N=f*Al(w);if(D>El){var O,R=h*bl(x),U=h*Al(x),L=f*bl(b),q=f*Al(b);if(g1?0:t<-1?Dl:Math.acos(t)}((P*I+j*W)/(Ml(P*P+j*j)*Ml(I*I+W*W)))/2),Y=Ml(O[0]*O[0]+O[1]*O[1]);C=kl(D,(f-Y)/(H-1)),F=kl(D,(h-Y)/(H+1))}else C=F=0}A>El?F>El?(y=Wl(L,q,T,B,h,F,m),v=Wl(R,U,z,N,h,F,m),s.moveTo(y.cx+y.x01,y.cy+y.y01),FEl&&k>El?C>El?(y=Wl(z,N,R,U,f,-C,m),v=Wl(T,B,L,q,f,-C,m),s.lineTo(y.cx+y.x01,y.cy+y.y01),Ct.startAngle||0)).endAngle((t=>t.endAngle||0)).padAngle((t=>t.padAngle||0)).innerRadius((t=>t.innerRadius||0)).outerRadius((t=>t.outerRadius||0)).cornerRadius((t=>t.cornerRadius||0)),Eg=Zl().x(wg).y1(kg).y0((t=>(t.y||0)+(t.height||0))).defined(Ag),Dg=Zl().y(kg).x1(wg).x0((t=>(t.x||0)+(t.width||0))).defined(Ag),Cg=Jl().x(wg).y(kg).defined(Ag),Fg=_g().x(wg).y(kg).width((t=>t.width||0)).height((t=>t.height||0)).cornerRadius((t=>bg(t.cornerRadiusTopLeft,t.cornerRadius)||0),(t=>bg(t.cornerRadiusTopRight,t.cornerRadius)||0),(t=>bg(t.cornerRadiusBottomRight,t.cornerRadius)||0),(t=>bg(t.cornerRadiusBottomLeft,t.cornerRadius)||0)),Sg=function(t,e){let n=null,r=Ul(i);function i(){let i;if(n||(n=i=r()),t.apply(this,arguments).draw(n,+e.apply(this,arguments)),i)return n=null,i+""||null}return t="function"==typeof t?t:vl(t||Ql),e="function"==typeof e?e:vl(void 0===e?64:+e),i.type=function(e){return arguments.length?(t="function"==typeof e?e:vl(e),i):t},i.size=function(t){return arguments.length?(e="function"==typeof t?t:vl(+t),i):e},i.context=function(t){return arguments.length?(n=null==t?null:t,i):n},i}().type((t=>cg(t.shape||"circle"))).size((t=>bg(t.size,64))),$g=xg().x(wg).y(kg).defined(Ag).size((t=>t.size||1));function Tg(t){return t.cornerRadius||t.cornerRadiusTopLeft||t.cornerRadiusTopRight||t.cornerRadiusBottomRight||t.cornerRadiusBottomLeft}function Bg(t,e,n,r){return Fg.context(t)(e,n,r)}var zg=1;function Ng(){zg=1}function Og(t,e,n){var r=e.clip,i=t._defs,o=e.clip_id||(e.clip_id="clip"+zg++),a=i.clipping[o]||(i.clipping[o]={id:o});return J(r)?a.path=r(null):Tg(n)?a.path=Bg(null,n,0,0):(a.width=n.width||0,a.height=n.height||0),"url(#"+o+")"}function Rg(t){this.clear(),t&&this.union(t)}function Ug(t){this.mark=t,this.bounds=this.bounds||new Rg}function Lg(t){Ug.call(this,t),this.items=this.items||[]}function qg(t){this._pending=0,this._loader=t||fa()}function Pg(t){t._pending+=1}function jg(t){t._pending-=1}function Ig(t,e,n){if(e.stroke&&0!==e.opacity&&0!==e.strokeOpacity){const r=null!=e.strokeWidth?+e.strokeWidth:1;t.expand(r+(n?function(t,e){return t.strokeJoin&&"miter"!==t.strokeJoin?0:e}(e,r):0))}return t}Rg.prototype={clone(){return new Rg(this)},clear(){return this.x1=+Number.MAX_VALUE,this.y1=+Number.MAX_VALUE,this.x2=-Number.MAX_VALUE,this.y2=-Number.MAX_VALUE,this},empty(){return this.x1===+Number.MAX_VALUE&&this.y1===+Number.MAX_VALUE&&this.x2===-Number.MAX_VALUE&&this.y2===-Number.MAX_VALUE},equals(t){return this.x1===t.x1&&this.y1===t.y1&&this.x2===t.x2&&this.y2===t.y2},set(t,e,n,r){return nthis.x2&&(this.x2=t),e>this.y2&&(this.y2=e),this},expand(t){return this.x1-=t,this.y1-=t,this.x2+=t,this.y2+=t,this},round(){return this.x1=Math.floor(this.x1),this.y1=Math.floor(this.y1),this.x2=Math.ceil(this.x2),this.y2=Math.ceil(this.y2),this},scale(t){return this.x1*=t,this.y1*=t,this.x2*=t,this.y2*=t,this},translate(t,e){return this.x1+=t,this.x2+=t,this.y1+=e,this.y2+=e,this},rotate(t,e,n){const r=this.rotatedPoints(t,e,n);return this.clear().add(r[0],r[1]).add(r[2],r[3]).add(r[4],r[5]).add(r[6],r[7])},rotatedPoints(t,e,n){var{x1:r,y1:i,x2:o,y2:a}=this,s=Math.cos(t),u=Math.sin(t),l=e-e*s+n*u,c=n-e*u-n*s;return[s*r-u*i+l,u*r+s*i+c,s*r-u*a+l,u*r+s*a+c,s*o-u*i+l,u*o+s*i+c,s*o-u*a+l,u*o+s*a+c]},union(t){return t.x1this.x2&&(this.x2=t.x2),t.y2>this.y2&&(this.y2=t.y2),this},intersect(t){return t.x1>this.x1&&(this.x1=t.x1),t.y1>this.y1&&(this.y1=t.y1),t.x2=t.x2&&this.y1<=t.y1&&this.y2>=t.y2},alignsWith(t){return t&&(this.x1==t.x1||this.x2==t.x2||this.y1==t.y1||this.y2==t.y2)},intersects(t){return t&&!(this.x2t.x2||this.y2t.y2)},contains(t,e){return!(tthis.x2||ethis.y2)},width(){return this.x2-this.x1},height(){return this.y2-this.y1}},dt(Lg,Ug),qg.prototype={pending(){return this._pending},sanitizeURL(t){const e=this;return Pg(e),e._loader.sanitize(t,{context:"href"}).then((t=>(jg(e),t))).catch((()=>(jg(e),null)))},loadImage(t){const e=this,n=Tc();return Pg(e),e._loader.sanitize(t,{context:"image"}).then((t=>{const r=t.href;if(!r||!n)throw{url:r};const i=new n,o=lt(t,"crossOrigin")?t.crossOrigin:"anonymous";return null!=o&&(i.crossOrigin=o),i.onload=()=>jg(e),i.onerror=()=>jg(e),i.src=r,i})).catch((t=>(jg(e),{complete:!1,width:0,height:0,src:t&&t.url||""})))},ready(){const t=this;return new Promise((e=>{!function n(r){t.pending()?setTimeout((()=>{n(!0)}),10):e(r)}(!1)}))}};const Wg=Qp-1e-8;let Hg,Yg,Gg,Vg,Xg,Jg,Zg,Qg;const Kg=(t,e)=>Hg.add(t,e),tm=(t,e)=>Kg(Yg=t,Gg=e),em=t=>Kg(t,Hg.y1),nm=t=>Kg(Hg.x1,t),rm=(t,e)=>Xg*t+Zg*e,im=(t,e)=>Jg*t+Qg*e,om=(t,e)=>Kg(rm(t,e),im(t,e)),am=(t,e)=>tm(rm(t,e),im(t,e));function sm(t,e){return Hg=t,e?(Vg=e*Jp,Xg=Qg=Math.cos(Vg),Jg=Math.sin(Vg),Zg=-Jg):(Xg=Qg=1,Vg=Jg=Zg=0),um}const um={beginPath(){},closePath(){},moveTo:am,lineTo:am,rect(t,e,n,r){Vg?(om(t+n,e),om(t+n,e+r),om(t,e+r),am(t,e)):(Kg(t+n,e+r),tm(t,e))},quadraticCurveTo(t,e,n,r){const i=rm(t,e),o=im(t,e),a=rm(n,r),s=im(n,r);lm(Yg,i,a,em),lm(Gg,o,s,nm),tm(a,s)},bezierCurveTo(t,e,n,r,i,o){const a=rm(t,e),s=im(t,e),u=rm(n,r),l=im(n,r),c=rm(i,o),f=im(i,o);cm(Yg,a,u,c,em),cm(Gg,s,l,f,nm),tm(c,f)},arc(t,e,n,r,i,o){if(r+=Vg,i+=Vg,Yg=n*Math.cos(i)+t,Gg=n*Math.sin(i)+e,Math.abs(i-r)>Wg)Kg(t-n,e-n),Kg(t+n,e+n);else{const a=r=>Kg(n*Math.cos(r)+t,n*Math.sin(r)+e);let s,u;if(a(r),a(i),i!==r)if((r%=Qp)<0&&(r+=Qp),(i%=Qp)<0&&(i+=Qp),ii;++u,s-=Zp)a(s);else for(s=r-r%Zp+Zp,u=0;u<4&&s1e-14?(u=a*a+s*o,u>=0&&(u=Math.sqrt(u),l=(-a+u)/o,c=(-a-u)/o)):l=.5*s/a,0m)return!1;d>g&&(g=d)}else if(f>0){if(d0&&(t.globalAlpha=n,t.fillStyle=wm(t,e,e.fill),!0)}var Am=[];function Mm(t,e,n){var r=null!=(r=e.strokeWidth)?r:1;return!(r<=0)&&((n*=null==e.strokeOpacity?1:e.strokeOpacity)>0&&(t.globalAlpha=n,t.strokeStyle=wm(t,e,e.stroke),t.lineWidth=r,t.lineCap=e.strokeCap||"butt",t.lineJoin=e.strokeJoin||"miter",t.miterLimit=e.strokeMiterLimit||10,t.setLineDash&&(t.setLineDash(e.strokeDash||Am),t.lineDashOffset=e.strokeDashOffset||0),!0))}function Em(t,e){return t.zindex-e.zindex||t.index-e.index}function Dm(t){if(!t.zdirty)return t.zitems;var e,n,r,i=t.items,o=[];for(n=0,r=i.length;n=0;)if(n=e(i[r]))return n;if(i===o)for(r=(i=t.items).length;--r>=0;)if(!i[r].zindex&&(n=e(i[r])))return n;return null}function Sm(t){return function(e,n,r){Cm(n,(n=>{r&&!r.intersects(n.bounds)||Tm(t,e,n,n)}))}}function $m(t){return function(e,n,r){!n.items.length||r&&!r.intersects(n.bounds)||Tm(t,e,n.items[0],n.items)}}function Tm(t,e,n,r){var i=null==n.opacity?1:n.opacity;0!==i&&(t(e,r)||(_m(e,n),n.fill&&km(e,n,i)&&e.fill(),n.stroke&&Mm(e,n,i)&&e.stroke()))}function Bm(t){return t=t||p,function(e,n,r,i,o,a){return r*=e.pixelRatio,i*=e.pixelRatio,Fm(n,(n=>{const s=n.bounds;if((!s||s.contains(o,a))&&s)return t(e,n,r,i,o,a)?n:void 0}))}}function zm(t,e){return function(n,r,i,o){var a,s,u=Array.isArray(r)?r[0]:r,l=null==e?u.fill:e,c=u.stroke&&n.isPointInStroke;return c&&(a=u.strokeWidth,s=u.strokeCap,n.lineWidth=null!=a?a:1,n.lineCap=null!=s?s:"butt"),!t(n,r)&&(l&&n.isPointInPath(i,o)||c&&n.isPointInStroke(i,o))}}function Nm(t){return Bm(zm(t))}function Om(t,e){return"translate("+t+","+e+")"}function Rm(t){return"rotate("+t+")"}function Um(t){return Om(t.x||0,t.y||0)}function Lm(t,e,n){function r(t,n){var r=n.x||0,i=n.y||0,o=n.angle||0;t.translate(r,i),o&&t.rotate(o*=Jp),t.beginPath(),e(t,n),o&&t.rotate(-o),t.translate(-r,-i)}return{type:t,tag:"path",nested:!1,attr:function(t,n){t("transform",function(t){return Om(t.x||0,t.y||0)+(t.angle?" "+Rm(t.angle):"")}(n)),t("d",e(null,n))},bound:function(t,n){return e(sm(t,n.angle),n),Ig(t,n).translate(n.x||0,n.y||0)},draw:Sm(r),pick:Nm(r),isect:n||pm(r)}}var qm=Lm("arc",(function(t,e){return Mg.context(t)(e)}));function Pm(t,e,n){function r(t,n){t.beginPath(),e(t,n)}const i=zm(r);return{type:t,tag:"path",nested:!0,attr:function(t,n){var r=n.mark.items;r.length&&t("d",e(null,r))},bound:function(t,n){var r=n.items;return 0===r.length?t:(e(sm(t),r),Ig(t,r[0]))},draw:$m(r),pick:function(t,e,n,r,o,a){var s=e.items,u=e.bounds;return!s||!s.length||u&&!u.contains(o,a)?null:(n*=t.pixelRatio,r*=t.pixelRatio,i(t,s,n,r)?s[0]:null)},isect:gm,tip:n}}var jm=Pm("area",(function(t,e){const n=e[0],r=n.interpolate||"linear";return("horizontal"===n.orient?Dg:Eg).curve(Ip(r,n.orient,n.tension)).context(t)(e)}),(function(t,e){for(var n,r,i="horizontal"===t[0].orient?e[1]:e[0],o="horizontal"===t[0].orient?"y":"x",a=t.length,s=1/0;--a>=0;)!1!==t[a].defined&&(r=Math.abs(t[a][o]-i)).5&&e<1.5?.5-Math.abs(e-1):0}function Hm(t,e){const n=Wm(e);t("d",Bg(null,e,n,n))}function Ym(t,e,n,r){const i=Wm(e);t.beginPath(),Bg(t,e,(n||0)+i,(r||0)+i)}const Gm=zm(Ym),Vm=zm(Ym,!1),Xm=zm(Ym,!0);var Jm={type:"group",tag:"g",nested:!1,attr:function(t,e){t("transform",Um(e))},bound:function(t,e){if(!e.clip&&e.items){const n=e.items,r=n.length;for(let e=0;e{const i=e.x||0,o=e.y||0,a=e.strokeForeground,s=null==e.opacity?1:e.opacity;(e.stroke||e.fill)&&s&&(Ym(t,e,i,o),_m(t,e),e.fill&&km(t,e,s)&&t.fill(),e.stroke&&!a&&Mm(t,e,s)&&t.stroke()),t.save(),t.translate(i,o),e.clip&&Im(t,e),n&&n.translate(-i,-o),Cm(e,(e=>{("group"===e.marktype||null==r||r.includes(e.marktype))&&this.draw(t,e,n,r)})),n&&n.translate(i,o),t.restore(),a&&e.stroke&&s&&(Ym(t,e,i,o),_m(t,e),Mm(t,e,s)&&t.stroke())}))},pick:function(t,e,n,r,i,o){if(e.bounds&&!e.bounds.contains(i,o)||!e.items)return null;const a=n*t.pixelRatio,s=r*t.pixelRatio;return Fm(e,(u=>{let l,c,f;const h=u.bounds;if(h&&!h.contains(i,o))return;c=u.x||0,f=u.y||0;const d=c+(u.width||0),p=f+(u.height||0),g=u.clip;if(g&&(id||op))return;if(t.save(),t.translate(c,f),c=i-c,f=o-f,g&&Tg(u)&&!Xm(t,u,a,s))return t.restore(),null;const m=u.strokeForeground,y=!1!==e.interactive;return y&&m&&u.stroke&&Vm(t,u,a,s)?(t.restore(),u):(l=Fm(u,(t=>function(t,e,n){return(!1!==t.interactive||"group"===t.marktype)&&t.bounds&&t.bounds.contains(e,n)}(t,c,f)?this.pick(t,n,r,c,f):null)),!l&&y&&(u.fill||!m&&u.stroke)&&Gm(t,u,a,s)&&(l=u),t.restore(),l||null)}))},isect:mm,content:function(t,e,n){t("clip-path",e.clip?Og(n,e,e):null)},background:function(t,e){t("class","background"),t("aria-hidden",!0),Hm(t,e)},foreground:function(t,e){t("class","foreground"),t("aria-hidden",!0),e.strokeForeground?Hm(t,e):t("d","")}},Zm={xmlns:"http://www.w3.org/2000/svg","xmlns:xlink":"http://www.w3.org/1999/xlink",version:"1.1"};function Qm(t,e){var n=t.image;return(!n||t.url&&t.url!==n.url)&&(n={complete:!1,width:0,height:0},e.loadImage(t.url).then((e=>{t.image=e,t.image.url=t.url}))),n}function Km(t,e){return null!=t.width?t.width:e&&e.width?!1!==t.aspect&&t.height?t.height*e.width/e.height:e.width:0}function ty(t,e){return null!=t.height?t.height:e&&e.height?!1!==t.aspect&&t.width?t.width*e.height/e.width:e.height:0}function ey(t,e){return"center"===t?e/2:"right"===t?e:0}function ny(t,e){return"middle"===t?e/2:"bottom"===t?e:0}var ry={type:"image",tag:"image",nested:!1,attr:function(t,e,n){const r=Qm(e,n),i=Km(e,r),o=ty(e,r),a=(e.x||0)-ey(e.align,i),s=(e.y||0)-ny(e.baseline,o);t("href",!r.src&&r.toDataURL?r.toDataURL():r.src||"",Zm["xmlns:xlink"],"xlink:href"),t("transform",Om(a,s)),t("width",i),t("height",o),t("preserveAspectRatio",!1===e.aspect?"none":"xMidYMid")},bound:function(t,e){const n=e.image,r=Km(e,n),i=ty(e,n),o=(e.x||0)-ey(e.align,r),a=(e.y||0)-ny(e.baseline,i);return t.set(o,a,o+r,a+i)},draw:function(t,e,n){Cm(e,(e=>{if(n&&!n.intersects(e.bounds))return;const r=Qm(e,this);let i=Km(e,r),o=ty(e,r);if(0===i||0===o)return;let a,s,u,l,c=(e.x||0)-ey(e.align,i),f=(e.y||0)-ny(e.baseline,o);!1!==e.aspect&&(s=r.width/r.height,u=e.width/e.height,s==s&&u==u&&s!==u&&(u=0;)if(!1!==t[o].defined&&(n=t[o].x-e[0])*n+(r=t[o].y-e[1])*r{if(!n||n.intersects(e.bounds)){var r=null==e.opacity?1:e.opacity;r&&ly(t,e,r)&&(_m(t,e),t.stroke())}}))},pick:Bm((function(t,e,n,r){return!!t.isPointInStroke&&(ly(t,e,1)&&t.isPointInStroke(n,r))})),isect:ym},fy=Lm("shape",(function(t,e){return(e.mark.shape||e.shape).context(t)(e)})),hy=Lm("symbol",(function(t,e){return Sg.context(t)(e)}),gm);const dy=kt();var py={height:xy,measureWidth:vy,estimateWidth:my,width:my,canvas:gy};function gy(t){py.width=t&&hm?vy:my}function my(t,e){return yy(Ay(t,e),xy(t))}function yy(t,e){return~~(.8*t.length*e)}function vy(t,e){return xy(t)<=0||!(e=Ay(t,e))?0:_y(e,Ey(t))}function _y(t,e){const n=`(${e}) ${t}`;let r=dy.get(n);return void 0===r&&(hm.font=e,r=hm.measureText(t).width,dy.set(n,r)),r}function xy(t){return null!=t.fontSize?+t.fontSize||0:11}function by(t){return null!=t.lineHeight?t.lineHeight:xy(t)+2}function wy(t){return e=t.lineBreak&&t.text&&!k(t.text)?t.text.split(t.lineBreak):t.text,k(e)?e.length>1?e:e[0]:e;var e}function ky(t){const e=wy(t);return(k(e)?e.length-1:0)*by(t)}function Ay(t,e){const n=null==e?"":(e+"").trim();return t.limit>0&&n.length?function(t,e){var n=+t.limit,r=function(t){if(py.width===vy){const e=Ey(t);return t=>_y(t,e)}if(py.width===my){const e=xy(t);return t=>yy(t,e)}return e=>py.width(t,e)}(t);if(r(e)>>1,r(e.slice(i))>n?s=i+1:u=i;return o+e.slice(s)}for(;s>>1),r(e.slice(0,i))Math.max(t,py.width(e,n))),0)):r=py.width(e,f),"center"===o?l-=r/2:"right"===o&&(l-=r),t.set(l+=s,c+=u,l+r,c+i),e.angle&&!n)t.rotate(e.angle*Jp,s,u);else if(2===n)return t.rotatedPoints(e.angle*Jp,s,u);return t}var Ty={type:"text",tag:"text",nested:!1,attr:function(t,e){var n,r=e.dx||0,i=(e.dy||0)+Dy(e),o=Sy(e),a=o.x1,s=o.y1,u=e.angle||0;t("text-anchor",Cy[e.align]||"start"),u?(n=Om(a,s)+" "+Rm(u),(r||i)&&(n+=" "+Om(r,i))):n=Om(a+r,s+i),t("transform",n)},bound:$y,draw:function(t,e,n){Cm(e,(e=>{var r,i,o,a,s,u,l,c=null==e.opacity?1:e.opacity;if(!(n&&!n.intersects(e.bounds)||0===c||e.fontSize<=0||null==e.text||0===e.text.length)){if(t.font=Ey(e),t.textAlign=e.align||"left",i=(r=Sy(e)).x1,o=r.y1,e.angle&&(t.save(),t.translate(i,o),t.rotate(e.angle*Jp),i=o=0),i+=e.dx||0,o+=(e.dy||0)+Dy(e),u=wy(e),_m(t,e),k(u))for(s=by(e),a=0;a=0;)if(!1!==t[i].defined&&(n=t[i].x-e[0])*n+(r=t[i].y-e[1])*r<(n=t[i].size||1)*n)return t[i];return null})),zy={arc:qm,area:jm,group:Jm,image:ry,line:iy,path:ay,rect:uy,rule:cy,shape:fy,symbol:hy,text:Ty,trail:By};function Ny(t,e,n){var r=zy[t.mark.marktype],i=e||r.bound;return r.nested&&(t=t.mark),i(t.bounds||(t.bounds=new Rg),t,n)}var Oy={mark:null};function Ry(t,e,n){var r,i,o,a,s=zy[t.marktype],u=s.bound,l=t.items,c=l&&l.length;if(s.nested)return c?o=l[0]:(Oy.mark=t,o=Oy),a=Ny(o,u,n),e=e&&e.union(a)||a;if(e=e||t.bounds&&t.bounds.clear()||new Rg,c)for(r=0,i=l.length;re;)t.removeChild(n[--r]);return t}function Vy(t){return"mark-"+t.marktype+(t.role?" role-"+t.role:"")+(t.name?" "+t.name:"")}function Xy(t,e){const n=e.getBoundingClientRect();return[t.clientX-n.left-(e.clientLeft||0),t.clientY-n.top-(e.clientTop||0)]}function Jy(t,e){this._active=null,this._handlers={},this._loader=t||fa(),this._tooltip=e||Zy}function Zy(t,e,n,r){t.element().setAttribute("title",r||"")}function Qy(t){this._el=null,this._bgcolor=null,this._loader=new qg(t)}jy.prototype={toJSON(t){return Ly(this.root,t||0)},mark(t,e,n){const r=Iy(t,e=e||this.root.items[0]);return e.items[n]=r,r.zindex&&(r.group.zdirty=!0),r}},Jy.prototype={initialize(t,e,n){return this._el=t,this._obj=n||null,this.origin(e)},element(){return this._el},canvas(){return this._el&&this._el.firstChild},origin(t){return arguments.length?(this._origin=t||[0,0],this):this._origin.slice()},scene(t){return arguments.length?(this._scene=t,this):this._scene},on(){},off(){},_handlerIndex(t,e,n){for(let r=t?t.length:0;--r>=0;)if(t[r].type===e&&(!n||t[r].handler===n))return r;return-1},handlers(t){const e=this._handlers,n=[];if(t)n.push(...e[this.eventName(t)]);else for(const t in e)n.push(...e[t]);return n},eventName(t){const e=t.indexOf(".");return e<0?t:t.slice(0,e)},handleHref(t,e,n){this._loader.sanitize(n,{context:"href"}).then((e=>{const n=new MouseEvent(t.type,t),r=Wy(null,"a");for(const t in e)r.setAttribute(t,e[t]);r.dispatchEvent(n)})).catch((()=>{}))},handleTooltip(t,e,n){if(e&&null!=e.tooltip){e=function(t,e,n,r){var i,o,a=t&&t.mark;if(a&&(i=zy[a.marktype]).tip){for((o=Xy(e,n))[0]-=r[0],o[1]-=r[1];t=t.mark.group;)o[0]-=t.x||0,o[1]-=t.y||0;t=i.tip(a.items,o)}return t}(e,t,this.canvas(),this._origin);const r=n&&e&&e.tooltip||null;this._tooltip.call(this._obj,this,t,e,r)}},getItemBoundingClientRect(t){const e=this.canvas();if(!e)return;const n=e.getBoundingClientRect(),r=this._origin,i=t.bounds,o=i.width(),a=i.height();let s=i.x1+r[0]+n.left,u=i.y1+r[1]+n.top;for(;t.mark&&(t=t.mark.group);)s+=t.x||0,u+=t.y||0;return{x:s,y:u,width:o,height:a,left:s,top:u,right:s+o,bottom:u+a}}},Qy.prototype={initialize(t,e,n,r,i){return this._el=t,this.resize(e,n,r,i)},element(){return this._el},canvas(){return this._el&&this._el.firstChild},background(t){return 0===arguments.length?this._bgcolor:(this._bgcolor=t,this)},resize(t,e,n,r){return this._width=t,this._height=e,this._origin=n||[0,0],this._scale=r||1,this},dirty(){},render(t,e){const n=this;return n._call=function(){n._render(t,e)},n._call(),n._call=null,n},_render(){},renderAsync(t,e){const n=this.render(t,e);return this._ready?this._ready.then((()=>n)):Promise.resolve(n)},_load(t,e){var n=this,r=n._loader[t](e);if(!n._ready){const t=n._call;n._ready=n._loader.ready().then((e=>{e&&t(),n._ready=null}))}return r},sanitizeURL(t){return this._load("sanitizeURL",t)},loadImage(t){return this._load("loadImage",t)}};const Ky="dragenter",tv="dragleave",ev="dragover",nv="pointerdown",rv="pointermove",iv="pointerout",ov="pointerover",av="mousedown",sv="mousemove",uv="mouseout",lv="mouseover",cv="click",fv="mousewheel",hv="touchstart",dv="touchmove",pv="touchend",gv=rv,mv=iv,yv=cv;function vv(t,e){Jy.call(this,t,e),this._down=null,this._touch=null,this._first=!0,this._events={}}function _v(t,e){(t=>t===hv||t===dv||t===pv?[hv,dv,pv]:[t])(e).forEach((e=>function(t,e){const n=t.canvas();n&&!t._events[e]&&(t._events[e]=1,n.addEventListener(e,t[e]?n=>t[e](n):n=>t.fire(e,n)))}(t,e)))}function xv(t,e,n){e.forEach((e=>t.fire(e,n)))}function bv(t,e,n){return function(r){const i=this._active,o=this.pickEvent(r);o===i||(i&&i.exit||xv(this,n,r),this._active=o,xv(this,e,r)),xv(this,t,r)}}function wv(t){return function(e){xv(this,t,e),this._active=null}}function kv(t,e,n,r,i,o){const a="undefined"!=typeof HTMLElement&&t instanceof HTMLElement&&null!=t.parentNode,s=t.getContext("2d"),u=a?"undefined"!=typeof window&&window.devicePixelRatio||1:i;t.width=e*u,t.height=n*u;for(const t in o)s[t]=o[t];return a&&1!==u&&(t.style.width=e+"px",t.style.height=n+"px"),s.pixelRatio=u,s.setTransform(u,0,0,u,u*r[0],u*r[1]),t}function Av(t){Qy.call(this,t),this._options={},this._redraw=!1,this._dirty=new Rg,this._tempb=new Rg}dt(vv,Jy,{initialize(t,e,n){return this._canvas=t&&Hy(t,"canvas"),[cv,av,nv,rv,iv,tv].forEach((t=>_v(this,t))),Jy.prototype.initialize.call(this,t,e,n)},canvas(){return this._canvas},context(){return this._canvas.getContext("2d")},events:["keydown","keypress","keyup",Ky,tv,ev,nv,"pointerup",rv,iv,ov,av,"mouseup",sv,uv,lv,cv,"dblclick","wheel",fv,hv,dv,pv],DOMMouseScroll(t){this.fire(fv,t)},pointermove:bv([rv,sv],[ov,lv],[iv,uv]),dragover:bv([ev],[Ky],[tv]),pointerout:wv([iv,uv]),dragleave:wv([tv]),pointerdown(t){this._down=this._active,this.fire(nv,t)},mousedown(t){this._down=this._active,this.fire(av,t)},click(t){this._down===this._active&&(this.fire(cv,t),this._down=null)},touchstart(t){this._touch=this.pickEvent(t.changedTouches[0]),this._first&&(this._active=this._touch,this._first=!1),this.fire(hv,t,!0)},touchmove(t){this.fire(dv,t,!0)},touchend(t){this.fire(pv,t,!0),this._touch=null},fire(t,e,n){const r=n?this._touch:this._active,i=this._handlers[t];if(e.vegaType=t,t===yv&&r&&r.href?this.handleHref(e,r,r.href):t!==gv&&t!==mv||this.handleTooltip(e,r,t!==mv),i)for(let t=0,n=i.length;t=0&&r.splice(i,1),this},pickEvent(t){const e=Xy(t,this._canvas),n=this._origin;return this.pick(this._scene,e[0],e[1],e[0]-n[0],e[1]-n[1])},pick(t,e,n,r,i){const o=this.context();return zy[t.marktype].pick.call(this,o,t,e,n,r,i)}});const Mv=Qy.prototype;function Ev(t,e){Jy.call(this,t,e);const n=this;n._hrefHandler=Dv(n,((t,e)=>{e&&e.href&&n.handleHref(t,e,e.href)})),n._tooltipHandler=Dv(n,((t,e)=>{n.handleTooltip(t,e,t.type!==mv)}))}dt(Av,Qy,{initialize(t,e,n,r,i,o){return this._options=o||{},this._canvas=this._options.externalContext?null:$c(1,1,this._options.type),t&&this._canvas&&(Gy(t,0).appendChild(this._canvas),this._canvas.setAttribute("class","marks")),Mv.initialize.call(this,t,e,n,r,i)},resize(t,e,n,r){if(Mv.resize.call(this,t,e,n,r),this._canvas)kv(this._canvas,this._width,this._height,this._origin,this._scale,this._options.context);else{const t=this._options.externalContext;t||s("CanvasRenderer is missing a valid canvas or context"),t.scale(this._scale,this._scale),t.translate(this._origin[0],this._origin[1])}return this._redraw=!0,this},canvas(){return this._canvas},context(){return this._options.externalContext||(this._canvas?this._canvas.getContext("2d"):null)},dirty(t){const e=this._tempb.clear().union(t.bounds);let n=t.mark.group;for(;n;)e.translate(n.x||0,n.y||0),n=n.mark.group;this._dirty.union(e)},_render(t,e){const n=this.context(),r=this._origin,i=this._width,o=this._height,a=this._dirty,s=((t,e,n)=>(new Rg).set(0,0,e,n).translate(-t[0],-t[1]))(r,i,o);n.save();const u=this._redraw||a.empty()?(this._redraw=!1,s.expand(1)):function(t,e,n){return e.expand(1).round(),t.pixelRatio%1&&e.scale(t.pixelRatio).round().scale(1/t.pixelRatio),e.translate(-n[0]%1,-n[1]%1),t.beginPath(),t.rect(e.x1,e.y1,e.width(),e.height()),t.clip(),e}(n,s.intersect(a),r);return this.clear(-r[0],-r[1],i,o),this.draw(n,t,u,e),n.restore(),a.clear(),this},draw(t,e,n,r){if("group"!==e.marktype&&null!=r&&!r.includes(e.marktype))return;const i=zy[e.marktype];e.clip&&function(t,e){var n=e.clip;t.save(),J(n)?(t.beginPath(),n(t),t.clip()):Im(t,e.group)}(t,e),i.draw.call(this,t,e,n,r),e.clip&&t.restore()},clear(t,e,n,r){const i=this._options,o=this.context();"pdf"===i.type||i.externalContext||o.clearRect(t,e,n,r),null!=this._bgcolor&&(o.fillStyle=this._bgcolor,o.fillRect(t,e,n,r))}});const Dv=(t,e)=>n=>{let r=n.target.__data__;r=Array.isArray(r)?r[0]:r,n.vegaType=n.type,e.call(t._obj,n,r)};dt(Ev,Jy,{initialize(t,e,n){let r=this._svg;return r&&(r.removeEventListener(yv,this._hrefHandler),r.removeEventListener(gv,this._tooltipHandler),r.removeEventListener(mv,this._tooltipHandler)),this._svg=r=t&&Hy(t,"svg"),r&&(r.addEventListener(yv,this._hrefHandler),r.addEventListener(gv,this._tooltipHandler),r.addEventListener(mv,this._tooltipHandler)),Jy.prototype.initialize.call(this,t,e,n)},canvas(){return this._svg},on(t,e){const n=this.eventName(t),r=this._handlers;if(this._handlerIndex(r[n],t,e)<0){const i={type:t,handler:e,listener:Dv(this,e)};(r[n]||(r[n]=[])).push(i),this._svg&&this._svg.addEventListener(n,i.listener)}return this},off(t,e){const n=this.eventName(t),r=this._handlers[n],i=this._handlerIndex(r,t,e);return i>=0&&(this._svg&&this._svg.removeEventListener(n,r[i].listener),r.splice(i,1)),this}});const Cv="aria-hidden",Fv="aria-label",Sv="role",$v="aria-roledescription",Tv="graphics-object",Bv="graphics-symbol",zv=(t,e,n)=>({[Sv]:t,[$v]:e,[Fv]:n||void 0}),Nv=Bt(["axis-domain","axis-grid","axis-label","axis-tick","axis-title","legend-band","legend-entry","legend-gradient","legend-label","legend-title","legend-symbol","title"]),Ov={axis:{desc:"axis",caption:function(t){const e=t.datum,n=t.orient,r=e.title?Pv(t):null,i=t.context,o=i.scales[e.scale].value,a=i.dataflow.locale(),s=o.type;return("left"===n||"right"===n?"Y":"X")+"-axis"+(r?` titled '${r}'`:"")+` for a ${Kd(s)?"discrete":s} scale`+` with ${Np(a,o,t)}`}},legend:{desc:"legend",caption:function(t){const e=t.datum,n=e.title?Pv(t):null,r=`${e.type||""} legend`.trim(),i=e.scales,o=Object.keys(i),a=t.context,s=a.scales[i[o[0]]].value,u=a.dataflow.locale();return l=r,(l.length?l[0].toUpperCase()+l.slice(1):l)+(n?` titled '${n}'`:"")+` for ${function(t){return t=t.map((t=>t+("fill"===t||"stroke"===t?" color":""))),t.length<2?t[0]:t.slice(0,-1).join(", ")+" and "+F(t)}(o)}`+` with ${Np(u,s,t)}`;var l}},"title-text":{desc:"title",caption:t=>`Title text '${qv(t)}'`},"title-subtitle":{desc:"subtitle",caption:t=>`Subtitle text '${qv(t)}'`}},Rv={ariaRole:Sv,ariaRoleDescription:$v,description:Fv};function Uv(t,e){const n=!1===e.aria;if(t(Cv,n||void 0),n||null==e.description)for(const e in Rv)t(Rv[e],void 0);else{const n=e.mark.marktype;t(Fv,e.description),t(Sv,e.ariaRole||("group"===n?Tv:Bv)),t($v,e.ariaRoleDescription||`${n} mark`)}}function Lv(t){return!1===t.aria?{[Cv]:!0}:Nv[t.role]?null:Ov[t.role]?function(t,e){try{const n=t.items[0],r=e.caption||(()=>"");return zv(e.role||Bv,e.desc,n.description||r(n))}catch(t){return null}}(t,Ov[t.role]):function(t){const e=t.marktype,n="group"===e||"text"===e||t.items.some((t=>null!=t.description&&!1!==t.aria));return zv(n?Tv:Bv,`${e} mark container`,t.description)}(t)}function qv(t){return V(t.text).join(" ")}function Pv(t){try{return V(F(t.items).items[0].text).join(" ")}catch(t){return null}}const jv=t=>(t+"").replace(/&/g,"&").replace(//g,">");function Iv(){let t="",e="",n="";const r=[],i=()=>e=n="",o=(t,n)=>{var r;return null!=n&&(e+=` ${t}="${r=n,jv(r).replace(/"/g,""").replace(/\t/g," ").replace(/\n/g," ").replace(/\r/g," ")}"`),a},a={open(s){(o=>{e&&(t+=`${e}>${n}`,i()),r.push(o)})(s),e="<"+s;for(var u=arguments.length,l=new Array(u>1?u-1:0),c=1;c${n}`:"/>"):``,i(),a},attr:o,text:t=>(n+=jv(t),a),toString:()=>t};return a}const Wv=t=>Hv(Iv(),t)+"";function Hv(t,e){if(t.open(e.tagName),e.hasAttributes()){const n=e.attributes,r=n.length;for(let e=0;e1&&t.previousSibling!=e}(a,n))&&e.insertBefore(a,n?n.nextSibling:e.firstChild),a}dt(Zv,Qy,{initialize(t,e,n,r,i){return this._defs={},this._clearDefs(),t&&(this._svg=Yy(t,0,"svg",Jv),this._svg.setAttributeNS(Xv,"xmlns",Jv),this._svg.setAttributeNS(Xv,"xmlns:xlink",Zm["xmlns:xlink"]),this._svg.setAttribute("version",Zm.version),this._svg.setAttribute("class","marks"),Gy(t,1),this._root=Yy(this._svg,0,"g",Jv),u_(this._root,Vv),Gy(this._svg,1)),this.background(this._bgcolor),Qv.initialize.call(this,t,e,n,r,i)},background(t){return arguments.length&&this._svg&&this._svg.style.setProperty("background-color",t),Qv.background.apply(this,arguments)},resize(t,e,n,r){return Qv.resize.call(this,t,e,n,r),this._svg&&(u_(this._svg,{width:this._width*this._scale,height:this._height*this._scale,viewBox:`0 0 ${this._width} ${this._height}`}),this._root.setAttribute("transform",`translate(${this._origin})`)),this._dirty=[],this},canvas(){return this._svg},svg(){const t=this._svg,e=this._bgcolor;if(!t)return null;let n;e&&(t.removeAttribute("style"),n=Yy(t,0,"rect",Jv),u_(n,{width:this._width,height:this._height,fill:e}));const r=Wv(t);return e&&(t.removeChild(n),this._svg.style.setProperty("background-color",e)),r},_render(t,e){return this._dirtyCheck()&&(this._dirtyAll&&this._clearDefs(),this.mark(this._root,t,void 0,e),Gy(this._root,1)),this.defs(),this._dirty=[],++this._dirtyID,this},dirty(t){t.dirty!==this._dirtyID&&(t.dirty=this._dirtyID,this._dirty.push(t))},isDirty(t){return this._dirtyAll||!t._svg||!t._svg.ownerSVGElement||t.dirty===this._dirtyID},_dirtyCheck(){this._dirtyAll=!0;const t=this._dirty;if(!t.length||!this._dirtyID)return!0;const e=++this._dirtyID;let n,r,i,o,a,s,u;for(a=0,s=t.length;a{t.dirty=e}))),r.zdirty||(n.exit?(o.nested&&r.items.length?(u=r.items[0],u._svg&&this._update(o,u._svg,u)):n._svg&&(u=n._svg.parentNode,u&&u.removeChild(n._svg)),n._svg=null):(n=o.nested?r.items[0]:n,n._update!==e&&(n._svg&&n._svg.ownerSVGElement?this._update(o,n._svg,n):(this._dirtyAll=!1,Kv(n,e)),n._update=e)));return!this._dirtyAll},mark(t,e,n,r){if(!this.isDirty(e))return e._svg;const i=this._svg,o=e.marktype,a=zy[o],s=!1===e.interactive?"none":null,u="g"===a.tag,l=n_(e,t,n,"g",i);if("group"!==o&&null!=r&&!r.includes(o))return Gy(l,0),e._svg;l.setAttribute("class",Vy(e));const c=Lv(e);for(const t in c)l_(l,t,c[t]);u||l_(l,"pointer-events",s),l_(l,"clip-path",e.clip?Og(this,e,e.group):null);let f=null,h=0;const d=t=>{const e=this.isDirty(t),n=n_(t,l,f,a.tag,i);e&&(this._update(a,n,t),u&&function(t,e,n,r){e=e.lastChild.previousSibling;let i,o=0;Cm(n,(n=>{i=t.mark(e,n,i,r),++o})),Gy(e,1+o)}(this,n,t,r)),f=n,++h};return a.nested?e.items.length&&d(e.items[0]):Cm(e,d),Gy(l,h),l},_update(t,e,n){r_=e,i_=e.__values__,Uv(a_,n),t.attr(a_,n,this);const r=o_[t.type];r&&r.call(this,t,e,n),r_&&this.style(r_,n)},style(t,e){if(null!=e){for(const n in Yv){let r="font"===n?My(e):e[n];if(r===i_[n])continue;const i=Yv[n];null==r?t.removeAttribute(i):(Up(r)&&(r=Lp(r,this._defs.gradient,c_())),t.setAttribute(i,r+"")),i_[n]=r}for(const n in Gv)s_(t,Gv[n],e[n])}},defs(){const t=this._svg,e=this._defs;let n=e.el,r=0;for(const i in e.gradient)n||(e.el=n=Yy(t,1,"defs",Jv)),r=t_(n,e.gradient[i],r);for(const i in e.clipping)n||(e.el=n=Yy(t,1,"defs",Jv)),r=e_(n,e.clipping[i],r);n&&(0===r?(t.removeChild(n),e.el=null):Gy(n,r))},_clearDefs(){const t=this._defs;t.gradient={},t.clipping={}}});let r_=null,i_=null;const o_={group(t,e,n){const r=r_=e.childNodes[2];i_=r.__values__,t.foreground(a_,n,this),i_=e.__values__,r_=e.childNodes[1],t.content(a_,n,this);const i=r_=e.childNodes[0];t.background(a_,n,this);const o=!1===n.mark.interactive?"none":null;if(o!==i_.events&&(l_(r,"pointer-events",o),l_(i,"pointer-events",o),i_.events=o),n.strokeForeground&&n.stroke){const t=n.fill;l_(r,"display",null),this.style(i,n),l_(i,"stroke",null),t&&(n.fill=null),i_=r.__values__,this.style(r,n),t&&(n.fill=t),r_=null}else l_(r,"display","none")},image(t,e,n){!1===n.smooth?(s_(e,"image-rendering","optimizeSpeed"),s_(e,"image-rendering","pixelated")):s_(e,"image-rendering",null)},text(t,e,n){const r=wy(n);let i,o,a,s;k(r)?(o=r.map((t=>Ay(n,t))),i=o.join("\n"),i!==i_.text&&(Gy(e,0),a=e.ownerDocument,s=by(n),o.forEach(((t,r)=>{const i=Wy(a,"tspan",Jv);i.__data__=n,i.textContent=t,r&&(i.setAttribute("x",0),i.setAttribute("dy",s)),e.appendChild(i)})),i_.text=i)):(o=Ay(n,r),o!==i_.text&&(e.textContent=o,i_.text=o)),l_(e,"font-family",My(n)),l_(e,"font-size",xy(n)+"px"),l_(e,"font-style",n.fontStyle),l_(e,"font-variant",n.fontVariant),l_(e,"font-weight",n.fontWeight)}};function a_(t,e,n){e!==i_[t]&&(n?function(t,e,n,r){null!=n?t.setAttributeNS(r,e,n):t.removeAttributeNS(r,e)}(r_,t,e,n):l_(r_,t,e),i_[t]=e)}function s_(t,e,n){n!==i_[e]&&(null==n?t.style.removeProperty(e):t.style.setProperty(e,n+""),i_[e]=n)}function u_(t,e){for(const n in e)l_(t,n,e[n])}function l_(t,e,n){null!=n?t.setAttribute(e,n):t.removeAttribute(e)}function c_(){let t;return"undefined"==typeof window?"":(t=window.location).hash?t.href.slice(0,-t.hash.length):t.href}function f_(t){Qy.call(this,t),this._text=null,this._defs={gradient:{},clipping:{}}}dt(f_,Qy,{svg(){return this._text},_render(t){const e=Iv();e.open("svg",ot({},Zm,{class:"marks",width:this._width*this._scale,height:this._height*this._scale,viewBox:`0 0 ${this._width} ${this._height}`}));const n=this._bgcolor;return n&&"transparent"!==n&&"none"!==n&&e.open("rect",{width:this._width,height:this._height,fill:n}).close(),e.open("g",Vv,{transform:"translate("+this._origin+")"}),this.mark(e,t),e.close(),this.defs(e),this._text=e.close()+"",this},mark(t,e){const n=zy[e.marktype],r=n.tag,i=[Uv,n.attr];t.open("g",{class:Vy(e),"clip-path":e.clip?Og(this,e,e.group):null},Lv(e),{"pointer-events":"g"!==r&&!1===e.interactive?"none":null});const o=o=>{const a=this.href(o);if(a&&t.open("a",a),t.open(r,this.attr(e,o,i,"g"!==r?r:null)),"text"===r){const e=wy(o);if(k(e)){const n={x:0,dy:by(o)};for(let r=0;rthis.mark(t,e))),t.close(),r&&a?(i&&(o.fill=null),o.stroke=a,t.open("path",this.attr(e,o,n.foreground,"bgrect")).close(),i&&(o.fill=i)):t.open("path",this.attr(e,o,n.foreground,"bgfore")).close()}t.close(),a&&t.close()};return n.nested?e.items&&e.items.length&&o(e.items[0]):Cm(e,o),t.close()},href(t){const e=t.href;let n;if(e){if(n=this._hrefs&&this._hrefs[e])return n;this.sanitizeURL(e).then((t=>{t["xlink:href"]=t.href,t.href=null,(this._hrefs||(this._hrefs={}))[e]=t}))}return null},attr(t,e,n,r){const i={},o=(t,e,n,r)=>{i[r||t]=e};return Array.isArray(n)?n.forEach((t=>t(o,e,this))):n(o,e,this),r&&function(t,e,n,r,i){let o;if(null==e)return t;"bgrect"===r&&!1===n.interactive&&(t["pointer-events"]="none");if("bgfore"===r&&(!1===n.interactive&&(t["pointer-events"]="none"),t.display="none",null!==e.fill))return t;"image"===r&&!1===e.smooth&&(o=["image-rendering: optimizeSpeed;","image-rendering: pixelated;"]);"text"===r&&(t["font-family"]=My(e),t["font-size"]=xy(e)+"px",t["font-style"]=e.fontStyle,t["font-variant"]=e.fontVariant,t["font-weight"]=e.fontWeight);for(const n in Yv){let r=e[n];const o=Yv[n];("transparent"!==r||"fill"!==o&&"stroke"!==o)&&null!=r&&(Up(r)&&(r=Lp(r,i.gradient,"")),t[o]=r)}for(const t in Gv){const n=e[t];null!=n&&(o=o||[],o.push(`${Gv[t]}: ${n};`))}o&&(t.style=o.join(" "))}(i,e,t,r,this._defs),i},defs(t){const e=this._defs.gradient,n=this._defs.clipping;if(0!==Object.keys(e).length+Object.keys(n).length){t.open("defs");for(const n in e){const r=e[n],i=r.stops;"radial"===r.gradient?(t.open("pattern",{id:Rp+n,viewBox:"0,0,1,1",width:"100%",height:"100%",preserveAspectRatio:"xMidYMid slice"}),t.open("rect",{width:"1",height:"1",fill:"url(#"+n+")"}).close(),t.close(),t.open("radialGradient",{id:n,fx:r.x1,fy:r.y1,fr:r.r1,cx:r.x2,cy:r.y2,r:r.r2})):t.open("linearGradient",{id:n,x1:r.x1,x2:r.x2,y1:r.y1,y2:r.y2});for(let e=0;e!h_.svgMarkTypes.includes(t)));this._svgRenderer.render(t,h_.svgMarkTypes),this._canvasRenderer.render(t,n)},resize(t,e,n,r){return p_.resize.call(this,t,e,n,r),this._svgRenderer.resize(t,e,n,r),this._canvasRenderer.resize(t,e,n,r),this},background(t){return h_.svgOnTop?this._canvasRenderer.background(t):this._svgRenderer.background(t),this}}),dt(g_,vv,{initialize(t,e,n){const r=Yy(Yy(t,0,"div"),h_.svgOnTop?0:1,"div");return vv.prototype.initialize.call(this,r,e,n)}});const m_="canvas",y_="hybrid",v_="none",__={Canvas:m_,PNG:"png",SVG:"svg",Hybrid:y_,None:v_},x_={};function b_(t,e){return t=String(t||"").toLowerCase(),arguments.length>1?(x_[t]=e,this):x_[t]}function w_(t,e,n){const r=[],i=(new Rg).union(e),o=t.marktype;return o?k_(t,i,n,r):"group"===o?A_(t,i,n,r):s("Intersect scene must be mark node or group item.")}function k_(t,e,n,r){if(function(t,e,n){return t.bounds&&e.intersects(t.bounds)&&("group"===t.marktype||!1!==t.interactive&&(!n||n(t)))}(t,e,n)){const i=t.items,o=t.marktype,a=i.length;let s=0;if("group"===o)for(;s=0;r--)if(i[r]!=o[r])return!1;for(r=i.length-1;r>=0;r--)if(!F_(t[n=i[r]],e[n],n))return!1;return typeof t==typeof e}(t,e):t==e)}function S_(t,e){return F_(Xp(t),Xp(e))}const $_="top",T_="left",B_="right",z_="bottom",N_="top-left",O_="top-right",R_="bottom-left",U_="bottom-right",L_="start",q_="middle",P_="end",j_="x",I_="y",W_="group",H_="axis",Y_="title",G_="frame",V_="scope",X_="legend",J_="row-header",Z_="row-footer",Q_="row-title",K_="column-header",tx="column-footer",ex="column-title",nx="padding",rx="symbol",ix="fit",ox="fit-x",ax="fit-y",sx="pad",ux="none",lx="all",cx="each",fx="flush",hx="column",dx="row";function px(t){Ja.call(this,null,t)}function gx(t,e,n){return e(t.bounds.clear(),t,n)}dt(px,Ja,{transform(t,e){const n=e.dataflow,r=t.mark,i=r.marktype,o=zy[i],a=o.bound;let s,u=r.bounds;if(o.nested)r.items.length&&n.dirty(r.items[0]),u=gx(r,a),r.items.forEach((t=>{t.bounds.clear().union(u)}));else if(i===W_||t.modified())switch(e.visit(e.MOD,(t=>n.dirty(t))),u.clear(),r.items.forEach((t=>u.union(gx(t,a)))),r.role){case H_:case X_:case Y_:e.reflow()}else s=e.changed(e.REM),e.visit(e.ADD,(t=>{u.union(gx(t,a))})),e.visit(e.MOD,(t=>{s=s||u.alignsWith(t.bounds),n.dirty(t),u.union(gx(t,a))})),s&&(u.clear(),r.items.forEach((t=>u.union(t.bounds))));return D_(r),e.modifies("bounds")}});const mx=":vega_identifier:";function yx(t){Ja.call(this,0,t)}function vx(t){Ja.call(this,null,t)}function _x(t){Ja.call(this,null,t)}yx.Definition={type:"Identifier",metadata:{modifies:!0},params:[{name:"as",type:"string",required:!0}]},dt(yx,Ja,{transform(t,e){const n=(i=e.dataflow)._signals[mx]||(i._signals[mx]=i.add(0)),r=t.as;var i;let o=n.value;return e.visit(e.ADD,(t=>t[r]=t[r]||++o)),n.set(this.value=o),e}}),dt(vx,Ja,{transform(t,e){let n=this.value;n||(n=e.dataflow.scenegraph().mark(t.markdef,function(t){const e=t.groups,n=t.parent;return e&&1===e.size?e.get(Object.keys(e.object)[0]):e&&n?e.lookup(n):null}(t),t.index),n.group.context=t.context,t.context.group||(t.context.group=n.group),n.source=this.source,n.clip=t.clip,n.interactive=t.interactive,this.value=n);const r=n.marktype===W_?Lg:Ug;return e.visit(e.ADD,(t=>r.call(t,n))),(t.modified("clip")||t.modified("interactive"))&&(n.clip=t.clip,n.interactive=!!t.interactive,n.zdirty=!0,e.reflow()),n.items=e.source,e}});const xx={parity:t=>t.filter(((t,e)=>e%2?t.opacity=0:1)),greedy:(t,e)=>{let n;return t.filter(((t,r)=>r&&bx(n.bounds,t.bounds,e)?t.opacity=0:(n=t,1)))}},bx=(t,e,n)=>n>Math.max(e.x1-t.x2,t.x1-e.x2,e.y1-t.y2,t.y1-e.y2),wx=(t,e)=>{for(var n,r=1,i=t.length,o=t[0].bounds;r{const e=t.bounds;return e.width()>1&&e.height()>1},Ax=t=>(t.forEach((t=>t.opacity=1)),t),Mx=(t,e)=>t.reflow(e.modified()).modifies("opacity");function Ex(t){Ja.call(this,null,t)}dt(_x,Ja,{transform(t,e){const n=xx[t.method]||xx.parity,r=t.separation||0;let i,o,a=e.materialize(e.SOURCE).source;if(!a||!a.length)return;if(!t.method)return t.modified("method")&&(Ax(a),e=Mx(e,t)),e;if(a=a.filter(kx),!a.length)return;if(t.sort&&(a=a.slice().sort(t.sort)),i=Ax(a),e=Mx(e,t),i.length>=3&&wx(i,r)){do{i=n(i,r)}while(i.length>=3&&wx(i,r));i.length<3&&!F(a).opacity&&(i.length>1&&(F(i).opacity=0),F(a).opacity=1)}t.boundScale&&t.boundTolerance>=0&&(o=((t,e,n)=>{var r=t.range(),i=new Rg;return e===$_||e===z_?i.set(r[0],-1/0,r[1],1/0):i.set(-1/0,r[0],1/0,r[1]),i.expand(n||1),t=>i.encloses(t.bounds)})(t.boundScale,t.boundOrient,+t.boundTolerance),a.forEach((t=>{o(t)||(t.opacity=0)})));const s=i[0].mark.bounds.clear();return a.forEach((t=>{t.opacity&&s.union(t.bounds)})),e}}),dt(Ex,Ja,{transform(t,e){const n=e.dataflow;if(e.visit(e.ALL,(t=>n.dirty(t))),e.fields&&e.fields.zindex){const t=e.source&&e.source[0];t&&(t.mark.zdirty=!0)}}});const Dx=new Rg;function Cx(t,e,n){return t[e]===n?0:(t[e]=n,1)}function Fx(t){var e=t.items[0].orient;return e===T_||e===B_}function Sx(t,e,n,r){var i,o,a=e.items[0],s=a.datum,u=null!=a.translate?a.translate:.5,l=a.orient,c=function(t){let e=+t.grid;return[t.ticks?e++:-1,t.labels?e++:-1,e+ +t.domain]}(s),f=a.range,h=a.offset,d=a.position,p=a.minExtent,g=a.maxExtent,m=s.title&&a.items[c[2]].items[0],y=a.titlePadding,v=a.bounds,_=m&&ky(m),x=0,b=0;switch(Dx.clear().union(v),v.clear(),(i=c[0])>-1&&v.union(a.items[i].bounds),(i=c[1])>-1&&v.union(a.items[i].bounds),l){case $_:x=d||0,b=-h,o=Math.max(p,Math.min(g,-v.y1)),v.add(0,-o).add(f,0),m&&$x(t,m,o,y,_,0,-1,v);break;case T_:x=-h,b=d||0,o=Math.max(p,Math.min(g,-v.x1)),v.add(-o,0).add(0,f),m&&$x(t,m,o,y,_,1,-1,v);break;case B_:x=n+h,b=d||0,o=Math.max(p,Math.min(g,v.x2)),v.add(0,0).add(o,f),m&&$x(t,m,o,y,_,1,1,v);break;case z_:x=d||0,b=r+h,o=Math.max(p,Math.min(g,v.y2)),v.add(0,0).add(f,o),m&&$x(t,m,o,y,0,0,1,v);break;default:x=a.x,b=a.y}return Ig(v.translate(x,b),a),Cx(a,"x",x+u)|Cx(a,"y",b+u)&&(a.bounds=Dx,t.dirty(a),a.bounds=v,t.dirty(a)),a.mark.bounds.clear().union(v)}function $x(t,e,n,r,i,o,a,s){const u=e.bounds;if(e.auto){const s=a*(n+i+r);let l=0,c=0;t.dirty(e),o?l=(e.x||0)-(e.x=s):c=(e.y||0)-(e.y=s),e.mark.bounds.clear().union(u.translate(-l,-c)),t.dirty(e)}s.union(u)}const Tx=(t,e)=>Math.floor(Math.min(t,e)),Bx=(t,e)=>Math.ceil(Math.max(t,e));function zx(t){return(new Rg).set(0,0,t.width||0,t.height||0)}function Nx(t){const e=t.bounds.clone();return e.empty()?e.set(0,0,0,0):e.translate(-(t.x||0),-(t.y||0))}function Ox(t,e,n){const r=A(t)?t[e]:t;return null!=r?r:void 0!==n?n:0}function Rx(t){return t<0?Math.ceil(-t):0}function Ux(t,e,n){var r,i,o,a,s,u,l,c,f,h,d,p=!n.nodirty,g=n.bounds===fx?zx:Nx,m=Dx.set(0,0,0,0),y=Ox(n.align,hx),v=Ox(n.align,dx),_=Ox(n.padding,hx),x=Ox(n.padding,dx),b=n.columns||e.length,w=b<=0?1:Math.ceil(e.length/b),k=e.length,A=Array(k),M=Array(b),E=0,D=Array(k),C=Array(w),F=0,S=Array(k),$=Array(k),T=Array(k);for(i=0;i1)for(i=0;i0&&(S[i]+=f/2);if(v&&Ox(n.center,dx)&&1!==b)for(i=0;i0&&($[i]+=h/2);for(i=0;ii&&(t.warn("Grid headers exceed limit: "+i),e=e.slice(0,i)),A+=o,g=0,y=e.length;g=0&&null==(x=n[m]);m-=h);s?(b=null==d?x.x:Math.round(x.bounds.x1+d*x.bounds.width()),w=A):(b=A,w=null==d?x.y:Math.round(x.bounds.y1+d*x.bounds.height())),v.union(_.bounds.translate(b-(_.x||0),w-(_.y||0))),_.x=b,_.y=w,t.dirty(_),M=a(M,v[l])}return M}function Ix(t,e,n,r,i,o){if(e){t.dirty(e);var a=n,s=n;r?a=Math.round(i.x1+o*i.width()):s=Math.round(i.y1+o*i.height()),e.bounds.translate(a-(e.x||0),s-(e.y||0)),e.mark.bounds.clear().union(e.bounds),e.x=a,e.y=s,t.dirty(e)}}function Wx(t,e,n,r,i,o,a){const s=function(t,e){const n=t[e]||{};return(e,r)=>null!=n[e]?n[e]:null!=t[e]?t[e]:r}(n,e),u=function(t,e){let n=-1/0;return t.forEach((t=>{null!=t.offset&&(n=Math.max(n,t.offset))})),n>-1/0?n:e}(t,s("offset",0)),l=s("anchor",L_),c=l===P_?1:l===q_?.5:0,f={align:cx,bounds:s("bounds",fx),columns:"vertical"===s("direction")?1:t.length,padding:s("margin",8),center:s("center"),nodirty:!0};switch(e){case T_:f.anchor={x:Math.floor(r.x1)-u,column:P_,y:c*(a||r.height()+2*r.y1),row:l};break;case B_:f.anchor={x:Math.ceil(r.x2)+u,y:c*(a||r.height()+2*r.y1),row:l};break;case $_:f.anchor={y:Math.floor(i.y1)-u,row:P_,x:c*(o||i.width()+2*i.x1),column:l};break;case z_:f.anchor={y:Math.ceil(i.y2)+u,x:c*(o||i.width()+2*i.x1),column:l};break;case N_:f.anchor={x:u,y:u};break;case O_:f.anchor={x:o-u,y:u,column:P_};break;case R_:f.anchor={x:u,y:a-u,row:P_};break;case U_:f.anchor={x:o-u,y:a-u,column:P_,row:P_}}return f}function Hx(t,e){var n,r,i=e.items[0],o=i.datum,a=i.orient,s=i.bounds,u=i.x,l=i.y;return i._bounds?i._bounds.clear().union(s):i._bounds=s.clone(),s.clear(),function(t,e,n){var r=e.padding,i=r-n.x,o=r-n.y;if(e.datum.title){var a=e.items[1].items[0],s=a.anchor,u=e.titlePadding||0,l=r-a.x,c=r-a.y;switch(a.orient){case T_:i+=Math.ceil(a.bounds.width())+u;break;case B_:case z_:break;default:o+=a.bounds.height()+u}switch((i||o)&&Gx(t,n,i,o),a.orient){case T_:c+=Yx(e,n,a,s,1,1);break;case B_:l+=Yx(e,n,a,P_,0,0)+u,c+=Yx(e,n,a,s,1,1);break;case z_:l+=Yx(e,n,a,s,0,0),c+=Yx(e,n,a,P_,-1,0,1)+u;break;default:l+=Yx(e,n,a,s,0,0)}(l||c)&&Gx(t,a,l,c),(l=Math.round(a.bounds.x1-r))<0&&(Gx(t,n,-l,0),Gx(t,a,-l,0))}else(i||o)&&Gx(t,n,i,o)}(t,i,i.items[0].items[0]),s=function(t,e){return t.items.forEach((t=>e.union(t.bounds))),e.x1=t.padding,e.y1=t.padding,e}(i,s),n=2*i.padding,r=2*i.padding,s.empty()||(n=Math.ceil(s.width()+n),r=Math.ceil(s.height()+r)),o.type===rx&&function(t){const e=t.reduce(((t,e)=>(t[e.column]=Math.max(e.bounds.x2-e.x,t[e.column]||0),t)),{});t.forEach((t=>{t.width=e[t.column],t.height=t.bounds.y2-t.y}))}(i.items[0].items[0].items[0].items),a!==ux&&(i.x=u=0,i.y=l=0),i.width=n,i.height=r,Ig(s.set(u,l,u+n,l+r),i),i.mark.bounds.clear().union(s),i}function Yx(t,e,n,r,i,o,a){const s="symbol"!==t.datum.type,u=n.datum.vgrad,l=(!s||!o&&u||a?e:e.items[0]).bounds[i?"y2":"x2"]-t.padding,c=u&&o?l:0,f=u&&o?0:l,h=i<=0?0:ky(n);return Math.round(r===L_?c:r===P_?f-h:.5*(l-h))}function Gx(t,e,n,r){e.x+=n,e.y+=r,e.bounds.translate(n,r),e.mark.bounds.translate(n,r),t.dirty(e)}function Vx(t){Ja.call(this,null,t)}dt(Vx,Ja,{transform(t,e){const n=e.dataflow;return t.mark.items.forEach((e=>{t.layout&&Lx(n,e,t.layout),function(t,e,n){var r,i,o,a,s,u=e.items,l=Math.max(0,e.width||0),c=Math.max(0,e.height||0),f=(new Rg).set(0,0,l,c),h=f.clone(),d=f.clone(),p=[];for(a=0,s=u.length;a{(o=t.orient||B_)!==ux&&(e[o]||(e[o]=[])).push(t)}));for(const r in e){const i=e[r];Ux(t,i,Wx(i,r,n.legends,h,d,l,c))}p.forEach((e=>{const r=e.bounds;if(r.equals(e._bounds)||(e.bounds=e._bounds,t.dirty(e),e.bounds=r,t.dirty(e)),!n.autosize||n.autosize.type!==ix&&n.autosize.type!==ox&&n.autosize.type!==ax)f.union(r);else switch(e.orient){case T_:case B_:f.add(r.x1,0).add(r.x2,0);break;case $_:case z_:f.add(0,r.y1).add(0,r.y2)}}))}f.union(h).union(d),r&&f.union(function(t,e,n,r,i){var o,a=e.items[0],s=a.frame,u=a.orient,l=a.anchor,c=a.offset,f=a.padding,h=a.items[0].items[0],d=a.items[1]&&a.items[1].items[0],p=u===T_||u===B_?r:n,g=0,m=0,y=0,v=0,_=0;if(s!==W_?u===T_?(g=i.y2,p=i.y1):u===B_?(g=i.y1,p=i.y2):(g=i.x1,p=i.x2):u===T_&&(g=r,p=0),o=l===L_?g:l===P_?p:(g+p)/2,d&&d.text){switch(u){case $_:case z_:_=h.bounds.height()+f;break;case T_:v=h.bounds.width()+f;break;case B_:v=-h.bounds.width()-f}Dx.clear().union(d.bounds),Dx.translate(v-(d.x||0),_-(d.y||0)),Cx(d,"x",v)|Cx(d,"y",_)&&(t.dirty(d),d.bounds.clear().union(Dx),d.mark.bounds.clear().union(Dx),t.dirty(d)),Dx.clear().union(d.bounds)}else Dx.clear();switch(Dx.union(h.bounds),u){case $_:m=o,y=i.y1-Dx.height()-c;break;case T_:m=i.x1-Dx.width()-c,y=o;break;case B_:m=i.x2+Dx.width()+c,y=o;break;case z_:m=o,y=i.y2+c;break;default:m=a.x,y=a.y}return Cx(a,"x",m)|Cx(a,"y",y)&&(Dx.translate(m,y),t.dirty(a),a.bounds.clear().union(Dx),e.bounds.clear().union(Dx),t.dirty(a)),a.bounds}(t,r,l,c,f));e.clip&&f.set(0,0,e.width||0,e.height||0);!function(t,e,n,r){const i=r.autosize||{},o=i.type;if(t._autosize<1||!o)return;let a=t._width,s=t._height,u=Math.max(0,e.width||0),l=Math.max(0,Math.ceil(-n.x1)),c=Math.max(0,e.height||0),f=Math.max(0,Math.ceil(-n.y1));const h=Math.max(0,Math.ceil(n.x2-u)),d=Math.max(0,Math.ceil(n.y2-c));if(i.contains===nx){const e=t.padding();a-=e.left+e.right,s-=e.top+e.bottom}o===ux?(l=0,f=0,u=a,c=s):o===ix?(u=Math.max(0,a-l-h),c=Math.max(0,s-f-d)):o===ox?(u=Math.max(0,a-l-h),s=c+f+d):o===ax?(a=u+l+h,c=Math.max(0,s-f-d)):o===sx&&(a=u+l+h,s=c+f+d);t._resizeView(a,s,u,c,[l,f],i.resize)}(t,e,f,n)}(n,e,t)})),function(t){return t&&"legend-entry"!==t.mark.role}(t.mark.group)?e.reflow():e}});var Xx=Object.freeze({__proto__:null,bound:px,identifier:yx,mark:vx,overlap:_x,render:Ex,viewlayout:Vx});function Jx(t){Ja.call(this,null,t)}function Zx(t){Ja.call(this,null,t)}function Qx(){return _a({})}function Kx(t){Ja.call(this,null,t)}function tb(t){Ja.call(this,[],t)}dt(Jx,Ja,{transform(t,e){if(this.value&&!t.modified())return e.StopPropagation;var n=e.dataflow.locale(),r=e.fork(e.NO_SOURCE|e.NO_FIELDS),i=this.value,o=t.scale,a=_p(o,null==t.count?t.values?t.values.length:10:t.count,t.minstep),s=t.format||wp(n,o,a,t.formatSpecifier,t.formatType,!!t.values),u=t.values?xp(o,t.values,a):bp(o,a);return i&&(r.rem=i),i=u.map(((t,e)=>_a({index:e/(u.length-1||1),value:t,label:s(t)}))),t.extra&&i.length&&i.push(_a({index:-1,extra:{value:i[0].value},label:""})),r.source=i,r.add=i,this.value=i,r}}),dt(Zx,Ja,{transform(t,e){var n=e.dataflow,r=e.fork(e.NO_SOURCE|e.NO_FIELDS),i=t.item||Qx,o=t.key||ya,a=this.value;return k(r.encode)&&(r.encode=null),a&&(t.modified("key")||e.modified(o))&&s("DataJoin does not support modified key function or fields."),a||(e=e.addAll(),this.value=a=function(t){const e=ft().test((t=>t.exit));return e.lookup=n=>e.get(t(n)),e}(o)),e.visit(e.ADD,(t=>{const e=o(t);let n=a.get(e);n?n.exit?(a.empty--,r.add.push(n)):r.mod.push(n):(n=i(t),a.set(e,n),r.add.push(n)),n.datum=t,n.exit=!1})),e.visit(e.MOD,(t=>{const e=o(t),n=a.get(e);n&&(n.datum=t,r.mod.push(n))})),e.visit(e.REM,(t=>{const e=o(t),n=a.get(e);t!==n.datum||n.exit||(r.rem.push(n),n.exit=!0,++a.empty)})),e.changed(e.ADD_MOD)&&r.modifies("datum"),(e.clean()||t.clean&&a.empty>n.cleanThreshold)&&n.runAfter(a.clean),r}}),dt(Kx,Ja,{transform(t,e){var n=e.fork(e.ADD_REM),r=t.mod||!1,i=t.encoders,o=e.encode;if(k(o)){if(!n.changed()&&!o.every((t=>i[t])))return e.StopPropagation;o=o[0],n.encode=null}var a="enter"===o,s=i.update||g,u=i.enter||g,l=i.exit||g,c=(o&&!a?i[o]:s)||g;if(e.changed(e.ADD)&&(e.visit(e.ADD,(e=>{u(e,t),s(e,t)})),n.modifies(u.output),n.modifies(s.output),c!==g&&c!==s&&(e.visit(e.ADD,(e=>{c(e,t)})),n.modifies(c.output))),e.changed(e.REM)&&l!==g&&(e.visit(e.REM,(e=>{l(e,t)})),n.modifies(l.output)),a||c!==g){const i=e.MOD|(t.modified()?e.REFLOW:0);a?(e.visit(i,(e=>{const i=u(e,t)||r;(c(e,t)||i)&&n.mod.push(e)})),n.mod.length&&n.modifies(u.output)):e.visit(i,(e=>{(c(e,t)||r)&&n.mod.push(e)})),n.mod.length&&n.modifies(c.output)}return n.changed()?n:e.StopPropagation}}),dt(tb,Ja,{transform(t,e){if(null!=this.value&&!t.modified())return e.StopPropagation;var n,r,i,o,a,s=e.dataflow.locale(),u=e.fork(e.NO_SOURCE|e.NO_FIELDS),l=this.value,c=t.type||pp,f=t.scale,h=+t.limit,d=_p(f,null==t.count?5:t.count,t.minstep),p=!!t.values||c===pp,g=t.format||Cp(s,f,d,c,t.formatSpecifier,t.formatType,p),m=t.values||Ep(f,d);return l&&(u.rem=l),c===pp?(h&&m.length>h?(e.dataflow.warn("Symbol legend count exceeds limit, filtering items."),l=m.slice(0,h-1),a=!0):l=m,J(i=t.size)?(t.values||0!==f(l[0])||(l=l.slice(1)),o=l.reduce(((e,n)=>Math.max(e,i(n,t))),0)):i=rt(o=i||8),l=l.map(((e,n)=>_a({index:n,label:g(e,n,l),value:e,offset:o,size:i(e,t)}))),a&&(a=m[l.length],l.push(_a({index:l.length,label:`…${m.length-l.length} entries`,value:a,offset:o,size:i(a,t)})))):"gradient"===c?(n=f.domain(),r=up(f,n[0],F(n)),m.length<3&&!t.values&&n[0]!==F(n)&&(m=[n[0],F(n)]),l=m.map(((t,e)=>_a({index:e,label:g(t,e,m),value:t,perc:r(t)})))):(i=m.length-1,r=function(t){const e=t.domain(),n=e.length-1;let r=+e[0],i=+F(e),o=i-r;if(t.type===Td){const t=n?o/n:.1;r-=t,i+=t,o=i-r}return t=>(t-r)/o}(f),l=m.map(((t,e)=>_a({index:e,label:g(t,e,m),value:t,perc:e?r(t):0,perc2:e===i?1:r(m[e+1])})))),u.source=l,u.add=l,this.value=l,u}});const eb=t=>t.source.x,nb=t=>t.source.y,rb=t=>t.target.x,ib=t=>t.target.y;function ob(t){Ja.call(this,{},t)}ob.Definition={type:"LinkPath",metadata:{modifies:!0},params:[{name:"sourceX",type:"field",default:"source.x"},{name:"sourceY",type:"field",default:"source.y"},{name:"targetX",type:"field",default:"target.x"},{name:"targetY",type:"field",default:"target.y"},{name:"orient",type:"enum",default:"vertical",values:["horizontal","vertical","radial"]},{name:"shape",type:"enum",default:"line",values:["line","arc","curve","diagonal","orthogonal"]},{name:"require",type:"signal"},{name:"as",type:"string",default:"path"}]},dt(ob,Ja,{transform(t,e){var n=t.sourceX||eb,r=t.sourceY||nb,i=t.targetX||rb,o=t.targetY||ib,a=t.as||"path",u=t.orient||"vertical",l=t.shape||"line",c=lb.get(l+"-"+u)||lb.get(l);return c||s("LinkPath unsupported type: "+t.shape+(t.orient?"-"+t.orient:"")),e.visit(e.SOURCE,(t=>{t[a]=c(n(t),r(t),i(t),o(t))})),e.reflow(t.modified()).modifies(a)}});const ab=(t,e,n,r)=>"M"+t+","+e+"L"+n+","+r,sb=(t,e,n,r)=>{var i=n-t,o=r-e,a=Math.hypot(i,o)/2;return"M"+t+","+e+"A"+a+","+a+" "+180*Math.atan2(o,i)/Math.PI+" 0 1 "+n+","+r},ub=(t,e,n,r)=>{const i=n-t,o=r-e,a=.2*(i+o),s=.2*(o-i);return"M"+t+","+e+"C"+(t+a)+","+(e+s)+" "+(n+s)+","+(r-a)+" "+n+","+r},lb=ft({line:ab,"line-radial":(t,e,n,r)=>ab(e*Math.cos(t),e*Math.sin(t),r*Math.cos(n),r*Math.sin(n)),arc:sb,"arc-radial":(t,e,n,r)=>sb(e*Math.cos(t),e*Math.sin(t),r*Math.cos(n),r*Math.sin(n)),curve:ub,"curve-radial":(t,e,n,r)=>ub(e*Math.cos(t),e*Math.sin(t),r*Math.cos(n),r*Math.sin(n)),"orthogonal-horizontal":(t,e,n,r)=>"M"+t+","+e+"V"+r+"H"+n,"orthogonal-vertical":(t,e,n,r)=>"M"+t+","+e+"H"+n+"V"+r,"orthogonal-radial":(t,e,n,r)=>{const i=Math.cos(t),o=Math.sin(t),a=Math.cos(n),s=Math.sin(n);return"M"+e*i+","+e*o+"A"+e+","+e+" 0 0,"+((Math.abs(n-t)>Math.PI?n<=t:n>t)?1:0)+" "+e*a+","+e*s+"L"+r*a+","+r*s},"diagonal-horizontal":(t,e,n,r)=>{const i=(t+n)/2;return"M"+t+","+e+"C"+i+","+e+" "+i+","+r+" "+n+","+r},"diagonal-vertical":(t,e,n,r)=>{const i=(e+r)/2;return"M"+t+","+e+"C"+t+","+i+" "+n+","+i+" "+n+","+r},"diagonal-radial":(t,e,n,r)=>{const i=Math.cos(t),o=Math.sin(t),a=Math.cos(n),s=Math.sin(n),u=(e+r)/2;return"M"+e*i+","+e*o+"C"+u*i+","+u*o+" "+u*a+","+u*s+" "+r*a+","+r*s}});function cb(t){Ja.call(this,null,t)}cb.Definition={type:"Pie",metadata:{modifies:!0},params:[{name:"field",type:"field"},{name:"startAngle",type:"number",default:0},{name:"endAngle",type:"number",default:6.283185307179586},{name:"sort",type:"boolean",default:!1},{name:"as",type:"string",array:!0,length:2,default:["startAngle","endAngle"]}]},dt(cb,Ja,{transform(t,e){var n,r,i,o=t.as||["startAngle","endAngle"],a=o[0],s=o[1],u=t.field||d,l=t.startAngle||0,c=null!=t.endAngle?t.endAngle:2*Math.PI,f=e.source,h=f.map(u),p=h.length,g=l,m=(c-l)/$e(h),y=Se(p);for(t.sort&&y.sort(((t,e)=>h[t]-h[e])),n=0;nt+(e<0?-1:e>0?1:0)),0))!==e.length&&n.warn("Log scale domain includes zero: "+Ct(e)));return e}function mb(t,e,n){return J(t)&&(e||n)?op(t,yb(e||[0,1],n)):t}function yb(t,e){return e?t.slice().reverse():t}function vb(t){Ja.call(this,null,t)}dt(pb,Ja,{transform(t,e){var n=e.dataflow,r=this.value,i=function(t){var e,n=t.type,r="";if(n===Cd)return Cd+"-"+bd;(function(t){const e=t.type;return Qd(e)&&e!==Ed&&e!==Dd&&(t.scheme||t.range&&t.range.length&&t.range.every(xt))})(t)&&(r=2===(e=t.rawDomain?t.rawDomain.length:t.domain?t.domain.length+ +(null!=t.domainMid):0)?Cd+"-":3===e?Fd+"-":"");return(r+n||bd).toLowerCase()}(t);for(i in r&&i===r.type||(this.value=r=Xd(i)()),t)if(!db[i]){if("padding"===i&&hb(r.type))continue;J(r[i])?r[i](t[i]):n.warn("Unsupported scale property: "+i)}return function(t,e,n){var r=t.type,i=e.round||!1,o=e.range;if(null!=e.rangeStep)o=function(t,e,n){t!==Nd&&t!==zd&&s("Only band and point scales support rangeStep.");var r=(null!=e.paddingOuter?e.paddingOuter:e.padding)||0,i=t===zd?1:(null!=e.paddingInner?e.paddingInner:e.padding)||0;return[0,e.rangeStep*xd(n,i,r)]}(r,e,n);else if(e.scheme&&(o=function(t,e,n){var r,i=e.schemeExtent;k(e.scheme)?r=ap(e.scheme,e.interpolate,e.interpolateGamma):(r=dp(e.scheme.toLowerCase()))||s(`Unrecognized scheme name: ${e.scheme}`);return n=t===Td?n+1:t===Od?n-1:t===Sd||t===$d?+e.schemeCount||fb:n,np(t)?mb(r,i,e.reverse):J(r)?sp(mb(r,i),n):t===Bd?r:r.slice(0,n)}(r,e,n),J(o))){if(t.interpolator)return t.interpolator(o);s(`Scale type ${r} does not support interpolating color schemes.`)}if(o&&np(r))return t.interpolator(ap(yb(o,e.reverse),e.interpolate,e.interpolateGamma));o&&e.interpolate&&t.interpolate?t.interpolate(lp(e.interpolate,e.interpolateGamma)):J(t.round)?t.round(i):J(t.rangeRound)&&t.interpolate(i?yh:mh);o&&t.range(yb(o,e.reverse))}(r,t,function(t,e,n){let r=e.bins;if(r&&!k(r)){const e=t.domain(),n=e[0],i=F(e),o=r.step;let a=null==r.start?n:r.start,u=null==r.stop?i:r.stop;o||s("Scale bins parameter missing step property."),ai&&(u=o*Math.floor(i/o)),r=Se(a,u+o/2,o)}r?t.bins=r:t.bins&&delete t.bins;t.type===Od&&(r?e.domain||e.domainRaw||(t.domain(r),n=r.length):t.bins=t.domain());return n}(r,t,function(t,e,n){const r=function(t,e,n){return e?(t.domain(gb(t.type,e,n)),e.length):-1}(t,e.domainRaw,n);if(r>-1)return r;var i,o,a=e.domain,s=t.type,u=e.zero||void 0===e.zero&&function(t){const e=t.type;return!t.bins&&(e===bd||e===kd||e===Ad)}(t);if(!a)return 0;hb(s)&&e.padding&&a[0]!==F(a)&&(a=function(t,e,n,r,i,o){var a=Math.abs(F(n)-n[0]),s=a/(a-2*r),u=t===wd?I(e,null,s):t===Ad?W(e,null,s,.5):t===kd?W(e,null,s,i||1):t===Md?H(e,null,s,o||1):j(e,null,s);return e=e.slice(),e[0]=u[0],e[e.length-1]=u[1],e}(s,a,e.range,e.padding,e.exponent,e.constant));if((u||null!=e.domainMin||null!=e.domainMax||null!=e.domainMid)&&(i=(a=a.slice()).length-1||1,u&&(a[0]>0&&(a[0]=0),a[i]<0&&(a[i]=0)),null!=e.domainMin&&(a[0]=e.domainMin),null!=e.domainMax&&(a[i]=e.domainMax),null!=e.domainMid)){const t=(o=e.domainMid)>a[i]?i+1:ot(u);if(null==e)d.push(t.slice());else for(i={},o=0,a=t.length;oh&&(h=f),n&&c.sort(n)}return d.max=h,d}(e.source,t.groupby,l,c),r=0,i=n.length,o=n.max;r0?1:t<0?-1:0},Gb=Math.sqrt,Vb=Math.tan;function Xb(t){return t>1?0:t<-1?Sb:Math.acos(t)}function Jb(t){return t>1?$b:t<-1?-$b:Math.asin(t)}function Zb(){}function Qb(t,e){t&&tw.hasOwnProperty(t.type)&&tw[t.type](t,e)}var Kb={Feature:function(t,e){Qb(t.geometry,e)},FeatureCollection:function(t,e){for(var n=t.features,r=-1,i=n.length;++r=0?1:-1,i=r*n,o=Lb(e=(e*=Nb)/2+Tb),a=Hb(e),s=uw*a,u=sw*o+s*Lb(i),l=s*r*Hb(i);xw.add(Ub(l,u)),aw=t,sw=o,uw=a}function Dw(t){return[Ub(t[1],t[0]),Jb(t[2])]}function Cw(t){var e=t[0],n=t[1],r=Lb(n);return[r*Lb(e),r*Hb(e),Hb(n)]}function Fw(t,e){return t[0]*e[0]+t[1]*e[1]+t[2]*e[2]}function Sw(t,e){return[t[1]*e[2]-t[2]*e[1],t[2]*e[0]-t[0]*e[2],t[0]*e[1]-t[1]*e[0]]}function $w(t,e){t[0]+=e[0],t[1]+=e[1],t[2]+=e[2]}function Tw(t,e){return[t[0]*e,t[1]*e,t[2]*e]}function Bw(t){var e=Gb(t[0]*t[0]+t[1]*t[1]+t[2]*t[2]);t[0]/=e,t[1]/=e,t[2]/=e}var zw,Nw,Ow,Rw,Uw,Lw,qw,Pw,jw,Iw,Ww,Hw,Yw,Gw,Vw,Xw,Jw={point:Zw,lineStart:Kw,lineEnd:tk,polygonStart:function(){Jw.point=ek,Jw.lineStart=nk,Jw.lineEnd=rk,yw=new se,ww.polygonStart()},polygonEnd:function(){ww.polygonEnd(),Jw.point=Zw,Jw.lineStart=Kw,Jw.lineEnd=tk,xw<0?(lw=-(fw=180),cw=-(hw=90)):yw>Cb?hw=90:yw<-Cb&&(cw=-90),_w[0]=lw,_w[1]=fw},sphere:function(){lw=-(fw=180),cw=-(hw=90)}};function Zw(t,e){vw.push(_w=[lw=t,fw=t]),ehw&&(hw=e)}function Qw(t,e){var n=Cw([t*Nb,e*Nb]);if(mw){var r=Sw(mw,n),i=Sw([r[1],-r[0],0],r);Bw(i),i=Dw(i);var o,a=t-dw,s=a>0?1:-1,u=i[0]*zb*s,l=Ob(a)>180;l^(s*dwhw&&(hw=o):l^(s*dw<(u=(u+360)%360-180)&&uhw&&(hw=e)),l?tik(lw,fw)&&(fw=t):ik(t,fw)>ik(lw,fw)&&(lw=t):fw>=lw?(tfw&&(fw=t)):t>dw?ik(lw,t)>ik(lw,fw)&&(fw=t):ik(t,fw)>ik(lw,fw)&&(lw=t)}else vw.push(_w=[lw=t,fw=t]);ehw&&(hw=e),mw=n,dw=t}function Kw(){Jw.point=Qw}function tk(){_w[0]=lw,_w[1]=fw,Jw.point=Zw,mw=null}function ek(t,e){if(mw){var n=t-dw;yw.add(Ob(n)>180?n+(n>0?360:-360):n)}else pw=t,gw=e;ww.point(t,e),Qw(t,e)}function nk(){ww.lineStart()}function rk(){ek(pw,gw),ww.lineEnd(),Ob(yw)>Cb&&(lw=-(fw=180)),_w[0]=lw,_w[1]=fw,mw=null}function ik(t,e){return(e-=t)<0?e+360:e}function ok(t,e){return t[0]-e[0]}function ak(t,e){return t[0]<=t[1]?t[0]<=e&&e<=t[1]:eSb&&(t-=Math.round(t/Bb)*Bb),[t,e]}function xk(t,e,n){return(t%=Bb)?e||n?vk(wk(t),kk(e,n)):wk(t):e||n?kk(e,n):_k}function bk(t){return function(e,n){return Ob(e+=t)>Sb&&(e-=Math.round(e/Bb)*Bb),[e,n]}}function wk(t){var e=bk(t);return e.invert=bk(-t),e}function kk(t,e){var n=Lb(t),r=Hb(t),i=Lb(e),o=Hb(e);function a(t,e){var a=Lb(e),s=Lb(t)*a,u=Hb(t)*a,l=Hb(e),c=l*n+s*r;return[Ub(u*i-c*o,s*n-l*r),Jb(c*i+u*o)]}return a.invert=function(t,e){var a=Lb(e),s=Lb(t)*a,u=Hb(t)*a,l=Hb(e),c=l*i-u*o;return[Ub(u*i+l*o,s*n+c*r),Jb(c*n-s*r)]},a}function Ak(t,e){(e=Cw(e))[0]-=t,Bw(e);var n=Xb(-e[1]);return((-e[2]<0?-n:n)+Bb-Cb)%Bb}function Mk(){var t,e=[];return{point:function(e,n,r){t.push([e,n,r])},lineStart:function(){e.push(t=[])},lineEnd:Zb,rejoin:function(){e.length>1&&e.push(e.pop().concat(e.shift()))},result:function(){var n=e;return e=[],t=null,n}}}function Ek(t,e){return Ob(t[0]-e[0])=0;--o)i.point((c=l[o])[0],c[1]);else r(h.x,h.p.x,-1,i);h=h.p}l=(h=h.o).z,d=!d}while(!h.v);i.lineEnd()}}}function Fk(t){if(e=t.length){for(var e,n,r=0,i=t[0];++r=0?1:-1,E=M*A,D=E>Sb,C=m*w;if(u.add(Ub(C*M*Hb(E),y*k+C*Lb(E))),a+=D?A+M*Bb:A,D^p>=n^x>=n){var F=Sw(Cw(d),Cw(_));Bw(F);var S=Sw(o,F);Bw(S);var $=(D^A>=0?-1:1)*Jb(S[2]);(r>$||r===$&&(F[0]||F[1]))&&(s+=D^A>=0?1:-1)}}return(a<-Cb||a0){for(f||(i.polygonStart(),f=!0),i.lineStart(),t=0;t1&&2&u&&h.push(h.pop().concat(h.shift())),a.push(h.filter(Tk))}return h}}function Tk(t){return t.length>1}function Bk(t,e){return((t=t.x)[0]<0?t[1]-$b-Cb:$b-t[1])-((e=e.x)[0]<0?e[1]-$b-Cb:$b-e[1])}_k.invert=_k;var zk=$k((function(){return!0}),(function(t){var e,n=NaN,r=NaN,i=NaN;return{lineStart:function(){t.lineStart(),e=1},point:function(o,a){var s=o>0?Sb:-Sb,u=Ob(o-n);Ob(u-Sb)0?$b:-$b),t.point(i,r),t.lineEnd(),t.lineStart(),t.point(s,r),t.point(o,r),e=0):i!==s&&u>=Sb&&(Ob(n-i)Cb?Rb((Hb(e)*(o=Lb(r))*Hb(n)-Hb(r)*(i=Lb(e))*Hb(t))/(i*o*a)):(e+r)/2}(n,r,o,a),t.point(i,r),t.lineEnd(),t.lineStart(),t.point(s,r),e=0),t.point(n=o,r=a),i=s},lineEnd:function(){t.lineEnd(),n=r=NaN},clean:function(){return 2-e}}}),(function(t,e,n,r){var i;if(null==t)i=n*$b,r.point(-Sb,i),r.point(0,i),r.point(Sb,i),r.point(Sb,0),r.point(Sb,-i),r.point(0,-i),r.point(-Sb,-i),r.point(-Sb,0),r.point(-Sb,i);else if(Ob(t[0]-e[0])>Cb){var o=t[0]0,i=Ob(e)>Cb;function o(t,n){return Lb(t)*Lb(n)>e}function a(t,n,r){var i=[1,0,0],o=Sw(Cw(t),Cw(n)),a=Fw(o,o),s=o[0],u=a-s*s;if(!u)return!r&&t;var l=e*a/u,c=-e*s/u,f=Sw(i,o),h=Tw(i,l);$w(h,Tw(o,c));var d=f,p=Fw(h,d),g=Fw(d,d),m=p*p-g*(Fw(h,h)-1);if(!(m<0)){var y=Gb(m),v=Tw(d,(-p-y)/g);if($w(v,h),v=Dw(v),!r)return v;var _,x=t[0],b=n[0],w=t[1],k=n[1];b0^v[1]<(Ob(v[0]-x)Sb^(x<=v[0]&&v[0]<=b)){var E=Tw(d,(-p+y)/g);return $w(E,h),[v,Dw(E)]}}}function s(e,n){var i=r?t:Sb-t,o=0;return e<-i?o|=1:e>i&&(o|=2),n<-i?o|=4:n>i&&(o|=8),o}return $k(o,(function(t){var e,n,u,l,c;return{lineStart:function(){l=u=!1,c=1},point:function(f,h){var d,p=[f,h],g=o(f,h),m=r?g?0:s(f,h):g?s(f+(f<0?Sb:-Sb),h):0;if(!e&&(l=u=g)&&t.lineStart(),g!==u&&(!(d=a(e,p))||Ek(e,d)||Ek(p,d))&&(p[2]=1),g!==u)c=0,g?(t.lineStart(),d=a(p,e),t.point(d[0],d[1])):(d=a(e,p),t.point(d[0],d[1],2),t.lineEnd()),e=d;else if(i&&e&&r^g){var y;m&n||!(y=a(p,e,!0))||(c=0,r?(t.lineStart(),t.point(y[0][0],y[0][1]),t.point(y[1][0],y[1][1]),t.lineEnd()):(t.point(y[1][0],y[1][1]),t.lineEnd(),t.lineStart(),t.point(y[0][0],y[0][1],3)))}!g||e&&Ek(e,p)||t.point(p[0],p[1]),e=p,u=g,n=m},lineEnd:function(){u&&t.lineEnd(),e=null},clean:function(){return c|(l&&u)<<1}}}),(function(e,r,i,o){!function(t,e,n,r,i,o){if(n){var a=Lb(e),s=Hb(e),u=r*n;null==i?(i=e+r*Bb,o=e-u/2):(i=Ak(a,i),o=Ak(a,o),(r>0?io)&&(i+=r*Bb));for(var l,c=i;r>0?c>o:c0)do{l.point(0===c||3===c?t:n,c>1?r:e)}while((c=(c+s+4)%4)!==f);else l.point(o[0],o[1])}function a(r,i){return Ob(r[0]-t)0?0:3:Ob(r[0]-n)0?2:1:Ob(r[1]-e)0?1:0:i>0?3:2}function s(t,e){return u(t.x,e.x)}function u(t,e){var n=a(t,1),r=a(e,1);return n!==r?n-r:0===n?e[1]-t[1]:1===n?t[0]-e[0]:2===n?t[1]-e[1]:e[0]-t[0]}return function(a){var u,l,c,f,h,d,p,g,m,y,v,_=a,x=Mk(),b={point:w,lineStart:function(){b.point=k,l&&l.push(c=[]);y=!0,m=!1,p=g=NaN},lineEnd:function(){u&&(k(f,h),d&&m&&x.rejoin(),u.push(x.result()));b.point=w,m&&_.lineEnd()},polygonStart:function(){_=x,u=[],l=[],v=!0},polygonEnd:function(){var e=function(){for(var e=0,n=0,i=l.length;nr&&(h-o)*(r-a)>(d-a)*(t-o)&&++e:d<=r&&(h-o)*(r-a)<(d-a)*(t-o)&&--e;return e}(),n=v&&e,i=(u=Fe(u)).length;(n||i)&&(a.polygonStart(),n&&(a.lineStart(),o(null,null,1,a),a.lineEnd()),i&&Ck(u,s,e,o,a),a.polygonEnd());_=a,u=l=c=null}};function w(t,e){i(t,e)&&_.point(t,e)}function k(o,a){var s=i(o,a);if(l&&c.push([o,a]),y)f=o,h=a,d=s,y=!1,s&&(_.lineStart(),_.point(o,a));else if(s&&m)_.point(o,a);else{var u=[p=Math.max(Rk,Math.min(Ok,p)),g=Math.max(Rk,Math.min(Ok,g))],x=[o=Math.max(Rk,Math.min(Ok,o)),a=Math.max(Rk,Math.min(Ok,a))];!function(t,e,n,r,i,o){var a,s=t[0],u=t[1],l=0,c=1,f=e[0]-s,h=e[1]-u;if(a=n-s,f||!(a>0)){if(a/=f,f<0){if(a0){if(a>c)return;a>l&&(l=a)}if(a=i-s,f||!(a<0)){if(a/=f,f<0){if(a>c)return;a>l&&(l=a)}else if(f>0){if(a0)){if(a/=h,h<0){if(a0){if(a>c)return;a>l&&(l=a)}if(a=o-u,h||!(a<0)){if(a/=h,h<0){if(a>c)return;a>l&&(l=a)}else if(h>0){if(a0&&(t[0]=s+l*f,t[1]=u+l*h),c<1&&(e[0]=s+c*f,e[1]=u+c*h),!0}}}}}(u,x,t,e,n,r)?s&&(_.lineStart(),_.point(o,a),v=!1):(m||(_.lineStart(),_.point(u[0],u[1])),_.point(x[0],x[1]),s||_.lineEnd(),v=!1)}p=o,g=a,m=s}return b}}function Lk(t,e,n){var r=Se(t,e-Cb,n).concat(e);return function(t){return r.map((function(e){return[t,e]}))}}function qk(t,e,n){var r=Se(t,e-Cb,n).concat(e);return function(t){return r.map((function(e){return[e,t]}))}}var Pk,jk,Ik,Wk,Hk=t=>t,Yk=new se,Gk=new se,Vk={point:Zb,lineStart:Zb,lineEnd:Zb,polygonStart:function(){Vk.lineStart=Xk,Vk.lineEnd=Qk},polygonEnd:function(){Vk.lineStart=Vk.lineEnd=Vk.point=Zb,Yk.add(Ob(Gk)),Gk=new se},result:function(){var t=Yk/2;return Yk=new se,t}};function Xk(){Vk.point=Jk}function Jk(t,e){Vk.point=Zk,Pk=Ik=t,jk=Wk=e}function Zk(t,e){Gk.add(Wk*t-Ik*e),Ik=t,Wk=e}function Qk(){Zk(Pk,jk)}var Kk=1/0,tA=Kk,eA=-Kk,nA=eA,rA={point:function(t,e){teA&&(eA=t);enA&&(nA=e)},lineStart:Zb,lineEnd:Zb,polygonStart:Zb,polygonEnd:Zb,result:function(){var t=[[Kk,tA],[eA,nA]];return eA=nA=-(tA=Kk=1/0),t}};var iA,oA,aA,sA,uA=0,lA=0,cA=0,fA=0,hA=0,dA=0,pA=0,gA=0,mA=0,yA={point:vA,lineStart:_A,lineEnd:wA,polygonStart:function(){yA.lineStart=kA,yA.lineEnd=AA},polygonEnd:function(){yA.point=vA,yA.lineStart=_A,yA.lineEnd=wA},result:function(){var t=mA?[pA/mA,gA/mA]:dA?[fA/dA,hA/dA]:cA?[uA/cA,lA/cA]:[NaN,NaN];return uA=lA=cA=fA=hA=dA=pA=gA=mA=0,t}};function vA(t,e){uA+=t,lA+=e,++cA}function _A(){yA.point=xA}function xA(t,e){yA.point=bA,vA(aA=t,sA=e)}function bA(t,e){var n=t-aA,r=e-sA,i=Gb(n*n+r*r);fA+=i*(aA+t)/2,hA+=i*(sA+e)/2,dA+=i,vA(aA=t,sA=e)}function wA(){yA.point=vA}function kA(){yA.point=MA}function AA(){EA(iA,oA)}function MA(t,e){yA.point=EA,vA(iA=aA=t,oA=sA=e)}function EA(t,e){var n=t-aA,r=e-sA,i=Gb(n*n+r*r);fA+=i*(aA+t)/2,hA+=i*(sA+e)/2,dA+=i,pA+=(i=sA*t-aA*e)*(aA+t),gA+=i*(sA+e),mA+=3*i,vA(aA=t,sA=e)}function DA(t){this._context=t}DA.prototype={_radius:4.5,pointRadius:function(t){return this._radius=t,this},polygonStart:function(){this._line=0},polygonEnd:function(){this._line=NaN},lineStart:function(){this._point=0},lineEnd:function(){0===this._line&&this._context.closePath(),this._point=NaN},point:function(t,e){switch(this._point){case 0:this._context.moveTo(t,e),this._point=1;break;case 1:this._context.lineTo(t,e);break;default:this._context.moveTo(t+this._radius,e),this._context.arc(t,e,this._radius,0,Bb)}},result:Zb};var CA,FA,SA,$A,TA,BA=new se,zA={point:Zb,lineStart:function(){zA.point=NA},lineEnd:function(){CA&&OA(FA,SA),zA.point=Zb},polygonStart:function(){CA=!0},polygonEnd:function(){CA=null},result:function(){var t=+BA;return BA=new se,t}};function NA(t,e){zA.point=OA,FA=$A=t,SA=TA=e}function OA(t,e){$A-=t,TA-=e,BA.add(Gb($A*$A+TA*TA)),$A=t,TA=e}let RA,UA,LA,qA;class PA{constructor(t){this._append=null==t?jA:function(t){const e=Math.floor(t);if(!(e>=0))throw new RangeError(`invalid digits: ${t}`);if(e>15)return jA;if(e!==RA){const t=10**e;RA=e,UA=function(e){let n=1;this._+=e[0];for(const r=e.length;n=0))throw new RangeError(`invalid digits: ${t}`);i=e}return null===e&&(r=new PA(i)),a},a.projection(t).digits(i).context(e)}function WA(t){return function(e){var n=new HA;for(var r in t)n[r]=t[r];return n.stream=e,n}}function HA(){}function YA(t,e,n){var r=t.clipExtent&&t.clipExtent();return t.scale(150).translate([0,0]),null!=r&&t.clipExtent(null),rw(n,t.stream(rA)),e(rA.result()),null!=r&&t.clipExtent(r),t}function GA(t,e,n){return YA(t,(function(n){var r=e[1][0]-e[0][0],i=e[1][1]-e[0][1],o=Math.min(r/(n[1][0]-n[0][0]),i/(n[1][1]-n[0][1])),a=+e[0][0]+(r-o*(n[1][0]+n[0][0]))/2,s=+e[0][1]+(i-o*(n[1][1]+n[0][1]))/2;t.scale(150*o).translate([a,s])}),n)}function VA(t,e,n){return GA(t,[[0,0],e],n)}function XA(t,e,n){return YA(t,(function(n){var r=+e,i=r/(n[1][0]-n[0][0]),o=(r-i*(n[1][0]+n[0][0]))/2,a=-i*n[0][1];t.scale(150*i).translate([o,a])}),n)}function JA(t,e,n){return YA(t,(function(n){var r=+e,i=r/(n[1][1]-n[0][1]),o=-i*n[0][0],a=(r-i*(n[1][1]+n[0][1]))/2;t.scale(150*i).translate([o,a])}),n)}HA.prototype={constructor:HA,point:function(t,e){this.stream.point(t,e)},sphere:function(){this.stream.sphere()},lineStart:function(){this.stream.lineStart()},lineEnd:function(){this.stream.lineEnd()},polygonStart:function(){this.stream.polygonStart()},polygonEnd:function(){this.stream.polygonEnd()}};var ZA=16,QA=Lb(30*Nb);function KA(t,e){return+e?function(t,e){function n(r,i,o,a,s,u,l,c,f,h,d,p,g,m){var y=l-r,v=c-i,_=y*y+v*v;if(_>4*e&&g--){var x=a+h,b=s+d,w=u+p,k=Gb(x*x+b*b+w*w),A=Jb(w/=k),M=Ob(Ob(w)-1)e||Ob((y*F+v*S)/_-.5)>.3||a*h+s*d+u*p2?t[2]%360*Nb:0,F()):[m*zb,y*zb,v*zb]},D.angle=function(t){return arguments.length?(_=t%360*Nb,F()):_*zb},D.reflectX=function(t){return arguments.length?(x=t?-1:1,F()):x<0},D.reflectY=function(t){return arguments.length?(b=t?-1:1,F()):b<0},D.precision=function(t){return arguments.length?(a=KA(s,E=t*t),S()):Gb(E)},D.fitExtent=function(t,e){return GA(D,t,e)},D.fitSize=function(t,e){return VA(D,t,e)},D.fitWidth=function(t,e){return XA(D,t,e)},D.fitHeight=function(t,e){return JA(D,t,e)},function(){return e=t.apply(this,arguments),D.invert=e.invert&&C,F()}}function iM(t){var e=0,n=Sb/3,r=rM(t),i=r(e,n);return i.parallels=function(t){return arguments.length?r(e=t[0]*Nb,n=t[1]*Nb):[e*zb,n*zb]},i}function oM(t,e){var n=Hb(t),r=(n+Hb(e))/2;if(Ob(r)2?t[2]*Nb:0),e.invert=function(e){return(e=t.invert(e[0]*Nb,e[1]*Nb))[0]*=zb,e[1]*=zb,e},e}(i.rotate()).invert([0,0]));return u(null==l?[[s[0]-o,s[1]-o],[s[0]+o,s[1]+o]]:t===hM?[[Math.max(s[0]-o,l),e],[Math.min(s[0]+o,n),r]]:[[l,Math.max(s[1]-o,e)],[n,Math.min(s[1]+o,r)]])}return i.scale=function(t){return arguments.length?(a(t),c()):a()},i.translate=function(t){return arguments.length?(s(t),c()):s()},i.center=function(t){return arguments.length?(o(t),c()):o()},i.clipExtent=function(t){return arguments.length?(null==t?l=e=n=r=null:(l=+t[0][0],e=+t[0][1],n=+t[1][0],r=+t[1][1]),c()):null==l?null:[[l,e],[n,r]]},c()}function pM(t){return Vb(($b+t)/2)}function gM(t,e){var n=Lb(t),r=t===e?Hb(t):Ib(n/Lb(e))/Ib(pM(e)/pM(t)),i=n*Wb(pM(t),r)/r;if(!r)return hM;function o(t,e){i>0?e<-$b+Cb&&(e=-$b+Cb):e>$b-Cb&&(e=$b-Cb);var n=i/Wb(pM(e),r);return[n*Hb(r*t),i-n*Lb(r*t)]}return o.invert=function(t,e){var n=i-e,o=Yb(r)*Gb(t*t+n*n),a=Ub(t,Ob(n))*Yb(n);return n*r<0&&(a-=Sb*Yb(t)*Yb(n)),[a/r,2*Rb(Wb(i/o,1/r))-$b]},o}function mM(t,e){return[t,e]}function yM(t,e){var n=Lb(t),r=t===e?Hb(t):(n-Lb(e))/(e-t),i=n/r+t;if(Ob(r)Cb&&--i>0);return[t/(.8707+(o=r*r)*(o*(o*o*o*(.003971-.001529*o)-.013791)-.131979)),r]},EM.invert=lM(Jb),DM.invert=lM((function(t){return 2*Rb(t)})),CM.invert=function(t,e){return[-e,2*Rb(Pb(t))-$b]};var FM=Math.abs,SM=Math.cos,$M=Math.sin,TM=1e-6,BM=Math.PI,zM=BM/2,NM=function(t){return t>0?Math.sqrt(t):0}(2);function OM(t){return t>1?zM:t<-1?-zM:Math.asin(t)}function RM(t,e){var n,r=t*$M(e),i=30;do{e-=n=(e+$M(e)-r)/(1+SM(e))}while(FM(n)>TM&&--i>0);return e/2}var UM=function(t,e,n){function r(r,i){return[t*r*SM(i=RM(n,i)),e*$M(i)]}return r.invert=function(r,i){return i=OM(i/e),[r/(t*SM(i)),OM((2*i+$M(2*i))/n)]},r}(NM/zM,NM,BM);const LM=IA(),qM=["clipAngle","clipExtent","scale","translate","center","rotate","parallels","precision","reflectX","reflectY","coefficient","distance","fraction","lobes","parallel","radius","ratio","spacing","tilt"];function PM(t,e){if(!t||"string"!=typeof t)throw new Error("Projection type must be a name string.");return t=t.toLowerCase(),arguments.length>1?(IM[t]=function(t,e){return function n(){const r=e();return r.type=t,r.path=IA().projection(r),r.copy=r.copy||function(){const t=n();return qM.forEach((e=>{r[e]&&t[e](r[e]())})),t.path.pointRadius(r.path.pointRadius()),t},Vd(r)}}(t,e),this):IM[t]||null}function jM(t){return t&&t.path||LM}const IM={albers:sM,albersusa:function(){var t,e,n,r,i,o,a=sM(),s=aM().rotate([154,0]).center([-2,58.5]).parallels([55,65]),u=aM().rotate([157,0]).center([-3,19.9]).parallels([8,18]),l={point:function(t,e){o=[t,e]}};function c(t){var e=t[0],a=t[1];return o=null,n.point(e,a),o||(r.point(e,a),o)||(i.point(e,a),o)}function f(){return t=e=null,c}return c.invert=function(t){var e=a.scale(),n=a.translate(),r=(t[0]-n[0])/e,i=(t[1]-n[1])/e;return(i>=.12&&i<.234&&r>=-.425&&r<-.214?s:i>=.166&&i<.234&&r>=-.214&&r<-.115?u:a).invert(t)},c.stream=function(n){return t&&e===n?t:(r=[a.stream(e=n),s.stream(n),u.stream(n)],i=r.length,t={point:function(t,e){for(var n=-1;++n2?t[2]+90:90]):[(t=n())[0],t[1],t[2]-90]},n([0,0,90]).scale(159.155)}};for(const t in IM)PM(t,IM[t]);function WM(){}const HM=[[],[[[1,1.5],[.5,1]]],[[[1.5,1],[1,1.5]]],[[[1.5,1],[.5,1]]],[[[1,.5],[1.5,1]]],[[[1,1.5],[.5,1]],[[1,.5],[1.5,1]]],[[[1,.5],[1,1.5]]],[[[1,.5],[.5,1]]],[[[.5,1],[1,.5]]],[[[1,1.5],[1,.5]]],[[[.5,1],[1,.5]],[[1.5,1],[1,1.5]]],[[[1.5,1],[1,.5]]],[[[.5,1],[1.5,1]]],[[[1,1.5],[1.5,1]]],[[[.5,1],[1,1.5]]],[]];function YM(){var t=1,e=1,n=a;function r(t,e){return e.map((e=>i(t,e)))}function i(r,i){var a=[],s=[];return function(n,r,i){var a,s,u,l,c,f,h=new Array,d=new Array;a=s=-1,l=n[0]>=r,HM[l<<1].forEach(p);for(;++a=r,HM[u|l<<1].forEach(p);HM[l<<0].forEach(p);for(;++s=r,c=n[s*t]>=r,HM[l<<1|c<<2].forEach(p);++a=r,f=c,c=n[s*t+a+1]>=r,HM[u|l<<1|c<<2|f<<3].forEach(p);HM[l|c<<3].forEach(p)}a=-1,c=n[s*t]>=r,HM[c<<2].forEach(p);for(;++a=r,HM[c<<2|f<<3].forEach(p);function p(t){var e,n,r=[t[0][0]+a,t[0][1]+s],u=[t[1][0]+a,t[1][1]+s],l=o(r),c=o(u);(e=d[l])?(n=h[c])?(delete d[e.end],delete h[n.start],e===n?(e.ring.push(u),i(e.ring)):h[e.start]=d[n.end]={start:e.start,end:n.end,ring:e.ring.concat(n.ring)}):(delete d[e.end],e.ring.push(u),d[e.end=c]=e):(e=h[c])?(n=d[l])?(delete h[e.start],delete d[n.end],e===n?(e.ring.push(u),i(e.ring)):h[n.start]=d[e.end]={start:n.start,end:e.end,ring:n.ring.concat(e.ring)}):(delete h[e.start],e.ring.unshift(r),h[e.start=l]=e):h[l]=d[c]={start:l,end:c,ring:[r,u]}}HM[c<<3].forEach(p)}(r,i,(t=>{n(t,r,i),function(t){var e=0,n=t.length,r=t[n-1][1]*t[0][0]-t[n-1][0]*t[0][1];for(;++e0?a.push([t]):s.push(t)})),s.forEach((t=>{for(var e,n=0,r=a.length;n{var o,a=n[0],s=n[1],u=0|a,l=0|s,c=r[l*t+u];a>0&&a0&&s=0&&o>=0||s("invalid size"),t=i,e=o,r},r.smooth=function(t){return arguments.length?(n=t?a:WM,r):n===a},r}function GM(t,e){for(var n,r=-1,i=e.length;++rr!=d>r&&n<(h-l)*(r-c)/(d-c)+l&&(i=-i)}return i}function XM(t,e,n){var r,i,o,a;return function(t,e,n){return(e[0]-t[0])*(n[1]-t[1])==(n[0]-t[0])*(e[1]-t[1])}(t,e,n)&&(i=t[r=+(t[0]===e[0])],o=n[r],a=e[r],i<=o&&o<=a||a<=o&&o<=i)}function JM(t,e,n){return function(r){var i=at(r),o=n?Math.min(i[0],0):i[0],a=i[1],s=a-o,u=e?be(o,a,t):s/(t+1);return Se(o+u,a,u)}}function ZM(t){Ja.call(this,null,t)}function QM(t,e,n,r,i){const o=t.x1||0,a=t.y1||0,s=e*n<0;function u(t){t.forEach(l)}function l(t){s&&t.reverse(),t.forEach(c)}function c(t){t[0]=(t[0]-o)*e+r,t[1]=(t[1]-a)*n+i}return function(t){return t.coordinates.forEach(u),t}}function KM(t,e,n){const r=t>=0?t:rs(e,n);return Math.round((Math.sqrt(4*r*r+1)-1)/2)}function tE(t){return J(t)?t:rt(+t)}function eE(){var t=t=>t[0],e=t=>t[1],n=d,r=[-1,-1],i=960,o=500,a=2;function u(s,u){const l=KM(r[0],s,t)>>a,c=KM(r[1],s,e)>>a,f=l?l+2:0,h=c?c+2:0,d=2*f+(i>>a),p=2*h+(o>>a),g=new Float32Array(d*p),m=new Float32Array(d*p);let y=g;s.forEach((r=>{const i=f+(+t(r)>>a),o=h+(+e(r)>>a);i>=0&&i=0&&o0&&c>0?(nE(d,p,g,m,l),rE(d,p,m,g,c),nE(d,p,g,m,l),rE(d,p,m,g,c),nE(d,p,g,m,l),rE(d,p,m,g,c)):l>0?(nE(d,p,g,m,l),nE(d,p,m,g,l),nE(d,p,g,m,l),y=m):c>0&&(rE(d,p,g,m,c),rE(d,p,m,g,c),rE(d,p,g,m,c),y=m);const v=u?Math.pow(2,-2*a):1/$e(y);for(let t=0,e=d*p;t>a),y2:h+(o>>a)}}return u.x=function(e){return arguments.length?(t=tE(e),u):t},u.y=function(t){return arguments.length?(e=tE(t),u):e},u.weight=function(t){return arguments.length?(n=tE(t),u):n},u.size=function(t){if(!arguments.length)return[i,o];var e=+t[0],n=+t[1];return e>=0&&n>=0||s("invalid size"),i=e,o=n,u},u.cellSize=function(t){return arguments.length?((t=+t)>=1||s("invalid cell size"),a=Math.floor(Math.log(t)/Math.LN2),u):1<=i&&(e>=o&&(s-=n[e-o+a*t]),r[e-i+a*t]=s/Math.min(e+1,t-1+o-e,o))}function rE(t,e,n,r,i){const o=1+(i<<1);for(let a=0;a=i&&(s>=o&&(u-=n[a+(s-o)*t]),r[a+(s-i)*t]=u/Math.min(s+1,e-1+o-s,o))}function iE(t){Ja.call(this,null,t)}ZM.Definition={type:"Isocontour",metadata:{generates:!0},params:[{name:"field",type:"field"},{name:"thresholds",type:"number",array:!0},{name:"levels",type:"number"},{name:"nice",type:"boolean",default:!1},{name:"resolve",type:"enum",values:["shared","independent"],default:"independent"},{name:"zero",type:"boolean",default:!0},{name:"smooth",type:"boolean",default:!0},{name:"scale",type:"number",expr:!0},{name:"translate",type:"number",array:!0,expr:!0},{name:"as",type:"string",null:!0,default:"contour"}]},dt(ZM,Ja,{transform(t,e){if(this.value&&!e.changed()&&!t.modified())return e.StopPropagation;var n=e.fork(e.NO_SOURCE|e.NO_FIELDS),r=e.materialize(e.SOURCE).source,i=t.field||f,o=YM().smooth(!1!==t.smooth),a=t.thresholds||function(t,e,n){const r=JM(n.levels||10,n.nice,!1!==n.zero);return"shared"!==n.resolve?r:r(t.map((t=>we(e(t).values))))}(r,i,t),s=null===t.as?null:t.as||"contour",u=[];return r.forEach((e=>{const n=i(e),r=o.size([n.width,n.height])(n.values,k(a)?a:a(n.values));!function(t,e,n,r){let i=r.scale||e.scale,o=r.translate||e.translate;J(i)&&(i=i(n,r));J(o)&&(o=o(n,r));if((1===i||null==i)&&!o)return;const a=(vt(i)?i:i[0])||1,s=(vt(i)?i:i[1])||1,u=o&&o[0]||0,l=o&&o[1]||0;t.forEach(QM(e,a,s,u,l))}(r,n,e,t),r.forEach((t=>{u.push(ba(e,_a(null!=s?{[s]:t}:t)))}))})),this.value&&(n.rem=this.value),this.value=n.source=n.add=u,n}}),iE.Definition={type:"KDE2D",metadata:{generates:!0},params:[{name:"size",type:"number",array:!0,length:2,required:!0},{name:"x",type:"field",required:!0},{name:"y",type:"field",required:!0},{name:"weight",type:"field"},{name:"groupby",type:"field",array:!0},{name:"cellSize",type:"number"},{name:"bandwidth",type:"number",array:!0,length:2},{name:"counts",type:"boolean",default:!1},{name:"as",type:"string",default:"grid"}]};const oE=["x","y","weight","size","cellSize","bandwidth"];function aE(t,e){return oE.forEach((n=>null!=e[n]?t[n](e[n]):0)),t}function sE(t){Ja.call(this,null,t)}dt(iE,Ja,{transform(t,e){if(this.value&&!e.changed()&&!t.modified())return e.StopPropagation;var r,i=e.fork(e.NO_SOURCE|e.NO_FIELDS),o=function(t,e){var n,r,i,o,a,s,u=[],l=t=>t(o);if(null==e)u.push(t);else for(n={},r=0,i=t.length;r_a(function(t,e){for(let n=0;nCb})).map(u)).concat(Se(qb(o/d)*d,i,d).filter((function(t){return Ob(t%g)>Cb})).map(l))}return y.lines=function(){return v().map((function(t){return{type:"LineString",coordinates:t}}))},y.outline=function(){return{type:"Polygon",coordinates:[c(r).concat(f(a).slice(1),c(n).reverse().slice(1),f(s).reverse().slice(1))]}},y.extent=function(t){return arguments.length?y.extentMajor(t).extentMinor(t):y.extentMinor()},y.extentMajor=function(t){return arguments.length?(r=+t[0][0],n=+t[1][0],s=+t[0][1],a=+t[1][1],r>n&&(t=r,r=n,n=t),s>a&&(t=s,s=a,a=t),y.precision(m)):[[r,s],[n,a]]},y.extentMinor=function(n){return arguments.length?(e=+n[0][0],t=+n[1][0],o=+n[0][1],i=+n[1][1],e>t&&(n=e,e=t,t=n),o>i&&(n=o,o=i,i=n),y.precision(m)):[[e,o],[t,i]]},y.step=function(t){return arguments.length?y.stepMajor(t).stepMinor(t):y.stepMinor()},y.stepMajor=function(t){return arguments.length?(p=+t[0],g=+t[1],y):[p,g]},y.stepMinor=function(t){return arguments.length?(h=+t[0],d=+t[1],y):[h,d]},y.precision=function(h){return arguments.length?(m=+h,u=Lk(o,i,90),l=qk(e,t,m),c=Lk(s,a,90),f=qk(r,n,m),y):m},y.extentMajor([[-180,-90+Cb],[180,90-Cb]]).extentMinor([[-180,-80-Cb],[180,80+Cb]])}()}function gE(t){Ja.call(this,null,t)}function mE(t){if(!J(t))return!1;const e=Bt(r(t));return e.$x||e.$y||e.$value||e.$max}function yE(t){Ja.call(this,null,t),this.modified(!0)}function vE(t,e,n){J(t[e])&&t[e](n)}cE.Definition={type:"GeoJSON",metadata:{},params:[{name:"fields",type:"field",array:!0,length:2},{name:"geojson",type:"field"}]},dt(cE,Ja,{transform(t,e){var n,i=this._features,o=this._points,a=t.fields,s=a&&a[0],u=a&&a[1],l=t.geojson||!a&&f,c=e.ADD;n=t.modified()||e.changed(e.REM)||e.modified(r(l))||s&&e.modified(r(s))||u&&e.modified(r(u)),this.value&&!n||(c=e.SOURCE,this._features=i=[],this._points=o=[]),l&&e.visit(c,(t=>i.push(l(t)))),s&&u&&(e.visit(c,(t=>{var e=s(t),n=u(t);null!=e&&null!=n&&(e=+e)===e&&(n=+n)===n&&o.push([e,n])})),i=i.concat({type:uE,geometry:{type:"MultiPoint",coordinates:o}})),this.value={type:lE,features:i}}}),fE.Definition={type:"GeoPath",metadata:{modifies:!0},params:[{name:"projection",type:"projection"},{name:"field",type:"field"},{name:"pointRadius",type:"number",expr:!0},{name:"as",type:"string",default:"path"}]},dt(fE,Ja,{transform(t,e){var n=e.fork(e.ALL),r=this.value,i=t.field||f,o=t.as||"path",a=n.SOURCE;!r||t.modified()?(this.value=r=jM(t.projection),n.materialize().reflow()):a=i===f||e.modified(i.fields)?n.ADD_MOD:n.ADD;const s=function(t,e){const n=t.pointRadius();t.context(null),null!=e&&t.pointRadius(e);return n}(r,t.pointRadius);return n.visit(a,(t=>t[o]=r(i(t)))),r.pointRadius(s),n.modifies(o)}}),hE.Definition={type:"GeoPoint",metadata:{modifies:!0},params:[{name:"projection",type:"projection",required:!0},{name:"fields",type:"field",array:!0,required:!0,length:2},{name:"as",type:"string",array:!0,length:2,default:["x","y"]}]},dt(hE,Ja,{transform(t,e){var n,r=t.projection,i=t.fields[0],o=t.fields[1],a=t.as||["x","y"],s=a[0],u=a[1];function l(t){const e=r([i(t),o(t)]);e?(t[s]=e[0],t[u]=e[1]):(t[s]=void 0,t[u]=void 0)}return t.modified()?e=e.materialize().reflow(!0).visit(e.SOURCE,l):(n=e.modified(i.fields)||e.modified(o.fields),e.visit(n?e.ADD_MOD:e.ADD,l)),e.modifies(a)}}),dE.Definition={type:"GeoShape",metadata:{modifies:!0,nomod:!0},params:[{name:"projection",type:"projection"},{name:"field",type:"field",default:"datum"},{name:"pointRadius",type:"number",expr:!0},{name:"as",type:"string",default:"shape"}]},dt(dE,Ja,{transform(t,e){var n=e.fork(e.ALL),r=this.value,i=t.as||"shape",o=n.ADD;return r&&!t.modified()||(this.value=r=function(t,e,n){const r=null==n?n=>t(e(n)):r=>{var i=t.pointRadius(),o=t.pointRadius(n)(e(r));return t.pointRadius(i),o};return r.context=e=>(t.context(e),r),r}(jM(t.projection),t.field||l("datum"),t.pointRadius),n.materialize().reflow(),o=n.SOURCE),n.visit(o,(t=>t[i]=r)),n.modifies(i)}}),pE.Definition={type:"Graticule",metadata:{changes:!0,generates:!0},params:[{name:"extent",type:"array",array:!0,length:2,content:{type:"number",array:!0,length:2}},{name:"extentMajor",type:"array",array:!0,length:2,content:{type:"number",array:!0,length:2}},{name:"extentMinor",type:"array",array:!0,length:2,content:{type:"number",array:!0,length:2}},{name:"step",type:"number",array:!0,length:2},{name:"stepMajor",type:"number",array:!0,length:2,default:[90,360]},{name:"stepMinor",type:"number",array:!0,length:2,default:[10,10]},{name:"precision",type:"number",default:2.5}]},dt(pE,Ja,{transform(t,e){var n,r=this.value,i=this.generator;if(!r.length||t.modified())for(const e in t)J(i[e])&&i[e](t[e]);return n=i(),r.length?e.mod.push(wa(r[0],n)):e.add.push(_a(n)),r[0]=n,e}}),gE.Definition={type:"heatmap",metadata:{modifies:!0},params:[{name:"field",type:"field"},{name:"color",type:"string",expr:!0},{name:"opacity",type:"number",expr:!0},{name:"resolve",type:"enum",values:["shared","independent"],default:"independent"},{name:"as",type:"string",default:"image"}]},dt(gE,Ja,{transform(t,e){if(!e.changed()&&!t.modified())return e.StopPropagation;var n=e.materialize(e.SOURCE).source,r="shared"===t.resolve,i=t.field||f,o=function(t,e){let n;J(t)?(n=n=>t(n,e),n.dep=mE(t)):t?n=rt(t):(n=t=>t.$value/t.$max||0,n.dep=!0);return n}(t.opacity,t),a=function(t,e){let n;J(t)?(n=n=>af(t(n,e)),n.dep=mE(t)):n=rt(af(t||"#888"));return n}(t.color,t),s=t.as||"image",u={$x:0,$y:0,$value:0,$max:r?we(n.map((t=>we(i(t).values)))):0};return n.forEach((t=>{const e=i(t),n=ot({},t,u);r||(n.$max=we(e.values||[])),t[s]=function(t,e,n,r){const i=t.width,o=t.height,a=t.x1||0,s=t.y1||0,u=t.x2||i,l=t.y2||o,c=t.values,f=c?t=>c[t]:h,d=$c(u-a,l-s),p=d.getContext("2d"),g=p.getImageData(0,0,u-a,l-s),m=g.data;for(let t=s,o=0;t{null!=t[e]&&vE(n,e,t[e])}))):qM.forEach((e=>{t.modified(e)&&vE(n,e,t[e])})),null!=t.pointRadius&&n.path.pointRadius(t.pointRadius),t.fit&&function(t,e){const n=function(t){return t=V(t),1===t.length?t[0]:{type:lE,features:t.reduce(((t,e)=>t.concat(function(t){return t.type===lE?t.features:V(t).filter((t=>null!=t)).map((t=>t.type===uE?t:{type:uE,geometry:t}))}(e))),[])}}(e.fit);e.extent?t.fitExtent(e.extent,n):e.size&&t.fitSize(e.size,n)}(n,t),e.fork(e.NO_SOURCE|e.NO_FIELDS)}});var _E=Object.freeze({__proto__:null,contour:sE,geojson:cE,geopath:fE,geopoint:hE,geoshape:dE,graticule:pE,heatmap:gE,isocontour:ZM,kde2d:iE,projection:yE});function xE(t,e,n,r){if(isNaN(e)||isNaN(n))return t;var i,o,a,s,u,l,c,f,h,d=t._root,p={data:r},g=t._x0,m=t._y0,y=t._x1,v=t._y1;if(!d)return t._root=p,t;for(;d.length;)if((l=e>=(o=(g+y)/2))?g=o:y=o,(c=n>=(a=(m+v)/2))?m=a:v=a,i=d,!(d=d[f=c<<1|l]))return i[f]=p,t;if(s=+t._x.call(null,d.data),u=+t._y.call(null,d.data),e===s&&n===u)return p.next=d,i?i[f]=p:t._root=p,t;do{i=i?i[f]=new Array(4):t._root=new Array(4),(l=e>=(o=(g+y)/2))?g=o:y=o,(c=n>=(a=(m+v)/2))?m=a:v=a}while((f=c<<1|l)==(h=(u>=a)<<1|s>=o));return i[h]=d,i[f]=p,t}function bE(t,e,n,r,i){this.node=t,this.x0=e,this.y0=n,this.x1=r,this.y1=i}function wE(t){return t[0]}function kE(t){return t[1]}function AE(t,e,n){var r=new ME(null==e?wE:e,null==n?kE:n,NaN,NaN,NaN,NaN);return null==t?r:r.addAll(t)}function ME(t,e,n,r,i,o){this._x=t,this._y=e,this._x0=n,this._y0=r,this._x1=i,this._y1=o,this._root=void 0}function EE(t){for(var e={data:t.data},n=e;t=t.next;)n=n.next={data:t.data};return e}var DE=AE.prototype=ME.prototype;function CE(t){return function(){return t}}function FE(t){return 1e-6*(t()-.5)}function SE(t){return t.x+t.vx}function $E(t){return t.y+t.vy}function TE(t){return t.index}function BE(t,e){var n=t.get(e);if(!n)throw new Error("node not found: "+e);return n}DE.copy=function(){var t,e,n=new ME(this._x,this._y,this._x0,this._y0,this._x1,this._y1),r=this._root;if(!r)return n;if(!r.length)return n._root=EE(r),n;for(t=[{source:r,target:n._root=new Array(4)}];r=t.pop();)for(var i=0;i<4;++i)(e=r.source[i])&&(e.length?t.push({source:e,target:r.target[i]=new Array(4)}):r.target[i]=EE(e));return n},DE.add=function(t){const e=+this._x.call(null,t),n=+this._y.call(null,t);return xE(this.cover(e,n),e,n,t)},DE.addAll=function(t){var e,n,r,i,o=t.length,a=new Array(o),s=new Array(o),u=1/0,l=1/0,c=-1/0,f=-1/0;for(n=0;nc&&(c=r),if&&(f=i));if(u>c||l>f)return this;for(this.cover(u,l).cover(c,f),n=0;nt||t>=i||r>e||e>=o;)switch(s=(eh||(o=u.y0)>d||(a=u.x1)=y)<<1|t>=m)&&(u=p[p.length-1],p[p.length-1]=p[p.length-1-l],p[p.length-1-l]=u)}else{var v=t-+this._x.call(null,g.data),_=e-+this._y.call(null,g.data),x=v*v+_*_;if(x=(s=(p+m)/2))?p=s:m=s,(c=a>=(u=(g+y)/2))?g=u:y=u,e=d,!(d=d[f=c<<1|l]))return this;if(!d.length)break;(e[f+1&3]||e[f+2&3]||e[f+3&3])&&(n=e,h=f)}for(;d.data!==t;)if(r=d,!(d=d.next))return this;return(i=d.next)&&delete d.next,r?(i?r.next=i:delete r.next,this):e?(i?e[f]=i:delete e[f],(d=e[0]||e[1]||e[2]||e[3])&&d===(e[3]||e[2]||e[1]||e[0])&&!d.length&&(n?n[h]=d:this._root=d),this):(this._root=i,this)},DE.removeAll=function(t){for(var e=0,n=t.length;e{}};function NE(){for(var t,e=0,n=arguments.length,r={};e=0&&(e=t.slice(n+1),t=t.slice(0,n)),t&&!r.hasOwnProperty(t))throw new Error("unknown type: "+t);return{type:t,name:e}}))),a=-1,s=o.length;if(!(arguments.length<2)){if(null!=e&&"function"!=typeof e)throw new Error("invalid callback: "+e);for(;++a0)for(var n,r,i=new Array(n),o=0;o=0&&e._call.call(void 0,t),e=e._next;--PE}()}finally{PE=0,function(){var t,e,n=LE,r=1/0;for(;n;)n._call?(r>n._time&&(r=n._time),t=n,n=n._next):(e=n._next,n._next=null,n=t?t._next=e:LE=e);qE=t,nD(r)}(),YE=0}}function eD(){var t=VE.now(),e=t-HE;e>WE&&(GE-=e,HE=t)}function nD(t){PE||(jE&&(jE=clearTimeout(jE)),t-YE>24?(t<1/0&&(jE=setTimeout(tD,t-VE.now()-GE)),IE&&(IE=clearInterval(IE))):(IE||(HE=VE.now(),IE=setInterval(eD,WE)),PE=1,XE(tD)))}QE.prototype=KE.prototype={constructor:QE,restart:function(t,e,n){if("function"!=typeof t)throw new TypeError("callback is not a function");n=(null==n?JE():+n)+(null==e?0:+e),this._next||qE===this||(qE?qE._next=this:LE=this,qE=this),this._call=t,this._time=n,nD()},stop:function(){this._call&&(this._call=null,this._time=1/0,nD())}};const rD=1664525,iD=1013904223,oD=4294967296;function aD(t){return t.x}function sD(t){return t.y}var uD=10,lD=Math.PI*(3-Math.sqrt(5));function cD(t){var e,n=1,r=.001,i=1-Math.pow(r,1/300),o=0,a=.6,s=new Map,u=KE(f),l=NE("tick","end"),c=function(){let t=1;return()=>(t=(rD*t+iD)%oD)/oD}();function f(){h(),l.call("tick",e),n1?(null==n?s.delete(t):s.set(t,p(n)),e):s.get(t)},find:function(e,n,r){var i,o,a,s,u,l=0,c=t.length;for(null==r?r=1/0:r*=r,l=0;l1?(l.on(t,n),e):l.on(t)}}}const fD={center:function(t,e){var n,r=1;function i(){var i,o,a=n.length,s=0,u=0;for(i=0;il+p||oc+p||au.index){var g=l-s.x-s.vx,m=c-s.y-s.vy,y=g*g+m*m;yt.r&&(t.r=t[e].r)}function u(){if(e){var r,i,o=e.length;for(n=new Array(o),r=0;r=s)){(t.data!==e||t.next)&&(0===f&&(p+=(f=FE(n))*f),0===h&&(p+=(h=FE(n))*h),p[s(t,e,r),t])));for(a=0,i=new Array(l);ae(t,n):e)}mD.Definition={type:"Force",metadata:{modifies:!0},params:[{name:"static",type:"boolean",default:!1},{name:"restart",type:"boolean",default:!1},{name:"iterations",type:"number",default:300},{name:"alpha",type:"number",default:1},{name:"alphaMin",type:"number",default:.001},{name:"alphaTarget",type:"number",default:0},{name:"velocityDecay",type:"number",default:.4},{name:"forces",type:"param",array:!0,params:[{key:{force:"center"},params:[{name:"x",type:"number",default:0},{name:"y",type:"number",default:0}]},{key:{force:"collide"},params:[{name:"radius",type:"number",expr:!0},{name:"strength",type:"number",default:.7},{name:"iterations",type:"number",default:1}]},{key:{force:"nbody"},params:[{name:"strength",type:"number",default:-30,expr:!0},{name:"theta",type:"number",default:.9},{name:"distanceMin",type:"number",default:1},{name:"distanceMax",type:"number"}]},{key:{force:"link"},params:[{name:"links",type:"data"},{name:"id",type:"field"},{name:"distance",type:"number",default:30,expr:!0},{name:"strength",type:"number",expr:!0},{name:"iterations",type:"number",default:1}]},{key:{force:"x"},params:[{name:"strength",type:"number",default:.1},{name:"x",type:"field"}]},{key:{force:"y"},params:[{name:"strength",type:"number",default:.1},{name:"y",type:"field"}]}]},{name:"as",type:"string",array:!0,modify:!1,default:gD}]},dt(mD,Ja,{transform(t,e){var n,r,i=this.value,o=e.changed(e.ADD_REM),a=t.modified(dD),s=t.iterations||300;if(i?(o&&(e.modifies("index"),i.nodes(e.source)),(a||e.changed(e.MOD))&&yD(i,t,0,e)):(this.value=i=function(t,e){const n=cD(t),r=n.stop,i=n.restart;let o=!1;return n.stopped=()=>o,n.restart=()=>(o=!1,i()),n.stop=()=>(o=!0,r()),yD(n,e,!0).on("end",(()=>o=!0))}(e.source,t),i.on("tick",(n=e.dataflow,r=this,()=>n.touch(r).run())),t.static||(o=!0,i.tick()),e.modifies("index")),a||o||t.modified(pD)||e.changed()&&t.restart)if(i.alpha(Math.max(i.alpha(),t.alpha||1)).alphaDecay(1-Math.pow(i.alphaMin(),1/s)),t.static)for(i.stop();--s>=0;)i.tick();else if(i.stopped()&&i.restart(),!o)return e.StopPropagation;return this.finish(t,e)},finish(t,e){const n=e.dataflow;for(let t,e=this._argops,s=0,u=e.length;s=0;)e+=n[r].value;else e=1;t.value=e}function ED(t,e){t instanceof Map?(t=[void 0,t],void 0===e&&(e=CD)):void 0===e&&(e=DD);for(var n,r,i,o,a,s=new $D(t),u=[s];n=u.pop();)if((i=e(n.data))&&(a=(i=Array.from(i)).length))for(n.children=i,o=a-1;o>=0;--o)u.push(r=i[o]=new $D(i[o])),r.parent=n,r.depth=n.depth+1;return s.eachBefore(SD)}function DD(t){return t.children}function CD(t){return Array.isArray(t)?t[1]:null}function FD(t){void 0!==t.data.value&&(t.value=t.data.value),t.data=t.data.data}function SD(t){var e=0;do{t.height=e}while((t=t.parent)&&t.height<++e)}function $D(t){this.data=t,this.depth=this.height=0,this.parent=null}function TD(t){return null==t?null:BD(t)}function BD(t){if("function"!=typeof t)throw new Error;return t}function zD(){return 0}function ND(t){return function(){return t}}$D.prototype=ED.prototype={constructor:$D,count:function(){return this.eachAfter(MD)},each:function(t,e){let n=-1;for(const r of this)t.call(e,r,++n,this);return this},eachAfter:function(t,e){for(var n,r,i,o=this,a=[o],s=[],u=-1;o=a.pop();)if(s.push(o),n=o.children)for(r=0,i=n.length;r=0;--r)o.push(n[r]);return this},find:function(t,e){let n=-1;for(const r of this)if(t.call(e,r,++n,this))return r},sum:function(t){return this.eachAfter((function(e){for(var n=+t(e.data)||0,r=e.children,i=r&&r.length;--i>=0;)n+=r[i].value;e.value=n}))},sort:function(t){return this.eachBefore((function(e){e.children&&e.children.sort(t)}))},path:function(t){for(var e=this,n=function(t,e){if(t===e)return t;var n=t.ancestors(),r=e.ancestors(),i=null;t=n.pop(),e=r.pop();for(;t===e;)i=t,t=n.pop(),e=r.pop();return i}(e,t),r=[e];e!==n;)e=e.parent,r.push(e);for(var i=r.length;t!==n;)r.splice(i,0,t),t=t.parent;return r},ancestors:function(){for(var t=this,e=[t];t=t.parent;)e.push(t);return e},descendants:function(){return Array.from(this)},leaves:function(){var t=[];return this.eachBefore((function(e){e.children||t.push(e)})),t},links:function(){var t=this,e=[];return t.each((function(n){n!==t&&e.push({source:n.parent,target:n})})),e},copy:function(){return ED(this).eachBefore(FD)},[Symbol.iterator]:function*(){var t,e,n,r,i=this,o=[i];do{for(t=o.reverse(),o=[];i=t.pop();)if(yield i,e=i.children)for(n=0,r=e.length;n0&&n*n>r*r+i*i}function jD(t,e){for(var n=0;n1e-6?(D+Math.sqrt(D*D-4*E*C))/(2*E):C/D);return{x:r+w+k*F,y:i+A+M*F,r:F}}function YD(t,e,n){var r,i,o,a,s=t.x-e.x,u=t.y-e.y,l=s*s+u*u;l?(i=e.r+n.r,i*=i,a=t.r+n.r,i>(a*=a)?(r=(l+a-i)/(2*l),o=Math.sqrt(Math.max(0,a/l-r*r)),n.x=t.x-r*s-o*u,n.y=t.y-r*u+o*s):(r=(l+i-a)/(2*l),o=Math.sqrt(Math.max(0,i/l-r*r)),n.x=e.x+r*s-o*u,n.y=e.y+r*u+o*s)):(n.x=e.x+n.r,n.y=e.y)}function GD(t,e){var n=t.r+e.r-1e-6,r=e.x-t.x,i=e.y-t.y;return n>0&&n*n>r*r+i*i}function VD(t){var e=t._,n=t.next._,r=e.r+n.r,i=(e.x*n.r+n.x*e.r)/r,o=(e.y*n.r+n.y*e.r)/r;return i*i+o*o}function XD(t){this._=t,this.next=null,this.previous=null}function JD(t,e){if(!(o=(t=function(t){return"object"==typeof t&&"length"in t?t:Array.from(t)}(t)).length))return 0;var n,r,i,o,a,s,u,l,c,f,h;if((n=t[0]).x=0,n.y=0,!(o>1))return n.r;if(r=t[1],n.x=-r.r,r.x=n.r,r.y=0,!(o>2))return n.r+r.r;YD(r,n,i=t[2]),n=new XD(n),r=new XD(r),i=new XD(i),n.next=i.previous=r,r.next=n.previous=i,i.next=r.previous=n;t:for(u=3;ufunction(t){t=`${t}`;let e=t.length;cC(t,e-1)&&!cC(t,e-2)&&(t=t.slice(0,-1));return"/"===t[0]?t:`/${t}`}(t(e,n,r)))),n=e.map(lC),i=new Set(e).add("");for(const t of n)i.has(t)||(i.add(t),e.push(t),n.push(lC(t)),h.push(oC));d=(t,n)=>e[n],p=(t,e)=>n[e]}for(a=0,i=h.length;a=0&&(l=h[t]).data===oC;--t)l.data=null}if(s.parent=rC,s.eachBefore((function(t){t.depth=t.parent.depth+1,--i})).eachBefore(SD),s.parent=null,i>0)throw new Error("cycle");return s}return r.id=function(t){return arguments.length?(e=TD(t),r):e},r.parentId=function(t){return arguments.length?(n=TD(t),r):n},r.path=function(e){return arguments.length?(t=TD(e),r):t},r}function lC(t){let e=t.length;if(e<2)return"";for(;--e>1&&!cC(t,e););return t.slice(0,e)}function cC(t,e){if("/"===t[e]){let n=0;for(;e>0&&"\\"===t[--e];)++n;if(0==(1&n))return!0}return!1}function fC(t,e){return t.parent===e.parent?1:2}function hC(t){var e=t.children;return e?e[0]:t.t}function dC(t){var e=t.children;return e?e[e.length-1]:t.t}function pC(t,e,n){var r=n/(e.i-t.i);e.c-=r,e.s+=n,t.c+=r,e.z+=n,e.m+=n}function gC(t,e,n){return t.a.parent===e.parent?t.a:n}function mC(t,e){this._=t,this.parent=null,this.children=null,this.A=null,this.a=this,this.z=0,this.m=0,this.c=0,this.s=0,this.t=null,this.i=e}function yC(t,e,n,r,i){for(var o,a=t.children,s=-1,u=a.length,l=t.value&&(i-n)/t.value;++sh&&(h=s),m=c*c*g,(d=Math.max(h/m,m/f))>p){c-=s;break}p=d}y.push(a={value:c,dice:u1?e:1)},n}(vC);var bC=function t(e){function n(t,n,r,i,o){if((a=t._squarify)&&a.ratio===e)for(var a,s,u,l,c,f=-1,h=a.length,d=t.value;++f1?e:1)},n}(vC);function wC(t,e,n){const r={};return t.each((t=>{const i=t.data;n(i)&&(r[e(i)]=t)})),t.lookup=r,t}function kC(t){Ja.call(this,null,t)}kC.Definition={type:"Nest",metadata:{treesource:!0,changes:!0},params:[{name:"keys",type:"field",array:!0},{name:"generate",type:"boolean"}]};const AC=t=>t.values;function MC(){const t=[],e={entries:t=>r(n(t,0),0),key:n=>(t.push(n),e)};function n(e,r){if(r>=t.length)return e;const i=e.length,o=t[r++],a={},s={};let u,l,c,f=-1;for(;++ft.length)return e;const i=[];for(const t in e)i.push({key:t,values:r(e[t],n)});return i}return e}function EC(t){Ja.call(this,null,t)}dt(kC,Ja,{transform(t,e){e.source||s("Nest transform requires an upstream data source.");var n=t.generate,r=t.modified(),i=e.clone(),o=this.value;return(!o||r||e.changed())&&(o&&o.each((t=>{t.children&&ma(t.data)&&i.rem.push(t.data)})),this.value=o=ED({values:V(t.keys).reduce(((t,e)=>(t.key(e),t)),MC()).entries(i.source)},AC),n&&o.each((t=>{t.children&&(t=_a(t.data),i.add.push(t),i.source.push(t))})),wC(o,ya,ya)),i.source.root=o,i}});const DC=(t,e)=>t.parent===e.parent?1:2;dt(EC,Ja,{transform(t,e){e.source&&e.source.root||s(this.constructor.name+" transform requires a backing tree data source.");const n=this.layout(t.method),r=this.fields,i=e.source.root,o=t.as||r;t.field?i.sum(t.field):i.count(),t.sort&&i.sort(ka(t.sort,(t=>t.data))),function(t,e,n){for(let r,i=0,o=e.length;ifunction(t,e,n){const r=t.data,i=e.length-1;for(let o=0;o(t=(OD*t+RD)%UD)/UD}();return i.x=e/2,i.y=n/2,t?i.eachBefore(QD(t)).eachAfter(KD(r,.5,o)).eachBefore(tC(1)):i.eachBefore(QD(ZD)).eachAfter(KD(zD,1,o)).eachAfter(KD(r,i.r/Math.min(e,n),o)).eachBefore(tC(Math.min(e,n)/(2*i.r))),i}return i.radius=function(e){return arguments.length?(t=TD(e),i):t},i.size=function(t){return arguments.length?(e=+t[0],n=+t[1],i):[e,n]},i.padding=function(t){return arguments.length?(r="function"==typeof t?t:ND(+t),i):r},i},params:["radius","size","padding"],fields:CC});const SC=["x0","y0","x1","y1","depth","children"];function $C(t){EC.call(this,t)}function TC(t){Ja.call(this,null,t)}$C.Definition={type:"Partition",metadata:{tree:!0,modifies:!0},params:[{name:"field",type:"field"},{name:"sort",type:"compare"},{name:"padding",type:"number",default:0},{name:"round",type:"boolean",default:!1},{name:"size",type:"number",array:!0,length:2},{name:"as",type:"string",array:!0,length:SC.length,default:SC}]},dt($C,EC,{layout:function(){var t=1,e=1,n=0,r=!1;function i(i){var o=i.height+1;return i.x0=i.y0=n,i.x1=t,i.y1=e/o,i.eachBefore(function(t,e){return function(r){r.children&&nC(r,r.x0,t*(r.depth+1)/e,r.x1,t*(r.depth+2)/e);var i=r.x0,o=r.y0,a=r.x1-n,s=r.y1-n;a=0;--i)s.push(n=e.children[i]=new mC(r[i],i)),n.parent=e;return(a.parent=new mC(null,0)).children=[a],a}(i);if(u.eachAfter(o),u.parent.m=-u.z,u.eachBefore(a),r)i.eachBefore(s);else{var l=i,c=i,f=i;i.eachBefore((function(t){t.xc.x&&(c=t),t.depth>f.depth&&(f=t)}));var h=l===c?1:t(l,c)/2,d=h-l.x,p=e/(c.x+h+d),g=n/(f.depth||1);i.eachBefore((function(t){t.x=(t.x+d)*p,t.y=t.depth*g}))}return i}function o(e){var n=e.children,r=e.parent.children,i=e.i?r[e.i-1]:null;if(n){!function(t){for(var e,n=0,r=0,i=t.children,o=i.length;--o>=0;)(e=i[o]).z+=n,e.m+=n,n+=e.s+(r+=e.c)}(e);var o=(n[0].z+n[n.length-1].z)/2;i?(e.z=i.z+t(e._,i._),e.m=e.z-o):e.z=o}else i&&(e.z=i.z+t(e._,i._));e.parent.A=function(e,n,r){if(n){for(var i,o=e,a=e,s=n,u=o.parent.children[0],l=o.m,c=a.m,f=s.m,h=u.m;s=dC(s),o=hC(o),s&&o;)u=hC(u),(a=dC(a)).a=e,(i=s.z+f-o.z-l+t(s._,o._))>0&&(pC(gC(s,e,r),e,i),l+=i,c+=i),f+=s.m,l+=o.m,h+=u.m,c+=a.m;s&&!dC(a)&&(a.t=s,a.m+=f-c),o&&!hC(u)&&(u.t=o,u.m+=l-h,r=e)}return r}(e,i,e.parent.A||r[0])}function a(t){t._.x=t.z+t.parent.m,t.m+=t.parent.m}function s(t){t.x*=e,t.y=t.depth*n}return i.separation=function(e){return arguments.length?(t=e,i):t},i.size=function(t){return arguments.length?(r=!1,e=+t[0],n=+t[1],i):r?null:[e,n]},i.nodeSize=function(t){return arguments.length?(r=!0,e=+t[0],n=+t[1],i):r?[e,n]:null},i},cluster:function(){var t=wD,e=1,n=1,r=!1;function i(i){var o,a=0;i.eachAfter((function(e){var n=e.children;n?(e.x=function(t){return t.reduce(kD,0)/t.length}(n),e.y=function(t){return 1+t.reduce(AD,0)}(n)):(e.x=o?a+=t(e,o):0,e.y=0,o=e)}));var s=function(t){for(var e;e=t.children;)t=e[0];return t}(i),u=function(t){for(var e;e=t.children;)t=e[e.length-1];return t}(i),l=s.x-t(s,u)/2,c=u.x+t(u,s)/2;return i.eachAfter(r?function(t){t.x=(t.x-i.x)*e,t.y=(i.y-t.y)*n}:function(t){t.x=(t.x-l)/(c-l)*e,t.y=(1-(i.y?t.y/i.y:1))*n})}return i.separation=function(e){return arguments.length?(t=e,i):t},i.size=function(t){return arguments.length?(r=!1,e=+t[0],n=+t[1],i):r?null:[e,n]},i.nodeSize=function(t){return arguments.length?(r=!0,e=+t[0],n=+t[1],i):r?[e,n]:null},i}},zC=["x","y","depth","children"];function NC(t){EC.call(this,t)}function OC(t){Ja.call(this,[],t)}NC.Definition={type:"Tree",metadata:{tree:!0,modifies:!0},params:[{name:"field",type:"field"},{name:"sort",type:"compare"},{name:"method",type:"enum",default:"tidy",values:["tidy","cluster"]},{name:"size",type:"number",array:!0,length:2},{name:"nodeSize",type:"number",array:!0,length:2},{name:"separation",type:"boolean",default:!0},{name:"as",type:"string",array:!0,length:zC.length,default:zC}]},dt(NC,EC,{layout(t){const e=t||"tidy";if(lt(BC,e))return BC[e]();s("Unrecognized Tree layout method: "+e)},params:["size","nodeSize"],fields:zC}),OC.Definition={type:"TreeLinks",metadata:{tree:!0,generates:!0,changes:!0},params:[]},dt(OC,Ja,{transform(t,e){const n=this.value,r=e.source&&e.source.root,i=e.fork(e.NO_SOURCE),o={};return r||s("TreeLinks transform requires a tree data source."),e.changed(e.ADD_REM)?(i.rem=n,e.visit(e.SOURCE,(t=>o[ya(t)]=1)),r.each((t=>{const e=t.data,n=t.parent&&t.parent.data;n&&o[ya(e)]&&o[ya(n)]&&i.add.push(_a({source:n,target:e}))})),this.value=i.add):e.changed(e.MOD)&&(e.visit(e.MOD,(t=>o[ya(t)]=1)),n.forEach((t=>{(o[ya(t.source)]||o[ya(t.target)])&&i.mod.push(t)}))),i}});const RC={binary:function(t,e,n,r,i){var o,a,s=t.children,u=s.length,l=new Array(u+1);for(l[0]=a=o=0;o=n-1){var c=s[e];return c.x0=i,c.y0=o,c.x1=a,void(c.y1=u)}var f=l[e],h=r/2+f,d=e+1,p=n-1;for(;d>>1;l[g]u-o){var v=r?(i*y+a*m)/r:a;t(e,d,m,i,o,v,u),t(d,n,y,v,o,a,u)}else{var _=r?(o*y+u*m)/r:u;t(e,d,m,i,o,a,_),t(d,n,y,i,_,a,u)}}(0,u,t.value,e,n,r,i)},dice:nC,slice:yC,slicedice:function(t,e,n,r,i){(1&t.depth?yC:nC)(t,e,n,r,i)},squarify:xC,resquarify:bC},UC=["x0","y0","x1","y1","depth","children"];function LC(t){EC.call(this,t)}LC.Definition={type:"Treemap",metadata:{tree:!0,modifies:!0},params:[{name:"field",type:"field"},{name:"sort",type:"compare"},{name:"method",type:"enum",default:"squarify",values:["squarify","resquarify","binary","dice","slice","slicedice"]},{name:"padding",type:"number",default:0},{name:"paddingInner",type:"number",default:0},{name:"paddingOuter",type:"number",default:0},{name:"paddingTop",type:"number",default:0},{name:"paddingRight",type:"number",default:0},{name:"paddingBottom",type:"number",default:0},{name:"paddingLeft",type:"number",default:0},{name:"ratio",type:"number",default:1.618033988749895},{name:"round",type:"boolean",default:!1},{name:"size",type:"number",array:!0,length:2},{name:"as",type:"string",array:!0,length:UC.length,default:UC}]},dt(LC,EC,{layout(){const t=function(){var t=xC,e=!1,n=1,r=1,i=[0],o=zD,a=zD,s=zD,u=zD,l=zD;function c(t){return t.x0=t.y0=0,t.x1=n,t.y1=r,t.eachBefore(f),i=[0],e&&t.eachBefore(eC),t}function f(e){var n=i[e.depth],r=e.x0+n,c=e.y0+n,f=e.x1-n,h=e.y1-n;f{const n=t.tile();n.ratio&&t.tile(n.ratio(e))},t.method=e=>{lt(RC,e)?t.tile(RC[e]):s("Unrecognized Treemap layout method: "+e)},t},params:["method","ratio","size","round","padding","paddingInner","paddingOuter","paddingTop","paddingRight","paddingBottom","paddingLeft"],fields:UC});var qC=Object.freeze({__proto__:null,nest:kC,pack:FC,partition:$C,stratify:TC,tree:NC,treelinks:OC,treemap:LC});const PC=4278190080;function jC(t,e,n){return new Uint32Array(t.getImageData(0,0,e,n).data.buffer)}function IC(t,e,n){if(!e.length)return;const r=e[0].mark.marktype;"group"===r?e.forEach((e=>{e.items.forEach((e=>IC(t,e.items,n)))})):zy[r].draw(t,{items:n?e.map(WC):e})}function WC(t){const e=ba(t,{});return e.stroke&&0!==e.strokeOpacity||e.fill&&0!==e.fillOpacity?{...e,strokeOpacity:1,stroke:"#000",fillOpacity:0}:e}const HC=5,YC=31,GC=32,VC=new Uint32Array(GC+1),XC=new Uint32Array(GC+1);XC[0]=0,VC[0]=~XC[0];for(let t=1;t<=GC;++t)XC[t]=XC[t-1]<<1|1,VC[t]=~XC[t];function JC(t,e,n){const r=Math.max(1,Math.sqrt(t*e/1e6)),i=~~((t+2*n+r)/r),o=~~((e+2*n+r)/r),a=t=>~~((t+n)/r);return a.invert=t=>t*r-n,a.bitmap=()=>function(t,e){const n=new Uint32Array(~~((t*e+GC)/GC));function r(t,e){n[t]|=e}function i(t,e){n[t]&=e}return{array:n,get:(e,r)=>{const i=r*t+e;return n[i>>>HC]&1<<(i&YC)},set:(e,n)=>{const i=n*t+e;r(i>>>HC,1<<(i&YC))},clear:(e,n)=>{const r=n*t+e;i(r>>>HC,~(1<<(r&YC)))},getRange:(e,r,i,o)=>{let a,s,u,l,c=o;for(;c>=r;--c)if(a=c*t+e,s=c*t+i,u=a>>>HC,l=s>>>HC,u===l){if(n[u]&VC[a&YC]&XC[1+(s&YC)])return!0}else{if(n[u]&VC[a&YC])return!0;if(n[l]&XC[1+(s&YC)])return!0;for(let t=u+1;t{let a,s,u,l,c;for(;n<=o;++n)if(a=n*t+e,s=n*t+i,u=a>>>HC,l=s>>>HC,u===l)r(u,VC[a&YC]&XC[1+(s&YC)]);else for(r(u,VC[a&YC]),r(l,XC[1+(s&YC)]),c=u+1;c{let a,s,u,l,c;for(;n<=o;++n)if(a=n*t+e,s=n*t+r,u=a>>>HC,l=s>>>HC,u===l)i(u,XC[a&YC]|VC[1+(s&YC)]);else for(i(u,XC[a&YC]),i(l,VC[1+(s&YC)]),c=u+1;cn<0||r<0||o>=e||i>=t}}(i,o),a.ratio=r,a.padding=n,a.width=t,a.height=e,a}function ZC(t,e,n,r,i,o){let a=n/2;return t-a<0||t+a>i||e-(a=r/2)<0||e+a>o}function QC(t,e,n,r,i,o,a,s){const u=i*o/(2*r),l=t(e-u),c=t(e+u),f=t(n-(o/=2)),h=t(n+o);return a.outOfBounds(l,f,c,h)||a.getRange(l,f,c,h)||s&&s.getRange(l,f,c,h)}const KC=[-1,-1,1,1],tF=[-1,1,-1,1];const eF=["right","center","left"],nF=["bottom","middle","top"];function rF(t,e,n,r,i,o,a,s,u,l,c,f){return!(i.outOfBounds(t,n,e,r)||(f&&o||i).getRange(t,n,e,r))}const iF={"top-left":0,top:1,"top-right":2,left:4,middle:5,right:6,"bottom-left":8,bottom:9,"bottom-right":10},oF={naive:function(t,e,n,r){const i=t.width,o=t.height;return function(t){const e=t.datum.datum.items[r].items,n=e.length,a=t.datum.fontSize,s=py.width(t.datum,t.datum.text);let u,l,c,f,h,d,p,g=0;for(let r=0;r=g&&(g=p,t.x=h,t.y=d);return h=s/2,d=a/2,u=t.x-h,l=t.x+h,c=t.y-d,f=t.y+d,t.align="center",u<0&&l<=i?t.align="left":0<=u&&i=1;)h=(d+p)/2,QC(t,c,f,l,u,h,a,s)?p=h:d=h;if(d>r)return[c,f,d,!0]}}return function(e){const s=e.datum.datum.items[r].items,l=s.length,c=e.datum.fontSize,f=py.width(e.datum,e.datum.text);let h,d,p,g,m,y,v,_,x,b,w,k,A,M,E,D,C,F=n?c:0,S=!1,$=!1,T=0;for(let r=0;rd&&(C=h,h=d,d=C),p>g&&(C=p,p=g,g=C),x=t(h),w=t(d),b=~~((x+w)/2),k=t(p),M=t(g),A=~~((k+M)/2),v=b;v>=x;--v)for(_=A;_>=k;--_)D=u(v,_,F,f,c),D&&([e.x,e.y,F,S]=D);for(v=b;v<=w;++v)for(_=A;_<=M;++_)D=u(v,_,F,f,c),D&&([e.x,e.y,F,S]=D);S||n||(E=Math.abs(d-h+g-p),m=(h+d)/2,y=(p+g)/2,E>=T&&!ZC(m,y,f,c,i,o)&&!QC(t,m,y,c,f,c,a,null)&&(T=E,e.x=m,e.y=y,$=!0))}return!(!S&&!$)&&(m=f/2,y=c/2,a.setRange(t(e.x-m),t(e.y-y),t(e.x+m),t(e.y+y)),e.align="center",e.baseline="middle",!0)}},floodfill:function(t,e,n,r){const i=t.width,o=t.height,a=e[0],s=e[1],u=t.bitmap();return function(e){const l=e.datum.datum.items[r].items,c=l.length,f=e.datum.fontSize,h=py.width(e.datum,e.datum.text),d=[];let p,g,m,y,v,_,x,b,w,k,A,M,E=n?f:0,D=!1,C=!1,F=0;for(let r=0;r=1;)A=(w+k)/2,QC(t,v,_,f,h,A,a,s)?k=A:w=A;w>E&&(e.x=v,e.y=_,E=w,D=!0)}}D||n||(M=Math.abs(g-p+y-m),v=(p+g)/2,_=(m+y)/2,M>=F&&!ZC(v,_,h,f,i,o)&&!QC(t,v,_,f,h,f,a,null)&&(F=M,e.x=v,e.y=_,C=!0))}return!(!D&&!C)&&(v=h/2,_=f/2,a.setRange(t(e.x-v),t(e.y-_),t(e.x+v),t(e.y+_)),e.align="center",e.baseline="middle",!0)}}};function aF(t,e,n,r,i,o,a,s,u,l,c){if(!t.length)return t;const f=Math.max(r.length,i.length),h=function(t,e){const n=new Float64Array(e),r=t.length;for(let e=0;e[t.x,t.x,t.x,t.y,t.y,t.y];return t?"line"===t||"area"===t?t=>i(t.datum):"line"===e?t=>{const e=t.datum.items[r].items;return i(e.length?e["start"===n?0:e.length-1]:{x:NaN,y:NaN})}:t=>{const e=t.datum.bounds;return[e.x1,(e.x1+e.x2)/2,e.x2,e.y1,(e.y1+e.y2)/2,e.y2]}:i}(p,g,s,u),v=null===l||l===1/0,_=m&&"naive"===c;var x;let b=-1,w=-1;const k=t.map((t=>{const e=v?py.width(t,t.text):void 0;return b=Math.max(b,e),w=Math.max(w,t.fontSize),{datum:t,opacity:0,x:void 0,y:void 0,align:void 0,baseline:void 0,boundary:y(t),textWidth:e}}));l=null===l||l===1/0?Math.max(b,w)+Math.max(...r):l;const A=JC(e[0],e[1],l);let M;if(!_){n&&k.sort(((t,e)=>n(t.datum,e.datum)));let e=!1;for(let t=0;tt.datum));M=o.length||r?function(t,e,n,r,i){const o=t.width,a=t.height,s=r||i,u=$c(o,a).getContext("2d"),l=$c(o,a).getContext("2d"),c=s&&$c(o,a).getContext("2d");n.forEach((t=>IC(u,t,!1))),IC(l,e,!1),s&&IC(c,e,!0);const f=jC(u,o,a),h=jC(l,o,a),d=s&&jC(c,o,a),p=t.bitmap(),g=s&&t.bitmap();let m,y,v,_,x,b,w,k;for(y=0;yn.set(t(e.boundary[0]),t(e.boundary[3])))),[n,void 0]}(A,a&&k)}const E=m?oF[c](A,M,a,u):function(t,e,n,r){const i=t.width,o=t.height,a=e[0],s=e[1],u=r.length;return function(e){const l=e.boundary,c=e.datum.fontSize;if(l[2]<0||l[5]<0||l[0]>i||l[3]>o)return!1;let f,h,d,p,g,m,y,v,_,x,b,w,k,A,M,E=e.textWidth??0;for(let i=0;i>>2&3)-1,d=0===f&&0===h||r[i]<0,p=f&&h?Math.SQRT1_2:1,g=r[i]<0?-1:1,m=l[1+f]+r[i]*f*p,b=l[4+h]+g*c*h/2+r[i]*h*p,v=b-c/2,_=b+c/2,w=t(m),A=t(v),M=t(_),!E){if(!rF(w,w,A,M,a,s,0,0,0,0,0,d))continue;E=py.width(e.datum,e.datum.text)}if(x=m+g*E*f/2,m=x-E/2,y=x+E/2,w=t(m),k=t(y),rF(w,k,A,M,a,s,0,0,0,0,0,d))return e.x=f?f*g<0?y:m:x,e.y=h?h*g<0?_:v:b,e.align=eF[f*g+1],e.baseline=nF[h*g+1],a.setRange(w,A,k,M),!0}return!1}}(A,M,d,h);return k.forEach((t=>t.opacity=+E(t))),k}const sF=["x","y","opacity","align","baseline"],uF=["top-left","left","bottom-left","top","bottom","top-right","right","bottom-right"];function lF(t){Ja.call(this,null,t)}lF.Definition={type:"Label",metadata:{modifies:!0},params:[{name:"size",type:"number",array:!0,length:2,required:!0},{name:"sort",type:"compare"},{name:"anchor",type:"string",array:!0,default:uF},{name:"offset",type:"number",array:!0,default:[1]},{name:"padding",type:"number",default:0,null:!0},{name:"lineAnchor",type:"string",values:["start","end"],default:"end"},{name:"markIndex",type:"number",default:0},{name:"avoidBaseMark",type:"boolean",default:!0},{name:"avoidMarks",type:"data",array:!0},{name:"method",type:"string",default:"naive"},{name:"as",type:"string",array:!0,length:sF.length,default:sF}]},dt(lF,Ja,{transform(t,e){const n=t.modified();if(!(n||e.changed(e.ADD_REM)||function(n){const r=t[n];return J(r)&&e.modified(r.fields)}("sort")))return;t.size&&2===t.size.length||s("Size parameter should be specified as a [width, height] array.");const r=t.as||sF;return aF(e.materialize(e.SOURCE).source||[],t.size,t.sort,V(null==t.offset?1:t.offset),V(t.anchor||uF),t.avoidMarks||[],!1!==t.avoidBaseMark,t.lineAnchor||"end",t.markIndex||0,void 0===t.padding?0:t.padding,t.method||"naive").forEach((t=>{const e=t.datum;e[r[0]]=t.x,e[r[1]]=t.y,e[r[2]]=t.opacity,e[r[3]]=t.align,e[r[4]]=t.baseline})),e.reflow(n).modifies(r)}});var cF=Object.freeze({__proto__:null,label:lF});function fF(t,e){var n,r,i,o,a,s,u=[],l=function(t){return t(o)};if(null==e)u.push(t);else for(n={},r=0,i=t.length;r{Ls(e,t.x,t.y,t.bandwidth||.3).forEach((t=>{const n={};for(let t=0;t"poly"===t?e:"quad"===t?2:1)(a,u),c=t.as||[n(t.x),n(t.y)],f=dF[a],h=[];let d=t.extent;lt(dF,a)||s("Invalid regression method: "+a),null!=d&&"log"===a&&d[0]<=0&&(e.dataflow.warn("Ignoring extent with values <= 0 for log regression."),d=null),i.forEach((n=>{if(n.length<=l)return void e.dataflow.warn("Skipping regression with more parameters than data points.");const r=f(n,t.x,t.y,u);if(t.params)return void h.push(_a({keys:n.dims,coef:r.coef,rSquared:r.rSquared}));const i=d||at(n,t.x),s=t=>{const e={};for(let t=0;ts([t,r.predict(t)]))):Is(r.predict,i,25,200).forEach(s)})),this.value&&(r.rem=this.value),this.value=r.add=r.source=h}return r}});var gF=Object.freeze({__proto__:null,loess:hF,regression:pF});const mF=134217729,yF=33306690738754706e-32;function vF(t,e,n,r,i){let o,a,s,u,l=e[0],c=r[0],f=0,h=0;c>l==c>-l?(o=l,l=e[++f]):(o=c,c=r[++h]);let d=0;if(fl==c>-l?(a=l+o,s=o-(a-l),l=e[++f]):(a=c+o,s=o-(a-c),c=r[++h]),o=a,0!==s&&(i[d++]=s);fl==c>-l?(a=o+l,u=a-o,s=o-(a-u)+(l-u),l=e[++f]):(a=o+c,u=a-o,s=o-(a-u)+(c-u),c=r[++h]),o=a,0!==s&&(i[d++]=s);for(;f0!=s>0)return u;const l=Math.abs(a+s);return Math.abs(u)>=33306690738754716e-32*l?u:-function(t,e,n,r,i,o,a){let s,u,l,c,f,h,d,p,g,m,y,v,_,x,b,w,k,A;const M=t-i,E=n-i,D=e-o,C=r-o;x=M*C,h=mF*M,d=h-(h-M),p=M-d,h=mF*C,g=h-(h-C),m=C-g,b=p*m-(x-d*g-p*g-d*m),w=D*E,h=mF*D,d=h-(h-D),p=D-d,h=mF*E,g=h-(h-E),m=E-g,k=p*m-(w-d*g-p*g-d*m),y=b-k,f=b-y,wF[0]=b-(y+f)+(f-k),v=x+y,f=v-x,_=x-(v-f)+(y-f),y=_-w,f=_-y,wF[1]=_-(y+f)+(f-w),A=v+y,f=A-v,wF[2]=v-(A-f)+(y-f),wF[3]=A;let F=function(t,e){let n=e[0];for(let r=1;r=S||-F>=S)return F;if(f=t-M,s=t-(M+f)+(f-i),f=n-E,l=n-(E+f)+(f-i),f=e-D,u=e-(D+f)+(f-o),f=r-C,c=r-(C+f)+(f-o),0===s&&0===u&&0===l&&0===c)return F;if(S=bF*a+yF*Math.abs(F),F+=M*c+C*s-(D*l+E*u),F>=S||-F>=S)return F;x=s*C,h=mF*s,d=h-(h-s),p=s-d,h=mF*C,g=h-(h-C),m=C-g,b=p*m-(x-d*g-p*g-d*m),w=u*E,h=mF*u,d=h-(h-u),p=u-d,h=mF*E,g=h-(h-E),m=E-g,k=p*m-(w-d*g-p*g-d*m),y=b-k,f=b-y,EF[0]=b-(y+f)+(f-k),v=x+y,f=v-x,_=x-(v-f)+(y-f),y=_-w,f=_-y,EF[1]=_-(y+f)+(f-w),A=v+y,f=A-v,EF[2]=v-(A-f)+(y-f),EF[3]=A;const $=vF(4,wF,4,EF,kF);x=M*c,h=mF*M,d=h-(h-M),p=M-d,h=mF*c,g=h-(h-c),m=c-g,b=p*m-(x-d*g-p*g-d*m),w=D*l,h=mF*D,d=h-(h-D),p=D-d,h=mF*l,g=h-(h-l),m=l-g,k=p*m-(w-d*g-p*g-d*m),y=b-k,f=b-y,EF[0]=b-(y+f)+(f-k),v=x+y,f=v-x,_=x-(v-f)+(y-f),y=_-w,f=_-y,EF[1]=_-(y+f)+(f-w),A=v+y,f=A-v,EF[2]=v-(A-f)+(y-f),EF[3]=A;const T=vF($,kF,4,EF,AF);x=s*c,h=mF*s,d=h-(h-s),p=s-d,h=mF*c,g=h-(h-c),m=c-g,b=p*m-(x-d*g-p*g-d*m),w=u*l,h=mF*u,d=h-(h-u),p=u-d,h=mF*l,g=h-(h-l),m=l-g,k=p*m-(w-d*g-p*g-d*m),y=b-k,f=b-y,EF[0]=b-(y+f)+(f-k),v=x+y,f=v-x,_=x-(v-f)+(y-f),y=_-w,f=_-y,EF[1]=_-(y+f)+(f-w),A=v+y,f=A-v,EF[2]=v-(A-f)+(y-f),EF[3]=A;const B=vF(T,AF,4,EF,MF);return MF[B-1]}(t,e,n,r,i,o,l)}const CF=Math.pow(2,-52),FF=new Uint32Array(512);class SF{static from(t){let e=arguments.length>1&&void 0!==arguments[1]?arguments[1]:OF,n=arguments.length>2&&void 0!==arguments[2]?arguments[2]:RF;const r=t.length,i=new Float64Array(2*r);for(let o=0;o>1;if(e>0&&"number"!=typeof t[0])throw new Error("Expected coords to contain numbers.");this.coords=t;const n=Math.max(2*e-5,0);this._triangles=new Uint32Array(3*n),this._halfedges=new Int32Array(3*n),this._hashSize=Math.ceil(Math.sqrt(e)),this._hullPrev=new Uint32Array(e),this._hullNext=new Uint32Array(e),this._hullTri=new Uint32Array(e),this._hullHash=new Int32Array(this._hashSize).fill(-1),this._ids=new Uint32Array(e),this._dists=new Float64Array(e),this.update()}update(){const{coords:t,_hullPrev:e,_hullNext:n,_hullTri:r,_hullHash:i}=this,o=t.length>>1;let a=1/0,s=1/0,u=-1/0,l=-1/0;for(let e=0;eu&&(u=n),r>l&&(l=r),this._ids[e]=e}const c=(a+u)/2,f=(s+l)/2;let h,d,p,g=1/0;for(let e=0;e0&&(d=e,g=n)}let v=t[2*d],_=t[2*d+1],x=1/0;for(let e=0;er&&(e[n++]=i,r=this._dists[i])}return this.hull=e.subarray(0,n),this.triangles=new Uint32Array(0),void(this.halfedges=new Uint32Array(0))}if(DF(m,y,v,_,b,w)<0){const t=d,e=v,n=_;d=p,v=b,_=w,p=t,b=e,w=n}const k=function(t,e,n,r,i,o){const a=n-t,s=r-e,u=i-t,l=o-e,c=a*a+s*s,f=u*u+l*l,h=.5/(a*l-s*u),d=t+(l*c-s*f)*h,p=e+(a*f-u*c)*h;return{x:d,y:p}}(m,y,v,_,b,w);this._cx=k.x,this._cy=k.y;for(let e=0;e0&&Math.abs(l-o)<=CF&&Math.abs(c-a)<=CF)continue;if(o=l,a=c,u===h||u===d||u===p)continue;let f=0;for(let t=0,e=this._hashKey(l,c);t=0;)if(m=g,m===f){m=-1;break}if(-1===m)continue;let y=this._addTriangle(m,u,n[m],-1,-1,r[m]);r[u]=this._legalize(y+2),r[m]=y,A++;let v=n[m];for(;g=n[v],DF(l,c,t[2*v],t[2*v+1],t[2*g],t[2*g+1])<0;)y=this._addTriangle(v,u,g,r[u],-1,r[v]),r[u]=this._legalize(y+2),n[v]=v,A--,v=g;if(m===f)for(;g=e[m],DF(l,c,t[2*g],t[2*g+1],t[2*m],t[2*m+1])<0;)y=this._addTriangle(g,u,m,-1,r[m],r[g]),this._legalize(y+2),r[g]=y,n[m]=m,A--,m=g;this._hullStart=e[u]=m,n[m]=e[v]=u,n[u]=v,i[this._hashKey(l,c)]=u,i[this._hashKey(t[2*m],t[2*m+1])]=m}this.hull=new Uint32Array(A);for(let t=0,e=this._hullStart;t0?3-n:1+n)/4}(t-this._cx,e-this._cy)*this._hashSize)%this._hashSize}_legalize(t){const{_triangles:e,_halfedges:n,coords:r}=this;let i=0,o=0;for(;;){const a=n[t],s=t-t%3;if(o=s+(t+2)%3,-1===a){if(0===i)break;t=FF[--i];continue}const u=a-a%3,l=s+(t+1)%3,c=u+(a+2)%3,f=e[o],h=e[t],d=e[l],p=e[c];if(TF(r[2*f],r[2*f+1],r[2*h],r[2*h+1],r[2*d],r[2*d+1],r[2*p],r[2*p+1])){e[t]=p,e[a]=f;const r=n[c];if(-1===r){let e=this._hullStart;do{if(this._hullTri[e]===c){this._hullTri[e]=t;break}e=this._hullPrev[e]}while(e!==this._hullStart)}this._link(t,r),this._link(a,n[o]),this._link(o,c);const s=u+(a+1)%3;i=n&&e[t[a]]>o;)t[a+1]=t[a--];t[a+1]=r}else{let i=n+1,o=r;NF(t,n+r>>1,i),e[t[n]]>e[t[r]]&&NF(t,n,r),e[t[i]]>e[t[r]]&&NF(t,i,r),e[t[n]]>e[t[i]]&&NF(t,n,i);const a=t[i],s=e[a];for(;;){do{i++}while(e[t[i]]s);if(o=o-n?(zF(t,e,i,r),zF(t,e,n,o-1)):(zF(t,e,n,o-1),zF(t,e,i,r))}}function NF(t,e,n){const r=t[e];t[e]=t[n],t[n]=r}function OF(t){return t[0]}function RF(t){return t[1]}const UF=1e-6;class LF{constructor(){this._x0=this._y0=this._x1=this._y1=null,this._=""}moveTo(t,e){this._+=`M${this._x0=this._x1=+t},${this._y0=this._y1=+e}`}closePath(){null!==this._x1&&(this._x1=this._x0,this._y1=this._y0,this._+="Z")}lineTo(t,e){this._+=`L${this._x1=+t},${this._y1=+e}`}arc(t,e,n){const r=(t=+t)+(n=+n),i=e=+e;if(n<0)throw new Error("negative radius");null===this._x1?this._+=`M${r},${i}`:(Math.abs(this._x1-r)>UF||Math.abs(this._y1-i)>UF)&&(this._+="L"+r+","+i),n&&(this._+=`A${n},${n},0,1,1,${t-n},${e}A${n},${n},0,1,1,${this._x1=r},${this._y1=i}`)}rect(t,e,n,r){this._+=`M${this._x0=this._x1=+t},${this._y0=this._y1=+e}h${+n}v${+r}h${-n}Z`}value(){return this._||null}}class qF{constructor(){this._=[]}moveTo(t,e){this._.push([t,e])}closePath(){this._.push(this._[0].slice())}lineTo(t,e){this._.push([t,e])}value(){return this._.length?this._:null}}let PF=class{constructor(t){let[e,n,r,i]=arguments.length>1&&void 0!==arguments[1]?arguments[1]:[0,0,960,500];if(!((r=+r)>=(e=+e)&&(i=+i)>=(n=+n)))throw new Error("invalid bounds");this.delaunay=t,this._circumcenters=new Float64Array(2*t.points.length),this.vectors=new Float64Array(2*t.points.length),this.xmax=r,this.xmin=e,this.ymax=i,this.ymin=n,this._init()}update(){return this.delaunay.update(),this._init(),this}_init(){const{delaunay:{points:t,hull:e,triangles:n},vectors:r}=this;let i,o;const a=this.circumcenters=this._circumcenters.subarray(0,n.length/3*2);for(let r,s,u=0,l=0,c=n.length;u1;)i-=2;for(let t=2;t0){if(e>=this.ymax)return null;(i=(this.ymax-e)/r)0){if(t>=this.xmax)return null;(i=(this.xmax-t)/n)this.xmax?2:0)|(ethis.ymax?8:0)}_simplify(t){if(t&&t.length>4){for(let e=0;e1&&void 0!==arguments[1]?arguments[1]:WF,n=arguments.length>2&&void 0!==arguments[2]?arguments[2]:HF,r=arguments.length>3?arguments[3]:void 0;return new GF("length"in t?function(t,e,n,r){const i=t.length,o=new Float64Array(2*i);for(let a=0;a2&&function(t){const{triangles:e,coords:n}=t;for(let t=0;t1e-10)return!1}return!0}(t)){this.collinear=Int32Array.from({length:e.length/2},((t,e)=>e)).sort(((t,n)=>e[2*t]-e[2*n]||e[2*t+1]-e[2*n+1]));const t=this.collinear[0],n=this.collinear[this.collinear.length-1],r=[e[2*t],e[2*t+1],e[2*n],e[2*n+1]],i=1e-8*Math.hypot(r[3]-r[1],r[2]-r[0]);for(let t=0,n=e.length/2;t0&&(this.triangles=new Int32Array(3).fill(-1),this.halfedges=new Int32Array(3).fill(-1),this.triangles[0]=r[0],o[r[0]]=1,2===r.length&&(o[r[1]]=0,this.triangles[1]=r[1],this.triangles[2]=r[1]))}voronoi(t){return new PF(this,t)}*neighbors(t){const{inedges:e,hull:n,_hullIndex:r,halfedges:i,triangles:o,collinear:a}=this;if(a){const e=a.indexOf(t);return e>0&&(yield a[e-1]),void(e2&&void 0!==arguments[2]?arguments[2]:0;if((t=+t)!=t||(e=+e)!=e)return-1;const r=n;let i;for(;(i=this._step(n,t,e))>=0&&i!==n&&i!==r;)n=i;return i}_step(t,e,n){const{inedges:r,hull:i,_hullIndex:o,halfedges:a,triangles:s,points:u}=this;if(-1===r[t]||!u.length)return(t+1)%(u.length>>1);let l=t,c=IF(e-u[2*t],2)+IF(n-u[2*t+1],2);const f=r[t];let h=f;do{let r=s[h];const f=IF(e-u[2*r],2)+IF(n-u[2*r+1],2);if(f=f));)if(e.x=a+i,e.y=l+o,!(e.x+e.x0<0||e.y+e.y0<0||e.x+e.x1>s[0]||e.y+e.y1>s[1])&&(!n||!rS(e,t,s[0]))&&(!n||oS(e,n))){for(var g,m=e.sprite,y=e.width>>5,v=s[0]>>5,_=e.x-(y<<4),x=127&_,b=32-x,w=e.y1-e.y0,k=(e.y+e.y0)*v+(_>>5),A=0;A>>x:0);k+=v}return e.sprite=null,!0}return!1}return f.layout=function(){for(var u=function(t){t.width=t.height=1;var e=Math.sqrt(t.getContext("2d").getImageData(0,0,1,1).data.length>>2);t.width=(KF<<5)/e,t.height=tS/e;var n=t.getContext("2d");return n.fillStyle=n.strokeStyle="red",n.textAlign="center",{context:n,ratio:e}}($c()),f=function(t){var e=[],n=-1;for(;++n>5)*s[1]),d=null,p=l.length,g=-1,m=[],y=l.map((s=>({text:t(s),font:e(s),style:r(s),weight:i(s),rotate:o(s),size:~~(n(s)+1e-14),padding:a(s),xoff:0,yoff:0,x1:0,y1:0,x0:0,y0:0,hasText:!1,sprite:null,datum:s}))).sort(((t,e)=>e.size-t.size));++g>1,v.y=s[1]*(c()+.5)>>1,nS(u,v,y,g),v.hasText&&h(f,v,d)&&(m.push(v),d?iS(d,v):d=[{x:v.x+v.x0,y:v.y+v.y0},{x:v.x+v.x1,y:v.y+v.y1}],v.x-=s[0]>>1,v.y-=s[1]>>1)}return m},f.words=function(t){return arguments.length?(l=t,f):l},f.size=function(t){return arguments.length?(s=[+t[0],+t[1]],f):s},f.font=function(t){return arguments.length?(e=sS(t),f):e},f.fontStyle=function(t){return arguments.length?(r=sS(t),f):r},f.fontWeight=function(t){return arguments.length?(i=sS(t),f):i},f.rotate=function(t){return arguments.length?(o=sS(t),f):o},f.text=function(e){return arguments.length?(t=sS(e),f):t},f.spiral=function(t){return arguments.length?(u=uS[t]||t,f):u},f.fontSize=function(t){return arguments.length?(n=sS(t),f):n},f.padding=function(t){return arguments.length?(a=sS(t),f):a},f.random=function(t){return arguments.length?(c=t,f):c},f}function nS(t,e,n,r){if(!e.sprite){var i=t.context,o=t.ratio;i.clearRect(0,0,(KF<<5)/o,tS/o);var a,s,u,l,c,f=0,h=0,d=0,p=n.length;for(--r;++r>5<<5,u=~~Math.max(Math.abs(v+_),Math.abs(v-_))}else a=a+31>>5<<5;if(u>d&&(d=u),f+a>=KF<<5&&(f=0,h+=d,d=0),h+u>=tS)break;i.translate((f+(a>>1))/o,(h+(u>>1))/o),e.rotate&&i.rotate(e.rotate*QF),i.fillText(e.text,0,0),e.padding&&(i.lineWidth=2*e.padding,i.strokeText(e.text,0,0)),i.restore(),e.width=a,e.height=u,e.xoff=f,e.yoff=h,e.x1=a>>1,e.y1=u>>1,e.x0=-e.x1,e.y0=-e.y1,e.hasText=!0,f+=a}for(var b=i.getImageData(0,0,(KF<<5)/o,tS/o).data,w=[];--r>=0;)if((e=n[r]).hasText){for(s=(a=e.width)>>5,u=e.y1-e.y0,l=0;l>5),E=b[(h+c)*(KF<<5)+(f+l)<<2]?1<<31-l%32:0;w[M]|=E,k|=E}k?A=c:(e.y0++,u--,c--,h++)}e.y1=e.y0+A,e.sprite=w.slice(0,(e.y1-e.y0)*s)}}}function rS(t,e,n){n>>=5;for(var r,i=t.sprite,o=t.width>>5,a=t.x-(o<<4),s=127&a,u=32-s,l=t.y1-t.y0,c=(t.y+t.y0)*n+(a>>5),f=0;f>>s:0))&e[c+h])return!0;c+=n}return!1}function iS(t,e){var n=t[0],r=t[1];e.x+e.x0r.x&&(r.x=e.x+e.x1),e.y+e.y1>r.y&&(r.y=e.y+e.y1)}function oS(t,e){return t.x+t.x1>e[0].x&&t.x+t.x0e[0].y&&t.y+t.y0e(t(n))}i.forEach((t=>{t[a[0]]=NaN,t[a[1]]=NaN,t[a[3]]=0}));const c=o.words(i).text(e.text).size(e.size||[500,500]).padding(e.padding||1).spiral(e.spiral||"archimedean").rotate(e.rotate||0).font(e.font||"sans-serif").fontStyle(e.fontStyle||"normal").fontWeight(e.fontWeight||"normal").fontSize(l).random(t.random).layout(),f=o.size(),h=f[0]>>1,d=f[1]>>1,p=c.length;for(let t,e,n=0;nnew Uint8Array(t),pS=t=>new Uint16Array(t),gS=t=>new Uint32Array(t);function mS(t,e,n){const r=(e<257?dS:e<65537?pS:gS)(t);return n&&r.set(n),r}function yS(t,e,n){const r=1<{const r=t[e],i=t[n];return ri?1:0})),function(t,e){return Array.from(e,(e=>t[e]))}(t,e)}(h,u),a)l=e,c=t,e=Array(a+s),t=gS(a+s),function(t,e,n,r,i,o,a,s,u){let l,c=0,f=0;for(l=0;c0)for(f=0;ft,size:()=>n}}function _S(t){Ja.call(this,function(){let t=8,e=[],n=gS(0),r=mS(0,t),i=mS(0,t);return{data:()=>e,seen:()=>n=function(t,e,n){return t.length>=e?t:((n=n||new t.constructor(e)).set(t),n)}(n,e.length),add(t){for(let n,r=0,i=e.length,o=t.length;re.length,curr:()=>r,prev:()=>i,reset:t=>i[t]=r[t],all:()=>t<257?255:t<65537?65535:4294967295,set(t,e){r[t]|=e},clear(t,e){r[t]&=~e},resize(e,n){(e>r.length||n>t)&&(t=Math.max(n,t),r=mS(e,t,r),i=mS(e,t))}}}(),t),this._indices=null,this._dims=null}function xS(t){Ja.call(this,null,t)}_S.Definition={type:"CrossFilter",metadata:{},params:[{name:"fields",type:"field",array:!0,required:!0},{name:"query",type:"array",array:!0,required:!0,content:{type:"number",array:!0,length:2}}]},dt(_S,Ja,{transform(t,e){return this._dims?t.modified("fields")||t.fields.some((t=>e.modified(t.fields)))?this.reinit(t,e):this.eval(t,e):this.init(t,e)},init(t,e){const n=t.fields,r=t.query,i=this._indices={},o=this._dims=[],a=r.length;let s,u,l=0;for(;l{const t=i.remove(e,n);for(const e in r)r[e].reindex(t)}))},update(t,e,n){const r=this._dims,i=t.query,o=e.stamp,a=r.length;let s,u,l=0;for(n.filters=0,u=0;ud)for(m=d,y=Math.min(f,p);mp)for(m=Math.max(f,p),y=h;mc)for(d=c,p=Math.min(u,f);df)for(d=Math.max(u,f),p=l;ds[t]&n?null:a[t];return o.filter(o.MOD,l),i&i-1?(o.filter(o.ADD,(t=>{const e=s[t]&n;return!e&&e^u[t]&n?a[t]:null})),o.filter(o.REM,(t=>{const e=s[t]&n;return e&&!(e^e^u[t]&n)?a[t]:null}))):(o.filter(o.ADD,l),o.filter(o.REM,(t=>(s[t]&n)===i?a[t]:null))),o.filter(o.SOURCE,(t=>l(t._index)))}});var bS=Object.freeze({__proto__:null,crossfilter:_S,resolvefilter:xS});const wS="Literal",kS="Property",AS="ArrayExpression",MS="BinaryExpression",ES="CallExpression",DS="ConditionalExpression",CS="LogicalExpression",FS="MemberExpression",SS="ObjectExpression",$S="UnaryExpression";function TS(t){this.type=t}var BS,zS,NS,OS,RS;TS.prototype.visit=function(t){let e,n,r;if(t(this))return 1;for(e=function(t){switch(t.type){case AS:return t.elements;case MS:case CS:return[t.left,t.right];case ES:return[t.callee].concat(t.arguments);case DS:return[t.test,t.consequent,t.alternate];case FS:return[t.object,t.property];case SS:return t.properties;case kS:return[t.key,t.value];case $S:return[t.argument];default:return[]}}(this),n=0,r=e.length;n",BS[qS]="Identifier",BS[PS]="Keyword",BS[jS]="Null",BS[IS]="Numeric",BS[WS]="Punctuator",BS[HS]="String",BS[9]="RegularExpression";var YS="ArrayExpression",GS="BinaryExpression",VS="CallExpression",XS="ConditionalExpression",JS="Identifier",ZS="Literal",QS="LogicalExpression",KS="MemberExpression",t$="ObjectExpression",e$="Property",n$="UnaryExpression",r$="Unexpected token %0",i$="Unexpected number",o$="Unexpected string",a$="Unexpected identifier",s$="Unexpected reserved word",u$="Unexpected end of input",l$="Invalid regular expression",c$="Invalid regular expression: missing /",f$="Octal literals are not allowed in strict mode.",h$="Duplicate data property in object literal not allowed in strict mode",d$="ILLEGAL",p$="Disabled.",g$=new RegExp("[\\xAA\\xB5\\xBA\\xC0-\\xD6\\xD8-\\xF6\\xF8-\\u02C1\\u02C6-\\u02D1\\u02E0-\\u02E4\\u02EC\\u02EE\\u0370-\\u0374\\u0376\\u0377\\u037A-\\u037D\\u037F\\u0386\\u0388-\\u038A\\u038C\\u038E-\\u03A1\\u03A3-\\u03F5\\u03F7-\\u0481\\u048A-\\u052F\\u0531-\\u0556\\u0559\\u0561-\\u0587\\u05D0-\\u05EA\\u05F0-\\u05F2\\u0620-\\u064A\\u066E\\u066F\\u0671-\\u06D3\\u06D5\\u06E5\\u06E6\\u06EE\\u06EF\\u06FA-\\u06FC\\u06FF\\u0710\\u0712-\\u072F\\u074D-\\u07A5\\u07B1\\u07CA-\\u07EA\\u07F4\\u07F5\\u07FA\\u0800-\\u0815\\u081A\\u0824\\u0828\\u0840-\\u0858\\u08A0-\\u08B2\\u0904-\\u0939\\u093D\\u0950\\u0958-\\u0961\\u0971-\\u0980\\u0985-\\u098C\\u098F\\u0990\\u0993-\\u09A8\\u09AA-\\u09B0\\u09B2\\u09B6-\\u09B9\\u09BD\\u09CE\\u09DC\\u09DD\\u09DF-\\u09E1\\u09F0\\u09F1\\u0A05-\\u0A0A\\u0A0F\\u0A10\\u0A13-\\u0A28\\u0A2A-\\u0A30\\u0A32\\u0A33\\u0A35\\u0A36\\u0A38\\u0A39\\u0A59-\\u0A5C\\u0A5E\\u0A72-\\u0A74\\u0A85-\\u0A8D\\u0A8F-\\u0A91\\u0A93-\\u0AA8\\u0AAA-\\u0AB0\\u0AB2\\u0AB3\\u0AB5-\\u0AB9\\u0ABD\\u0AD0\\u0AE0\\u0AE1\\u0B05-\\u0B0C\\u0B0F\\u0B10\\u0B13-\\u0B28\\u0B2A-\\u0B30\\u0B32\\u0B33\\u0B35-\\u0B39\\u0B3D\\u0B5C\\u0B5D\\u0B5F-\\u0B61\\u0B71\\u0B83\\u0B85-\\u0B8A\\u0B8E-\\u0B90\\u0B92-\\u0B95\\u0B99\\u0B9A\\u0B9C\\u0B9E\\u0B9F\\u0BA3\\u0BA4\\u0BA8-\\u0BAA\\u0BAE-\\u0BB9\\u0BD0\\u0C05-\\u0C0C\\u0C0E-\\u0C10\\u0C12-\\u0C28\\u0C2A-\\u0C39\\u0C3D\\u0C58\\u0C59\\u0C60\\u0C61\\u0C85-\\u0C8C\\u0C8E-\\u0C90\\u0C92-\\u0CA8\\u0CAA-\\u0CB3\\u0CB5-\\u0CB9\\u0CBD\\u0CDE\\u0CE0\\u0CE1\\u0CF1\\u0CF2\\u0D05-\\u0D0C\\u0D0E-\\u0D10\\u0D12-\\u0D3A\\u0D3D\\u0D4E\\u0D60\\u0D61\\u0D7A-\\u0D7F\\u0D85-\\u0D96\\u0D9A-\\u0DB1\\u0DB3-\\u0DBB\\u0DBD\\u0DC0-\\u0DC6\\u0E01-\\u0E30\\u0E32\\u0E33\\u0E40-\\u0E46\\u0E81\\u0E82\\u0E84\\u0E87\\u0E88\\u0E8A\\u0E8D\\u0E94-\\u0E97\\u0E99-\\u0E9F\\u0EA1-\\u0EA3\\u0EA5\\u0EA7\\u0EAA\\u0EAB\\u0EAD-\\u0EB0\\u0EB2\\u0EB3\\u0EBD\\u0EC0-\\u0EC4\\u0EC6\\u0EDC-\\u0EDF\\u0F00\\u0F40-\\u0F47\\u0F49-\\u0F6C\\u0F88-\\u0F8C\\u1000-\\u102A\\u103F\\u1050-\\u1055\\u105A-\\u105D\\u1061\\u1065\\u1066\\u106E-\\u1070\\u1075-\\u1081\\u108E\\u10A0-\\u10C5\\u10C7\\u10CD\\u10D0-\\u10FA\\u10FC-\\u1248\\u124A-\\u124D\\u1250-\\u1256\\u1258\\u125A-\\u125D\\u1260-\\u1288\\u128A-\\u128D\\u1290-\\u12B0\\u12B2-\\u12B5\\u12B8-\\u12BE\\u12C0\\u12C2-\\u12C5\\u12C8-\\u12D6\\u12D8-\\u1310\\u1312-\\u1315\\u1318-\\u135A\\u1380-\\u138F\\u13A0-\\u13F4\\u1401-\\u166C\\u166F-\\u167F\\u1681-\\u169A\\u16A0-\\u16EA\\u16EE-\\u16F8\\u1700-\\u170C\\u170E-\\u1711\\u1720-\\u1731\\u1740-\\u1751\\u1760-\\u176C\\u176E-\\u1770\\u1780-\\u17B3\\u17D7\\u17DC\\u1820-\\u1877\\u1880-\\u18A8\\u18AA\\u18B0-\\u18F5\\u1900-\\u191E\\u1950-\\u196D\\u1970-\\u1974\\u1980-\\u19AB\\u19C1-\\u19C7\\u1A00-\\u1A16\\u1A20-\\u1A54\\u1AA7\\u1B05-\\u1B33\\u1B45-\\u1B4B\\u1B83-\\u1BA0\\u1BAE\\u1BAF\\u1BBA-\\u1BE5\\u1C00-\\u1C23\\u1C4D-\\u1C4F\\u1C5A-\\u1C7D\\u1CE9-\\u1CEC\\u1CEE-\\u1CF1\\u1CF5\\u1CF6\\u1D00-\\u1DBF\\u1E00-\\u1F15\\u1F18-\\u1F1D\\u1F20-\\u1F45\\u1F48-\\u1F4D\\u1F50-\\u1F57\\u1F59\\u1F5B\\u1F5D\\u1F5F-\\u1F7D\\u1F80-\\u1FB4\\u1FB6-\\u1FBC\\u1FBE\\u1FC2-\\u1FC4\\u1FC6-\\u1FCC\\u1FD0-\\u1FD3\\u1FD6-\\u1FDB\\u1FE0-\\u1FEC\\u1FF2-\\u1FF4\\u1FF6-\\u1FFC\\u2071\\u207F\\u2090-\\u209C\\u2102\\u2107\\u210A-\\u2113\\u2115\\u2119-\\u211D\\u2124\\u2126\\u2128\\u212A-\\u212D\\u212F-\\u2139\\u213C-\\u213F\\u2145-\\u2149\\u214E\\u2160-\\u2188\\u2C00-\\u2C2E\\u2C30-\\u2C5E\\u2C60-\\u2CE4\\u2CEB-\\u2CEE\\u2CF2\\u2CF3\\u2D00-\\u2D25\\u2D27\\u2D2D\\u2D30-\\u2D67\\u2D6F\\u2D80-\\u2D96\\u2DA0-\\u2DA6\\u2DA8-\\u2DAE\\u2DB0-\\u2DB6\\u2DB8-\\u2DBE\\u2DC0-\\u2DC6\\u2DC8-\\u2DCE\\u2DD0-\\u2DD6\\u2DD8-\\u2DDE\\u2E2F\\u3005-\\u3007\\u3021-\\u3029\\u3031-\\u3035\\u3038-\\u303C\\u3041-\\u3096\\u309D-\\u309F\\u30A1-\\u30FA\\u30FC-\\u30FF\\u3105-\\u312D\\u3131-\\u318E\\u31A0-\\u31BA\\u31F0-\\u31FF\\u3400-\\u4DB5\\u4E00-\\u9FCC\\uA000-\\uA48C\\uA4D0-\\uA4FD\\uA500-\\uA60C\\uA610-\\uA61F\\uA62A\\uA62B\\uA640-\\uA66E\\uA67F-\\uA69D\\uA6A0-\\uA6EF\\uA717-\\uA71F\\uA722-\\uA788\\uA78B-\\uA78E\\uA790-\\uA7AD\\uA7B0\\uA7B1\\uA7F7-\\uA801\\uA803-\\uA805\\uA807-\\uA80A\\uA80C-\\uA822\\uA840-\\uA873\\uA882-\\uA8B3\\uA8F2-\\uA8F7\\uA8FB\\uA90A-\\uA925\\uA930-\\uA946\\uA960-\\uA97C\\uA984-\\uA9B2\\uA9CF\\uA9E0-\\uA9E4\\uA9E6-\\uA9EF\\uA9FA-\\uA9FE\\uAA00-\\uAA28\\uAA40-\\uAA42\\uAA44-\\uAA4B\\uAA60-\\uAA76\\uAA7A\\uAA7E-\\uAAAF\\uAAB1\\uAAB5\\uAAB6\\uAAB9-\\uAABD\\uAAC0\\uAAC2\\uAADB-\\uAADD\\uAAE0-\\uAAEA\\uAAF2-\\uAAF4\\uAB01-\\uAB06\\uAB09-\\uAB0E\\uAB11-\\uAB16\\uAB20-\\uAB26\\uAB28-\\uAB2E\\uAB30-\\uAB5A\\uAB5C-\\uAB5F\\uAB64\\uAB65\\uABC0-\\uABE2\\uAC00-\\uD7A3\\uD7B0-\\uD7C6\\uD7CB-\\uD7FB\\uF900-\\uFA6D\\uFA70-\\uFAD9\\uFB00-\\uFB06\\uFB13-\\uFB17\\uFB1D\\uFB1F-\\uFB28\\uFB2A-\\uFB36\\uFB38-\\uFB3C\\uFB3E\\uFB40\\uFB41\\uFB43\\uFB44\\uFB46-\\uFBB1\\uFBD3-\\uFD3D\\uFD50-\\uFD8F\\uFD92-\\uFDC7\\uFDF0-\\uFDFB\\uFE70-\\uFE74\\uFE76-\\uFEFC\\uFF21-\\uFF3A\\uFF41-\\uFF5A\\uFF66-\\uFFBE\\uFFC2-\\uFFC7\\uFFCA-\\uFFCF\\uFFD2-\\uFFD7\\uFFDA-\\uFFDC]"),m$=new RegExp("[\\xAA\\xB5\\xBA\\xC0-\\xD6\\xD8-\\xF6\\xF8-\\u02C1\\u02C6-\\u02D1\\u02E0-\\u02E4\\u02EC\\u02EE\\u0300-\\u0374\\u0376\\u0377\\u037A-\\u037D\\u037F\\u0386\\u0388-\\u038A\\u038C\\u038E-\\u03A1\\u03A3-\\u03F5\\u03F7-\\u0481\\u0483-\\u0487\\u048A-\\u052F\\u0531-\\u0556\\u0559\\u0561-\\u0587\\u0591-\\u05BD\\u05BF\\u05C1\\u05C2\\u05C4\\u05C5\\u05C7\\u05D0-\\u05EA\\u05F0-\\u05F2\\u0610-\\u061A\\u0620-\\u0669\\u066E-\\u06D3\\u06D5-\\u06DC\\u06DF-\\u06E8\\u06EA-\\u06FC\\u06FF\\u0710-\\u074A\\u074D-\\u07B1\\u07C0-\\u07F5\\u07FA\\u0800-\\u082D\\u0840-\\u085B\\u08A0-\\u08B2\\u08E4-\\u0963\\u0966-\\u096F\\u0971-\\u0983\\u0985-\\u098C\\u098F\\u0990\\u0993-\\u09A8\\u09AA-\\u09B0\\u09B2\\u09B6-\\u09B9\\u09BC-\\u09C4\\u09C7\\u09C8\\u09CB-\\u09CE\\u09D7\\u09DC\\u09DD\\u09DF-\\u09E3\\u09E6-\\u09F1\\u0A01-\\u0A03\\u0A05-\\u0A0A\\u0A0F\\u0A10\\u0A13-\\u0A28\\u0A2A-\\u0A30\\u0A32\\u0A33\\u0A35\\u0A36\\u0A38\\u0A39\\u0A3C\\u0A3E-\\u0A42\\u0A47\\u0A48\\u0A4B-\\u0A4D\\u0A51\\u0A59-\\u0A5C\\u0A5E\\u0A66-\\u0A75\\u0A81-\\u0A83\\u0A85-\\u0A8D\\u0A8F-\\u0A91\\u0A93-\\u0AA8\\u0AAA-\\u0AB0\\u0AB2\\u0AB3\\u0AB5-\\u0AB9\\u0ABC-\\u0AC5\\u0AC7-\\u0AC9\\u0ACB-\\u0ACD\\u0AD0\\u0AE0-\\u0AE3\\u0AE6-\\u0AEF\\u0B01-\\u0B03\\u0B05-\\u0B0C\\u0B0F\\u0B10\\u0B13-\\u0B28\\u0B2A-\\u0B30\\u0B32\\u0B33\\u0B35-\\u0B39\\u0B3C-\\u0B44\\u0B47\\u0B48\\u0B4B-\\u0B4D\\u0B56\\u0B57\\u0B5C\\u0B5D\\u0B5F-\\u0B63\\u0B66-\\u0B6F\\u0B71\\u0B82\\u0B83\\u0B85-\\u0B8A\\u0B8E-\\u0B90\\u0B92-\\u0B95\\u0B99\\u0B9A\\u0B9C\\u0B9E\\u0B9F\\u0BA3\\u0BA4\\u0BA8-\\u0BAA\\u0BAE-\\u0BB9\\u0BBE-\\u0BC2\\u0BC6-\\u0BC8\\u0BCA-\\u0BCD\\u0BD0\\u0BD7\\u0BE6-\\u0BEF\\u0C00-\\u0C03\\u0C05-\\u0C0C\\u0C0E-\\u0C10\\u0C12-\\u0C28\\u0C2A-\\u0C39\\u0C3D-\\u0C44\\u0C46-\\u0C48\\u0C4A-\\u0C4D\\u0C55\\u0C56\\u0C58\\u0C59\\u0C60-\\u0C63\\u0C66-\\u0C6F\\u0C81-\\u0C83\\u0C85-\\u0C8C\\u0C8E-\\u0C90\\u0C92-\\u0CA8\\u0CAA-\\u0CB3\\u0CB5-\\u0CB9\\u0CBC-\\u0CC4\\u0CC6-\\u0CC8\\u0CCA-\\u0CCD\\u0CD5\\u0CD6\\u0CDE\\u0CE0-\\u0CE3\\u0CE6-\\u0CEF\\u0CF1\\u0CF2\\u0D01-\\u0D03\\u0D05-\\u0D0C\\u0D0E-\\u0D10\\u0D12-\\u0D3A\\u0D3D-\\u0D44\\u0D46-\\u0D48\\u0D4A-\\u0D4E\\u0D57\\u0D60-\\u0D63\\u0D66-\\u0D6F\\u0D7A-\\u0D7F\\u0D82\\u0D83\\u0D85-\\u0D96\\u0D9A-\\u0DB1\\u0DB3-\\u0DBB\\u0DBD\\u0DC0-\\u0DC6\\u0DCA\\u0DCF-\\u0DD4\\u0DD6\\u0DD8-\\u0DDF\\u0DE6-\\u0DEF\\u0DF2\\u0DF3\\u0E01-\\u0E3A\\u0E40-\\u0E4E\\u0E50-\\u0E59\\u0E81\\u0E82\\u0E84\\u0E87\\u0E88\\u0E8A\\u0E8D\\u0E94-\\u0E97\\u0E99-\\u0E9F\\u0EA1-\\u0EA3\\u0EA5\\u0EA7\\u0EAA\\u0EAB\\u0EAD-\\u0EB9\\u0EBB-\\u0EBD\\u0EC0-\\u0EC4\\u0EC6\\u0EC8-\\u0ECD\\u0ED0-\\u0ED9\\u0EDC-\\u0EDF\\u0F00\\u0F18\\u0F19\\u0F20-\\u0F29\\u0F35\\u0F37\\u0F39\\u0F3E-\\u0F47\\u0F49-\\u0F6C\\u0F71-\\u0F84\\u0F86-\\u0F97\\u0F99-\\u0FBC\\u0FC6\\u1000-\\u1049\\u1050-\\u109D\\u10A0-\\u10C5\\u10C7\\u10CD\\u10D0-\\u10FA\\u10FC-\\u1248\\u124A-\\u124D\\u1250-\\u1256\\u1258\\u125A-\\u125D\\u1260-\\u1288\\u128A-\\u128D\\u1290-\\u12B0\\u12B2-\\u12B5\\u12B8-\\u12BE\\u12C0\\u12C2-\\u12C5\\u12C8-\\u12D6\\u12D8-\\u1310\\u1312-\\u1315\\u1318-\\u135A\\u135D-\\u135F\\u1380-\\u138F\\u13A0-\\u13F4\\u1401-\\u166C\\u166F-\\u167F\\u1681-\\u169A\\u16A0-\\u16EA\\u16EE-\\u16F8\\u1700-\\u170C\\u170E-\\u1714\\u1720-\\u1734\\u1740-\\u1753\\u1760-\\u176C\\u176E-\\u1770\\u1772\\u1773\\u1780-\\u17D3\\u17D7\\u17DC\\u17DD\\u17E0-\\u17E9\\u180B-\\u180D\\u1810-\\u1819\\u1820-\\u1877\\u1880-\\u18AA\\u18B0-\\u18F5\\u1900-\\u191E\\u1920-\\u192B\\u1930-\\u193B\\u1946-\\u196D\\u1970-\\u1974\\u1980-\\u19AB\\u19B0-\\u19C9\\u19D0-\\u19D9\\u1A00-\\u1A1B\\u1A20-\\u1A5E\\u1A60-\\u1A7C\\u1A7F-\\u1A89\\u1A90-\\u1A99\\u1AA7\\u1AB0-\\u1ABD\\u1B00-\\u1B4B\\u1B50-\\u1B59\\u1B6B-\\u1B73\\u1B80-\\u1BF3\\u1C00-\\u1C37\\u1C40-\\u1C49\\u1C4D-\\u1C7D\\u1CD0-\\u1CD2\\u1CD4-\\u1CF6\\u1CF8\\u1CF9\\u1D00-\\u1DF5\\u1DFC-\\u1F15\\u1F18-\\u1F1D\\u1F20-\\u1F45\\u1F48-\\u1F4D\\u1F50-\\u1F57\\u1F59\\u1F5B\\u1F5D\\u1F5F-\\u1F7D\\u1F80-\\u1FB4\\u1FB6-\\u1FBC\\u1FBE\\u1FC2-\\u1FC4\\u1FC6-\\u1FCC\\u1FD0-\\u1FD3\\u1FD6-\\u1FDB\\u1FE0-\\u1FEC\\u1FF2-\\u1FF4\\u1FF6-\\u1FFC\\u200C\\u200D\\u203F\\u2040\\u2054\\u2071\\u207F\\u2090-\\u209C\\u20D0-\\u20DC\\u20E1\\u20E5-\\u20F0\\u2102\\u2107\\u210A-\\u2113\\u2115\\u2119-\\u211D\\u2124\\u2126\\u2128\\u212A-\\u212D\\u212F-\\u2139\\u213C-\\u213F\\u2145-\\u2149\\u214E\\u2160-\\u2188\\u2C00-\\u2C2E\\u2C30-\\u2C5E\\u2C60-\\u2CE4\\u2CEB-\\u2CF3\\u2D00-\\u2D25\\u2D27\\u2D2D\\u2D30-\\u2D67\\u2D6F\\u2D7F-\\u2D96\\u2DA0-\\u2DA6\\u2DA8-\\u2DAE\\u2DB0-\\u2DB6\\u2DB8-\\u2DBE\\u2DC0-\\u2DC6\\u2DC8-\\u2DCE\\u2DD0-\\u2DD6\\u2DD8-\\u2DDE\\u2DE0-\\u2DFF\\u2E2F\\u3005-\\u3007\\u3021-\\u302F\\u3031-\\u3035\\u3038-\\u303C\\u3041-\\u3096\\u3099\\u309A\\u309D-\\u309F\\u30A1-\\u30FA\\u30FC-\\u30FF\\u3105-\\u312D\\u3131-\\u318E\\u31A0-\\u31BA\\u31F0-\\u31FF\\u3400-\\u4DB5\\u4E00-\\u9FCC\\uA000-\\uA48C\\uA4D0-\\uA4FD\\uA500-\\uA60C\\uA610-\\uA62B\\uA640-\\uA66F\\uA674-\\uA67D\\uA67F-\\uA69D\\uA69F-\\uA6F1\\uA717-\\uA71F\\uA722-\\uA788\\uA78B-\\uA78E\\uA790-\\uA7AD\\uA7B0\\uA7B1\\uA7F7-\\uA827\\uA840-\\uA873\\uA880-\\uA8C4\\uA8D0-\\uA8D9\\uA8E0-\\uA8F7\\uA8FB\\uA900-\\uA92D\\uA930-\\uA953\\uA960-\\uA97C\\uA980-\\uA9C0\\uA9CF-\\uA9D9\\uA9E0-\\uA9FE\\uAA00-\\uAA36\\uAA40-\\uAA4D\\uAA50-\\uAA59\\uAA60-\\uAA76\\uAA7A-\\uAAC2\\uAADB-\\uAADD\\uAAE0-\\uAAEF\\uAAF2-\\uAAF6\\uAB01-\\uAB06\\uAB09-\\uAB0E\\uAB11-\\uAB16\\uAB20-\\uAB26\\uAB28-\\uAB2E\\uAB30-\\uAB5A\\uAB5C-\\uAB5F\\uAB64\\uAB65\\uABC0-\\uABEA\\uABEC\\uABED\\uABF0-\\uABF9\\uAC00-\\uD7A3\\uD7B0-\\uD7C6\\uD7CB-\\uD7FB\\uF900-\\uFA6D\\uFA70-\\uFAD9\\uFB00-\\uFB06\\uFB13-\\uFB17\\uFB1D-\\uFB28\\uFB2A-\\uFB36\\uFB38-\\uFB3C\\uFB3E\\uFB40\\uFB41\\uFB43\\uFB44\\uFB46-\\uFBB1\\uFBD3-\\uFD3D\\uFD50-\\uFD8F\\uFD92-\\uFDC7\\uFDF0-\\uFDFB\\uFE00-\\uFE0F\\uFE20-\\uFE2D\\uFE33\\uFE34\\uFE4D-\\uFE4F\\uFE70-\\uFE74\\uFE76-\\uFEFC\\uFF10-\\uFF19\\uFF21-\\uFF3A\\uFF3F\\uFF41-\\uFF5A\\uFF66-\\uFFBE\\uFFC2-\\uFFC7\\uFFCA-\\uFFCF\\uFFD2-\\uFFD7\\uFFDA-\\uFFDC]");function y$(t,e){if(!t)throw new Error("ASSERT: "+e)}function v$(t){return t>=48&&t<=57}function _$(t){return"0123456789abcdefABCDEF".includes(t)}function x$(t){return"01234567".includes(t)}function b$(t){return 32===t||9===t||11===t||12===t||160===t||t>=5760&&[5760,6158,8192,8193,8194,8195,8196,8197,8198,8199,8200,8201,8202,8239,8287,12288,65279].includes(t)}function w$(t){return 10===t||13===t||8232===t||8233===t}function k$(t){return 36===t||95===t||t>=65&&t<=90||t>=97&&t<=122||92===t||t>=128&&g$.test(String.fromCharCode(t))}function A$(t){return 36===t||95===t||t>=65&&t<=90||t>=97&&t<=122||t>=48&&t<=57||92===t||t>=128&&m$.test(String.fromCharCode(t))}const M$={if:1,in:1,do:1,var:1,for:1,new:1,try:1,let:1,this:1,else:1,case:1,void:1,with:1,enum:1,while:1,break:1,catch:1,throw:1,const:1,yield:1,class:1,super:1,return:1,typeof:1,delete:1,switch:1,export:1,import:1,public:1,static:1,default:1,finally:1,extends:1,package:1,private:1,function:1,continue:1,debugger:1,interface:1,protected:1,instanceof:1,implements:1};function E$(){for(;NS1114111||"}"!==t)&&I$({},r$,d$),e<=65535?String.fromCharCode(e):(n=55296+(e-65536>>10),r=56320+(e-65536&1023),String.fromCharCode(n,r))}function F$(){var t,e;for(t=zS.charCodeAt(NS++),e=String.fromCharCode(t),92===t&&(117!==zS.charCodeAt(NS)&&I$({},r$,d$),++NS,(t=D$("u"))&&"\\"!==t&&k$(t.charCodeAt(0))||I$({},r$,d$),e=t);NS>>="===(r=zS.substr(NS,4))?{type:WS,value:r,start:i,end:NS+=4}:">>>"===(n=r.substr(0,3))||"<<="===n||">>="===n?{type:WS,value:n,start:i,end:NS+=3}:a===(e=n.substr(0,2))[1]&&"+-<>&|".includes(a)||"=>"===e?{type:WS,value:e,start:i,end:NS+=2}:("//"===e&&I$({},r$,d$),"<>=!+-*%&|^/".includes(a)?(++NS,{type:WS,value:a,start:i,end:NS}):void I$({},r$,d$))}function T$(){var t,e,n;if(y$(v$((n=zS[NS]).charCodeAt(0))||"."===n,"Numeric literal must start with a decimal digit or a decimal point"),e=NS,t="","."!==n){if(t=zS[NS++],n=zS[NS],"0"===t){if("x"===n||"X"===n)return++NS,function(t){let e="";for(;NS=0&&I$({},l$,n),{value:n,literal:e}}(),r=function(t,e){let n=t;e.includes("u")&&(n=n.replace(/\\u\{([0-9a-fA-F]+)\}/g,((t,e)=>{if(parseInt(e,16)<=1114111)return"x";I$({},l$)})).replace(/[\uD800-\uDBFF][\uDC00-\uDFFF]/g,"x"));try{new RegExp(n)}catch(t){I$({},l$)}try{return new RegExp(t,e)}catch(t){return null}}(e.value,n.value),{literal:e.literal+n.literal,value:r,regex:{pattern:e.value,flags:n.value},start:t,end:NS}}function z$(){if(E$(),NS>=OS)return{type:LS,start:NS,end:NS};const t=zS.charCodeAt(NS);return k$(t)?S$():40===t||41===t||59===t?$$():39===t||34===t?function(){var t,e,n,r,i="",o=!1;for(y$("'"===(t=zS[NS])||'"'===t,"String literal must starts with a quote"),e=NS,++NS;NS(y$(e":case"<=":case">=":case"instanceof":case"in":e=7;break;case"<<":case">>":case">>>":e=8;break;case"+":case"-":e=9;break;case"*":case"/":case"%":e=11}return e}function aT(){var t,e;return t=function(){var t,e,n,r,i,o,a,s,u,l;if(t=RS,u=iT(),0===(i=oT(r=RS)))return u;for(r.prec=i,N$(),e=[t,RS],o=[u,r,a=iT()];(i=oT(RS))>0;){for(;o.length>2&&i<=o[o.length-2].prec;)a=o.pop(),s=o.pop().value,u=o.pop(),e.pop(),n=R$(s,u,a),o.push(n);(r=N$()).prec=i,o.push(r),e.push(RS),n=iT(),o.push(n)}for(n=o[l=o.length-1],e.pop();l>1;)e.pop(),n=R$(o[l-1].value,o[l-2],n),l-=2;return n}(),Y$("?")&&(N$(),e=aT(),H$(":"),t=function(t,e,n){const r=new TS(XS);return r.test=t,r.consequent=e,r.alternate=n,r}(t,e,aT())),t}function sT(){const t=aT();if(Y$(","))throw new Error(p$);return t}function uT(t){NS=0,OS=(zS=t).length,RS=null,O$();const e=sT();if(RS.type!==LS)throw new Error("Unexpect token after expression.");return e}var lT={NaN:"NaN",E:"Math.E",LN2:"Math.LN2",LN10:"Math.LN10",LOG2E:"Math.LOG2E",LOG10E:"Math.LOG10E",PI:"Math.PI",SQRT1_2:"Math.SQRT1_2",SQRT2:"Math.SQRT2",MIN_VALUE:"Number.MIN_VALUE",MAX_VALUE:"Number.MAX_VALUE"};function cT(t){function e(e,n,r){return i=>function(e,n,r,i){let o=t(n[0]);return r&&(o=r+"("+o+")",0===r.lastIndexOf("new ",0)&&(o="("+o+")")),o+"."+e+(i<0?"":0===i?"()":"("+n.slice(1).map(t).join(",")+")")}(e,i,n,r)}const n="new Date",r="String",i="RegExp";return{isNaN:"Number.isNaN",isFinite:"Number.isFinite",abs:"Math.abs",acos:"Math.acos",asin:"Math.asin",atan:"Math.atan",atan2:"Math.atan2",ceil:"Math.ceil",cos:"Math.cos",exp:"Math.exp",floor:"Math.floor",hypot:"Math.hypot",log:"Math.log",max:"Math.max",min:"Math.min",pow:"Math.pow",random:"Math.random",round:"Math.round",sin:"Math.sin",sqrt:"Math.sqrt",tan:"Math.tan",clamp:function(e){e.length<3&&s("Missing arguments to clamp function."),e.length>3&&s("Too many arguments to clamp function.");const n=e.map(t);return"Math.max("+n[1]+", Math.min("+n[2]+","+n[0]+"))"},now:"Date.now",utc:"Date.UTC",datetime:n,date:e("getDate",n,0),day:e("getDay",n,0),year:e("getFullYear",n,0),month:e("getMonth",n,0),hours:e("getHours",n,0),minutes:e("getMinutes",n,0),seconds:e("getSeconds",n,0),milliseconds:e("getMilliseconds",n,0),time:e("getTime",n,0),timezoneoffset:e("getTimezoneOffset",n,0),utcdate:e("getUTCDate",n,0),utcday:e("getUTCDay",n,0),utcyear:e("getUTCFullYear",n,0),utcmonth:e("getUTCMonth",n,0),utchours:e("getUTCHours",n,0),utcminutes:e("getUTCMinutes",n,0),utcseconds:e("getUTCSeconds",n,0),utcmilliseconds:e("getUTCMilliseconds",n,0),length:e("length",null,-1),parseFloat:"parseFloat",parseInt:"parseInt",upper:e("toUpperCase",r,0),lower:e("toLowerCase",r,0),substring:e("substring",r),split:e("split",r),trim:e("trim",r,0),regexp:i,test:e("test",i),if:function(e){e.length<3&&s("Missing arguments to if function."),e.length>3&&s("Too many arguments to if function.");const n=e.map(t);return"("+n[0]+"?"+n[1]+":"+n[2]+")"}}}function fT(t){const e=(t=t||{}).allowed?Bt(t.allowed):{},n=t.forbidden?Bt(t.forbidden):{},r=t.constants||lT,i=(t.functions||cT)(h),o=t.globalvar,a=t.fieldvar,u=J(o)?o:t=>`${o}["${t}"]`;let l={},c={},f=0;function h(t){if(xt(t))return t;const e=d[t.type];return null==e&&s("Unsupported type: "+t.type),e(t)}const d={Literal:t=>t.raw,Identifier:t=>{const i=t.name;return f>0?i:lt(n,i)?s("Illegal identifier: "+i):lt(r,i)?r[i]:lt(e,i)?i:(l[i]=1,u(i))},MemberExpression:t=>{const e=!t.computed,n=h(t.object);e&&(f+=1);const r=h(t.property);return n===a&&(c[function(t){const e=t&&t.length-1;return e&&('"'===t[0]&&'"'===t[e]||"'"===t[0]&&"'"===t[e])?t.slice(1,-1):t}(r)]=1),e&&(f-=1),n+(e?"."+r:"["+r+"]")},CallExpression:t=>{"Identifier"!==t.callee.type&&s("Illegal callee type: "+t.callee.type);const e=t.callee.name,n=t.arguments,r=lt(i,e)&&i[e];return r||s("Unrecognized function: "+e),J(r)?r(n):r+"("+n.map(h).join(",")+")"},ArrayExpression:t=>"["+t.elements.map(h).join(",")+"]",BinaryExpression:t=>"("+h(t.left)+" "+t.operator+" "+h(t.right)+")",UnaryExpression:t=>"("+t.operator+h(t.argument)+")",ConditionalExpression:t=>"("+h(t.test)+"?"+h(t.consequent)+":"+h(t.alternate)+")",LogicalExpression:t=>"("+h(t.left)+t.operator+h(t.right)+")",ObjectExpression:t=>"{"+t.properties.map(h).join(",")+"}",Property:t=>{f+=1;const e=h(t.key);return f-=1,e+":"+h(t.value)}};function p(t){const e={code:h(t),globals:Object.keys(l),fields:Object.keys(c)};return l={},c={},e}return p.functions=i,p.constants=r,p}const hT=Symbol("vega_selection_getter");function dT(t){return t.getter&&t.getter[hT]||(t.getter=l(t.field),t.getter[hT]=!0),t.getter}const pT="intersect",gT="union",mT="_vgsid_",yT=l(mT),vT="E",_T="R",xT="R-E",bT="R-LE",wT="R-RE",kT="index:unit";function AT(t,e){for(var n,r,i=e.fields,o=e.values,a=i.length,s=0;s1?e-1:0),r=1;re.includes(t))):e},R_union:function(t,e){var n=S(e[0]),r=S(e[1]);return n>r&&(n=e[1],r=e[0]),t.length?(t[0]>n&&(t[0]=n),t[1]r&&(n=e[1],r=e[0]),t.length?rr&&(t[1]=r),t):[n,r]}};function FT(t,e,n,r){e[0].type!==wS&&s("First argument to selection functions must be a string literal.");const i=e[0].value,o="unit",a="@"+o,u=":"+i;(e.length>=2&&F(e).value)!==pT||lt(r,a)||(r[a]=n.getData(i).indataRef(n,o)),lt(r,u)||(r[u]=n.getData(i).tuplesRef())}function ST(t){const e=this.context.data[t];return e?e.values.value:[]}const $T=t=>function(e,n){return this.context.dataflow.locale()[t](n)(e)},TT=$T("format"),BT=$T("timeFormat"),zT=$T("utcFormat"),NT=$T("timeParse"),OT=$T("utcParse"),RT=new Date(2e3,0,1);function UT(t,e,n){return Number.isInteger(t)&&Number.isInteger(e)?(RT.setYear(2e3),RT.setMonth(t),RT.setDate(e),BT.call(this,RT,n)):""}const LT="%",qT="$";function PT(t,e,n,r){e[0].type!==wS&&s("First argument to data functions must be a string literal.");const i=e[0].value,o=":"+i;if(!lt(o,r))try{r[o]=n.getData(i).tuplesRef()}catch(t){}}function jT(t,e,n,r){if(e[0].type===wS)IT(n,r,e[0].value);else for(t in n.scales)IT(n,r,t)}function IT(t,e,n){const r=LT+n;if(!lt(e,r))try{e[r]=t.scaleRef(n)}catch(t){}}function WT(t,e){if(J(t))return t;if(xt(t)){const n=e.scales[t];return n&&function(t){return t&&!0===t[Gd]}(n.value)?n.value:void 0}}function HT(t,e,n){e.__bandwidth=t=>t&&t.bandwidth?t.bandwidth():0,n._bandwidth=jT,n._range=jT,n._scale=jT;const r=e=>"_["+(e.type===wS?Ct(LT+e.value):Ct(LT)+"+"+t(e))+"]";return{_bandwidth:t=>`this.__bandwidth(${r(t[0])})`,_range:t=>`${r(t[0])}.range()`,_scale:e=>`${r(e[0])}(${t(e[1])})`}}function YT(t,e){return function(n,r,i){if(n){const e=WT(n,(i||this).context);return e&&e.path[t](r)}return e(r)}}const GT=YT("area",(function(t){return bw=new se,rw(t,ww),2*bw})),VT=YT("bounds",(function(t){var e,n,r,i,o,a,s;if(hw=fw=-(lw=cw=1/0),vw=[],rw(t,Jw),n=vw.length){for(vw.sort(ok),e=1,o=[r=vw[0]];eik(r[0],r[1])&&(r[1]=i[1]),ik(i[0],r[1])>ik(r[0],r[1])&&(r[0]=i[0])):o.push(r=i);for(a=-1/0,e=0,r=o[n=o.length-1];e<=n;r=i,++e)i=o[e],(s=ik(r[1],i[0]))>a&&(a=s,lw=i[0],fw=r[1])}return vw=_w=null,lw===1/0||cw===1/0?[[NaN,NaN],[NaN,NaN]]:[[lw,cw],[fw,hw]]})),XT=YT("centroid",(function(t){zw=Nw=Ow=Rw=Uw=Lw=qw=Pw=0,jw=new se,Iw=new se,Ww=new se,rw(t,sk);var e=+jw,n=+Iw,r=+Ww,i=jb(e,n,r);return itB(t,e)}const nB={};function rB(t){return k(t)||ArrayBuffer.isView(t)?t:null}function iB(t){return rB(t)||(xt(t)?t:null)}const oB=t=>t.data;function aB(t,e){const n=ST.call(e,t);return n.root&&n.root.lookup||{}}const sB=()=>"undefined"!=typeof window&&window||null;function uB(t,e,n){if(!t)return[];const[r,i]=t,o=(new Rg).set(r[0],r[1],i[0],i[1]);return w_(n||this.context.dataflow.scenegraph().root,o,function(t){let e=null;if(t){const n=V(t.marktype),r=V(t.markname);e=t=>(!n.length||n.some((e=>t.marktype===e)))&&(!r.length||r.some((e=>t.name===e)))}return e}(e))}const lB={random:()=>t.random(),cumulativeNormal:hs,cumulativeLogNormal:vs,cumulativeUniform:As,densityNormal:fs,densityLogNormal:ys,densityUniform:ks,quantileNormal:ds,quantileLogNormal:_s,quantileUniform:Ms,sampleNormal:cs,sampleLogNormal:ms,sampleUniform:ws,isArray:k,isBoolean:gt,isDate:mt,isDefined:t=>void 0!==t,isNumber:vt,isObject:A,isRegExp:_t,isString:xt,isTuple:ma,isValid:t=>null!=t&&t==t,toBoolean:Ft,toDate:t=>$t(t),toNumber:S,toString:Tt,indexof:function(t){for(var e=arguments.length,n=new Array(e>1?e-1:0),r=1;r1?e-1:0),r=1;r1?e-1:0),r=1;r1?e-1:0),r=1;rat(t),inScope:function(t){const e=this.context.group;let n=!1;if(e)for(;t;){if(t===e){n=!0;break}t=t.mark.group}return n},intersect:uB,clampRange:X,pinchDistance:function(t){const e=t.touches,n=e[0].clientX-e[1].clientX,r=e[0].clientY-e[1].clientY;return Math.hypot(n,r)},pinchAngle:function(t){const e=t.touches;return Math.atan2(e[0].clientY-e[1].clientY,e[0].clientX-e[1].clientX)},screen:function(){const t=sB();return t?t.screen:{}},containerSize:function(){const t=this.context.dataflow,e=t.container&&t.container();return e?[e.clientWidth,e.clientHeight]:[void 0,void 0]},windowSize:function(){const t=sB();return t?[t.innerWidth,t.innerHeight]:[void 0,void 0]},bandspace:function(t,e,n){return xd(t||0,e||0,n||0)},setdata:function(t,e){const n=this.context.dataflow,r=this.context.data[t].input;return n.pulse(r,n.changeset().remove(p).insert(e)),1},pathShape:function(t){let e=null;return function(n){return n?ag(n,e=e||Xp(t)):t}},panLinear:R,panLog:U,panPow:L,panSymlog:q,zoomLinear:j,zoomLog:I,zoomPow:W,zoomSymlog:H,encode:function(t,e,n){if(t){const n=this.context.dataflow,r=t.mark.source;n.pulse(r,n.changeset().encode(t,e))}return void 0!==n?n:t},modify:function(t,e,n,r,i,o){const a=this.context.dataflow,s=this.context.data[t],u=s.input,l=a.stamp();let c,f,h=s.changes;if(!1===a._trigger||!(u.value.length||e||r))return 0;if((!h||h.stamp{s.modified=!0,a.pulse(u,h).run()}),!0,1)),n&&(c=!0===n?p:k(n)||ma(n)?n:eB(n),h.remove(c)),e&&h.insert(e),r&&(c=eB(r),u.value.some(c)?h.remove(c):h.insert(r)),i)for(f in o)h.modify(i,f,o[f]);return 1},lassoAppend:function(t,e,n){let r=arguments.length>3&&void 0!==arguments[3]?arguments[3]:5;const i=(t=V(t))[t.length-1];return void 0===i||Math.hypot(i[0]-e,i[1]-n)>r?[...t,[e,n]]:t},lassoPath:function(t){return V(t).reduce(((e,n,r)=>{let[i,o]=n;return e+(0==r?`M ${i},${o} `:r===t.length-1?" Z":`L ${i},${o} `)}),"")},intersectLasso:function(t,e,n){const{x:r,y:i,mark:o}=n,a=(new Rg).set(Number.MAX_SAFE_INTEGER,Number.MAX_SAFE_INTEGER,Number.MIN_SAFE_INTEGER,Number.MIN_SAFE_INTEGER);for(const[t,n]of e)ta.x2&&(a.x2=t),na.y2&&(a.y2=n);return a.translate(r,i),uB([[a.x1,a.y1],[a.x2,a.y2]],t,o).filter((t=>function(t,e,n){let r=0;for(let i=0,o=n.length-1;ie!=s>e&&t<(a-u)*(e-l)/(s-l)+u&&r++}return 1&r}(t.x,t.y,e)))}},cB=["view","item","group","xy","x","y"],fB="this.",hB={},dB={forbidden:["_"],allowed:["datum","event","item"],fieldvar:"datum",globalvar:t=>`_[${Ct(qT+t)}]`,functions:function(t){const e=cT(t);cB.forEach((t=>e[t]="event.vega."+t));for(const t in lB)e[t]=fB+t;return ot(e,HT(t,lB,hB)),e},constants:lT,visitors:hB},pB=fT(dB);function gB(t,e,n){return 1===arguments.length?lB[t]:(lB[t]=e,n&&(hB[t]=n),pB&&(pB.functions[t]=fB+t),this)}function mB(t,e){const n={};let r;try{r=uT(t=xt(t)?t:Ct(t)+"")}catch(e){s("Expression parse error: "+t)}r.visit((t=>{if(t.type!==ES)return;const r=t.callee.name,i=dB.visitors[r];i&&i(r,t.arguments,e,n)}));const i=pB(r);return i.globals.forEach((t=>{const r=qT+t;!lt(n,r)&&e.getSignal(t)&&(n[r]=e.signalRef(t))})),{$expr:ot({code:i.code},e.options.ast?{ast:r}:null),$fields:i.fields,$params:n}}gB("bandwidth",(function(t,e){const n=WT(t,(e||this).context);return n&&n.bandwidth?n.bandwidth():0}),jT),gB("copy",(function(t,e){const n=WT(t,(e||this).context);return n?n.copy():void 0}),jT),gB("domain",(function(t,e){const n=WT(t,(e||this).context);return n?n.domain():[]}),jT),gB("range",(function(t,e){const n=WT(t,(e||this).context);return n&&n.range?n.range():[]}),jT),gB("invert",(function(t,e,n){const r=WT(t,(n||this).context);return r?k(e)?(r.invertRange||r.invert)(e):(r.invert||r.invertExtent)(e):void 0}),jT),gB("scale",(function(t,e,n){const r=WT(t,(n||this).context);return r?r(e):void 0}),jT),gB("gradient",(function(t,e,n,r,i){t=WT(t,(i||this).context);const o=Pp(e,n);let a=t.domain(),s=a[0],u=F(a),l=f;return u-s?l=up(t,s,u):t=(t.interpolator?Xd("sequential")().interpolator(t.interpolator()):Xd("linear")().interpolate(t.interpolate()).range(t.range())).domain([s=0,u=1]),t.ticks&&(a=t.ticks(+r||15),s!==a[0]&&a.unshift(s),u!==F(a)&&a.push(u)),a.forEach((e=>o.stop(l(e),t(e)))),o}),jT),gB("geoArea",GT,jT),gB("geoBounds",VT,jT),gB("geoCentroid",XT,jT),gB("geoShape",(function(t,e,n){const r=WT(t,(n||this).context);return function(t){return r?r.path.context(t)(e):""}}),jT),gB("geoScale",(function(t,e){const n=WT(t,(e||this).context);return n&&n.scale()}),jT),gB("indata",(function(t,e,n){const r=this.context.data[t]["index:"+e],i=r?r.value.get(n):void 0;return i?i.count:i}),(function(t,e,n,r){e[0].type!==wS&&s("First argument to indata must be a string literal."),e[1].type!==wS&&s("Second argument to indata must be a string literal.");const i=e[0].value,o=e[1].value,a="@"+o;lt(a,r)||(r[a]=n.getData(i).indataRef(n,o))})),gB("data",ST,PT),gB("treePath",(function(t,e,n){const r=aB(t,this),i=r[e],o=r[n];return i&&o?i.path(o).map(oB):void 0}),PT),gB("treeAncestors",(function(t,e){const n=aB(t,this)[e];return n?n.ancestors().map(oB):void 0}),PT),gB("vlSelectionTest",(function(t,e,n){for(var r,i,o,a,s,u=this.context.data[t],l=u?u.values.value:[],c=u?u[kT]&&u[kT].value:void 0,f=n===pT,h=l.length,d=0;d(t[o[n].field]=e,t)),{}))}else u=mT,l=yT(i),(f=(c=v[u]||(v[u]={}))[s]||(c[s]=[])).push(l),n&&(f=_[s]||(_[s]=[])).push({[mT]:l});if(e=e||gT,v[mT]?v[mT]=CT[`${mT}_${e}`](...Object.values(v[mT])):Object.keys(v).forEach((t=>{v[t]=Object.keys(v[t]).map((e=>v[t][e])).reduce(((n,r)=>void 0===n?r:CT[`${x[t]}_${e}`](n,r)))})),y=Object.keys(_),n&&y.length){v[r?"vlPoint":"vlMulti"]=e===gT?{or:y.reduce(((t,e)=>(t.push(..._[e]),t)),[])}:{and:y.map((t=>({or:_[t]})))}}return v}),FT),gB("vlSelectionTuples",(function(t,e){return t.map((t=>ot(e.fields?{values:e.fields.map((e=>dT(e)(t.datum)))}:{[mT]:yT(t.datum)},e)))}));const yB=Bt(["rule"]),vB=Bt(["group","image","rect"]);function _B(t){return(t+"").toLowerCase()}function xB(t,e,n){n.endsWith(";")||(n="return("+n+");");const r=Function(...e.concat(n));return t&&t.functions?r.bind(t.functions):r}var bB={operator:(t,e)=>xB(t,["_"],e.code),parameter:(t,e)=>xB(t,["datum","_"],e.code),event:(t,e)=>xB(t,["event"],e.code),handler:(t,e)=>xB(t,["_","event"],`var datum=event.item&&event.item.datum;return ${e.code};`),encode:(t,e)=>{const{marktype:n,channels:r}=e;let i="var o=item,datum=o.datum,m=0,$;";for(const t in r){const e="o["+Ct(t)+"]";i+=`$=${r[t].code};if(${e}!==$)${e}=$,m=1;`}return i+=function(t,e){let n="";return yB[e]||(t.x2&&(t.x?(vB[e]&&(n+="if(o.x>o.x2)$=o.x,o.x=o.x2,o.x2=$;"),n+="o.width=o.x2-o.x;"):n+="o.x=o.x2-(o.width||0);"),t.xc&&(n+="o.x=o.xc-(o.width||0)/2;"),t.y2&&(t.y?(vB[e]&&(n+="if(o.y>o.y2)$=o.y,o.y=o.y2,o.y2=$;"),n+="o.height=o.y2-o.y;"):n+="o.y=o.y2-(o.height||0);"),t.yc&&(n+="o.y=o.yc-(o.height||0)/2;")),n}(r,n),i+="return m;",xB(t,["item","_"],i)},codegen:{get(t){const e=`[${t.map(Ct).join("][")}]`,n=Function("_",`return _${e};`);return n.path=e,n},comparator(t,e){let n;const r=Function("a","b","var u, v; return "+t.map(((t,r)=>{const i=e[r];let o,a;return t.path?(o=`a${t.path}`,a=`b${t.path}`):((n=n||{})["f"+r]=t,o=`this.f${r}(a)`,a=`this.f${r}(b)`),function(t,e,n,r){return`((u = ${t}) < (v = ${e}) || u == null) && v != null ? ${n}\n : (u > v || v == null) && u != null ? ${r}\n : ((v = v instanceof Date ? +v : v), (u = u instanceof Date ? +u : u)) !== u && v === v ? ${n}\n : v !== v && u === u ? ${r} : `}(o,a,-i,i)})).join("")+"0;");return n?r.bind(n):r}}};function wB(t,e,n){if(!t||!A(t))return t;for(let r,i=0,o=kB.length;it&&t.$tupleid?ya:t));return e.fn[n]||(e.fn[n]=Q(r,t.$order,e.expr.codegen))}},{key:"$context",parse:function(t,e){return e}},{key:"$subflow",parse:function(t,e){const n=t.$subflow;return function(t,r,i){const o=e.fork().parse(n),a=o.get(n.operators[0].id),s=o.signals.parent;return s&&s.set(i),a.detachSubflow=()=>e.detach(o),a}}},{key:"$tupleid",parse:function(){return ya}}];const AB={skip:!0};function MB(t,e,n,r){return new EB(t,e,n,r)}function EB(t,e,n,r){this.dataflow=t,this.transforms=e,this.events=t.events.bind(t),this.expr=r||bB,this.signals={},this.scales={},this.nodes={},this.data={},this.fn={},n&&(this.functions=Object.create(n),this.functions.context=this)}function DB(t){this.dataflow=t.dataflow,this.transforms=t.transforms,this.events=t.events,this.expr=t.expr,this.signals=Object.create(t.signals),this.scales=Object.create(t.scales),this.nodes=Object.create(t.nodes),this.data=Object.create(t.data),this.fn=Object.create(t.fn),t.functions&&(this.functions=Object.create(t.functions),this.functions.context=this)}function CB(t,e){t&&(null==e?t.removeAttribute("aria-label"):t.setAttribute("aria-label",e))}EB.prototype=DB.prototype={fork(){const t=new DB(this);return(this.subcontext||(this.subcontext=[])).push(t),t},detach(t){this.subcontext=this.subcontext.filter((e=>e!==t));const e=Object.keys(t.nodes);for(const n of e)t.nodes[n]._targets=null;for(const n of e)t.nodes[n].detach();t.nodes=null},get(t){return this.nodes[t]},set(t,e){return this.nodes[t]=e},add(t,e){const n=this,r=n.dataflow,i=t.value;if(n.set(t.id,e),function(t){return"collect"===_B(t)}(t.type)&&i&&(i.$ingest?r.ingest(e,i.$ingest,i.$format):i.$request?r.preload(e,i.$request,i.$format):r.pulse(e,r.changeset().insert(i))),t.root&&(n.root=e),t.parent){let i=n.get(t.parent.$ref);i?(r.connect(i,[e]),e.targets().add(i)):(n.unresolved=n.unresolved||[]).push((()=>{i=n.get(t.parent.$ref),r.connect(i,[e]),e.targets().add(i)}))}if(t.signal&&(n.signals[t.signal]=e),t.scale&&(n.scales[t.scale]=e),t.data)for(const r in t.data){const i=n.data[r]||(n.data[r]={});t.data[r].forEach((t=>i[t]=e))}},resolve(){return(this.unresolved||[]).forEach((t=>t())),delete this.unresolved,this},operator(t,e){this.add(t,this.dataflow.add(t.value,e))},transform(t,e){this.add(t,this.dataflow.add(this.transforms[_B(e)]))},stream(t,e){this.set(t.id,e)},update(t,e,n,r,i){this.dataflow.on(e,n,r,i,t.options)},operatorExpression(t){return this.expr.operator(this,t)},parameterExpression(t){return this.expr.parameter(this,t)},eventExpression(t){return this.expr.event(this,t)},handlerExpression(t){return this.expr.handler(this,t)},encodeExpression(t){return this.expr.encode(this,t)},parse:function(t){const e=this,n=t.operators||[];return t.background&&(e.background=t.background),t.eventConfig&&(e.eventConfig=t.eventConfig),t.locale&&(e.locale=t.locale),n.forEach((t=>e.parseOperator(t))),n.forEach((t=>e.parseOperatorParameters(t))),(t.streams||[]).forEach((t=>e.parseStream(t))),(t.updates||[]).forEach((t=>e.parseUpdate(t))),e.resolve()},parseOperator:function(t){const e=this;!function(t){return"operator"===_B(t)}(t.type)&&t.type?e.transform(t,t.type):e.operator(t,t.update?e.operatorExpression(t.update):null)},parseOperatorParameters:function(t){const e=this;if(t.params){const n=e.get(t.id);n||s("Invalid operator id: "+t.id),e.dataflow.connect(n,n.parameters(e.parseParameters(t.params),t.react,t.initonly))}},parseParameters:function(t,e){e=e||{};const n=this;for(const r in t){const i=t[r];e[r]=k(i)?i.map((t=>wB(t,n,e))):wB(i,n,e)}return e},parseStream:function(t){var e,n=this,r=null!=t.filter?n.eventExpression(t.filter):void 0,i=null!=t.stream?n.get(t.stream):void 0;t.source?i=n.events(t.source,t.type,r):t.merge&&(i=(e=t.merge.map((t=>n.get(t))))[0].merge.apply(e[0],e.slice(1))),t.between&&(e=t.between.map((t=>n.get(t))),i=i.between(e[0],e[1])),t.filter&&(i=i.filter(r)),null!=t.throttle&&(i=i.throttle(+t.throttle)),null!=t.debounce&&(i=i.debounce(+t.debounce)),null==i&&s("Invalid stream definition: "+JSON.stringify(t)),t.consume&&i.consume(!0),n.stream(t,i)},parseUpdate:function(t){var e,n=this,r=A(r=t.source)?r.$ref:r,i=n.get(r),o=t.update,a=void 0;i||s("Source not defined: "+t.source),e=t.target&&t.target.$expr?n.eventExpression(t.target.$expr):n.get(t.target),o&&o.$expr&&(o.$params&&(a=n.parseParameters(o.$params)),o=n.handlerExpression(o.$expr)),n.update(t,i,e,o,a)},getState:function(t){var e=this,n={};if(t.signals){var r=n.signals={};Object.keys(e.signals).forEach((n=>{const i=e.signals[n];t.signals(n,i)&&(r[n]=i.value)}))}if(t.data){var i=n.data={};Object.keys(e.data).forEach((n=>{const r=e.data[n];t.data(n,r)&&(i[n]=r.input.value)}))}return e.subcontext&&!1!==t.recurse&&(n.subcontext=e.subcontext.map((e=>e.getState(t)))),n},setState:function(t){var e=this,n=e.dataflow,r=t.data,i=t.signals;Object.keys(i||{}).forEach((t=>{n.update(e.signals[t],i[t],AB)})),Object.keys(r||{}).forEach((t=>{n.pulse(e.data[t].input,n.changeset().remove(p).insert(r[t]))})),(t.subcontext||[]).forEach(((t,n)=>{const r=e.subcontext[n];r&&r.setState(t)}))}};const FB="default";function SB(t,e){const n=t.globalCursor()?"undefined"!=typeof document&&document.body:t.container();if(n)return null==e?n.style.removeProperty("cursor"):n.style.cursor=e}function $B(t,e){var n=t._runtime.data;return lt(n,e)||s("Unrecognized data set: "+e),n[e]}function TB(t,e){Aa(e)||s("Second argument to changes must be a changeset.");const n=$B(this,t);return n.modified=!0,this.pulse(n.input,e)}function BB(t){var e=t.padding();return Math.max(0,t._viewWidth+e.left+e.right)}function zB(t){var e=t.padding();return Math.max(0,t._viewHeight+e.top+e.bottom)}function NB(t){var e=t.padding(),n=t._origin;return[e.left+n[0],e.top+n[1]]}function OB(t,e,n){var r,i,o=t._renderer,a=o&&o.canvas();return a&&(i=NB(t),(r=Xy(e.changedTouches?e.changedTouches[0]:e,a))[0]-=i[0],r[1]-=i[1]),e.dataflow=t,e.item=n,e.vega=function(t,e,n){const r=e?"group"===e.mark.marktype?e:e.mark.group:null;function i(t){var n,i=r;if(t)for(n=e;n;n=n.mark.group)if(n.mark.name===t){i=n;break}return i&&i.mark&&i.mark.interactive?i:{}}function o(t){if(!t)return n;xt(t)&&(t=i(t));const e=n.slice();for(;t;)e[0]-=t.x||0,e[1]-=t.y||0,t=t.mark&&t.mark.group;return e}return{view:rt(t),item:rt(e||{}),group:i,xy:o,x:t=>o(t)[0],y:t=>o(t)[1]}}(t,n,r),e}const RB="view",UB={trap:!1};function LB(t,e,n,r){t._eventListeners.push({type:n,sources:V(e),handler:r})}function qB(t,e,n){const r=t._eventConfig&&t._eventConfig[e];return!(!1===r||A(r)&&!r[n])||(t.warn(`Blocked ${e} ${n} event listener.`),!1)}function PB(t){return t.item}function jB(t){return t.item.mark.source}function IB(t){return function(e,n){return n.vega.view().changeset().encode(n.item,t)}}function WB(t,e,n){const r=document.createElement(t);for(const t in e)r.setAttribute(t,e[t]);return null!=n&&(r.textContent=n),r}const HB="vega-bind",YB="vega-bind-name",GB="vega-bind-radio";function VB(t,e,n,r){const i=n.event||"input",o=()=>t.update(e.value);r.signal(n.signal,e.value),e.addEventListener(i,o),LB(r,e,i,o),t.set=t=>{e.value=t,e.dispatchEvent(function(t){return"undefined"!=typeof Event?new Event(t):{type:t}}(i))}}function XB(t,e,n,r){const i=r.signal(n.signal),o=WB("div",{class:HB}),a="radio"===n.input?o:o.appendChild(WB("label"));a.appendChild(WB("span",{class:YB},n.name||n.signal)),e.appendChild(o);let s=JB;switch(n.input){case"checkbox":s=ZB;break;case"select":s=QB;break;case"radio":s=KB;break;case"range":s=tz}s(t,a,n,i)}function JB(t,e,n,r){const i=WB("input");for(const t in n)"signal"!==t&&"element"!==t&&i.setAttribute("input"===t?"type":t,n[t]);i.setAttribute("name",n.signal),i.value=r,e.appendChild(i),i.addEventListener("input",(()=>t.update(i.value))),t.elements=[i],t.set=t=>i.value=t}function ZB(t,e,n,r){const i={type:"checkbox",name:n.signal};r&&(i.checked=!0);const o=WB("input",i);e.appendChild(o),o.addEventListener("change",(()=>t.update(o.checked))),t.elements=[o],t.set=t=>o.checked=!!t||null}function QB(t,e,n,r){const i=WB("select",{name:n.signal}),o=n.labels||[];n.options.forEach(((t,e)=>{const n={value:t};ez(t,r)&&(n.selected=!0),i.appendChild(WB("option",n,(o[e]||t)+""))})),e.appendChild(i),i.addEventListener("change",(()=>{t.update(n.options[i.selectedIndex])})),t.elements=[i],t.set=t=>{for(let e=0,r=n.options.length;e{const s={type:"radio",name:n.signal,value:e};ez(e,r)&&(s.checked=!0);const u=WB("input",s);u.addEventListener("change",(()=>t.update(e)));const l=WB("label",{},(o[a]||e)+"");return l.prepend(u),i.appendChild(l),u})),t.set=e=>{const n=t.elements,r=n.length;for(let t=0;t{u.textContent=s.value,t.update(+s.value)};s.addEventListener("input",l),s.addEventListener("change",l),t.elements=[s],t.set=t=>{s.value=t,u.textContent=t}}function ez(t,e){return t===e||t+""==e+""}function nz(t,e,n,r,i,o){return(e=e||new r(t.loader())).initialize(n,BB(t),zB(t),NB(t),i,o).background(t.background())}function rz(t,e){return e?function(){try{e.apply(this,arguments)}catch(e){t.error(e)}}:null}function iz(t,e,n){if("string"==typeof e){if("undefined"==typeof document)return t.error("DOM document instance not found."),null;if(!(e=document.querySelector(e)))return t.error("Signal bind element not found: "+e),null}if(e&&n)try{e.textContent=""}catch(n){e=null,t.error(n)}return e}const oz=t=>+t||0;function az(t){return A(t)?{top:oz(t.top),bottom:oz(t.bottom),left:oz(t.left),right:oz(t.right)}:(t=>({top:t,bottom:t,left:t,right:t}))(oz(t))}async function sz(t,e,n,r){const i=b_(e),o=i&&i.headless;return o||s("Unrecognized renderer type: "+e),await t.runAsync(),nz(t,null,null,o,n,r).renderAsync(t._scenegraph.root)}var uz="width",lz="height",cz="padding",fz={skip:!0};function hz(t,e){var n=t.autosize(),r=t.padding();return e-(n&&n.contains===cz?r.left+r.right:0)}function dz(t,e){var n=t.autosize(),r=t.padding();return e-(n&&n.contains===cz?r.top+r.bottom:0)}function pz(t,e){return e.modified&&k(e.input.value)&&!t.startsWith("_:vega:_")}function gz(t,e){return!("parent"===t||e instanceof Za.proxy)}function mz(t,e,n,r){const i=t.element();i&&i.setAttribute("title",function(t){return null==t?"":k(t)?yz(t):A(t)&&!mt(t)?(e=t,Object.keys(e).map((t=>{const n=e[t];return t+": "+(k(n)?yz(n):vz(n))})).join("\n")):t+"";var e}(r))}function yz(t){return"["+t.map(vz).join(", ")+"]"}function vz(t){return k(t)?"[…]":A(t)&&!mt(t)?"{…}":t}function _z(t,e){const n=this;if(e=e||{},Va.call(n),e.loader&&n.loader(e.loader),e.logger&&n.logger(e.logger),null!=e.logLevel&&n.logLevel(e.logLevel),e.locale||t.locale){const r=ot({},t.locale,e.locale);n.locale(Ro(r.number,r.time))}n._el=null,n._elBind=null,n._renderType=e.renderer||__.Canvas,n._scenegraph=new jy;const r=n._scenegraph.root;n._renderer=null,n._tooltip=e.tooltip||mz,n._redraw=!0,n._handler=(new vv).scene(r),n._globalCursor=!1,n._preventDefault=!1,n._timers=[],n._eventListeners=[],n._resizeListeners=[],n._eventConfig=function(t){const e=ot({defaults:{}},t),n=(t,e)=>{e.forEach((e=>{k(t[e])&&(t[e]=Bt(t[e]))}))};return n(e.defaults,["prevent","allow"]),n(e,["view","window","selector"]),e}(t.eventConfig),n.globalCursor(n._eventConfig.globalCursor);const i=function(t,e,n){return MB(t,Za,lB,n).parse(e)}(n,t,e.expr);n._runtime=i,n._signals=i.signals,n._bind=(t.bindings||[]).map((t=>({state:null,param:ot({},t)}))),i.root&&i.root.set(r),r.source=i.data.root.input,n.pulse(i.data.root.input,n.changeset().insert(r.items)),n._width=n.width(),n._height=n.height(),n._viewWidth=hz(n,n._width),n._viewHeight=dz(n,n._height),n._origin=[0,0],n._resize=0,n._autosize=1,function(t){var e=t._signals,n=e[uz],r=e[lz],i=e[cz];function o(){t._autosize=t._resize=1}t._resizeWidth=t.add(null,(e=>{t._width=e.size,t._viewWidth=hz(t,e.size),o()}),{size:n}),t._resizeHeight=t.add(null,(e=>{t._height=e.size,t._viewHeight=dz(t,e.size),o()}),{size:r});const a=t.add(null,o,{pad:i});t._resizeWidth.rank=n.rank+1,t._resizeHeight.rank=r.rank+1,a.rank=i.rank+1}(n),function(t){t.add(null,(e=>(t._background=e.bg,t._resize=1,e.bg)),{bg:t._signals.background})}(n),function(t){const e=t._signals.cursor||(t._signals.cursor=t.add({user:FB,item:null}));t.on(t.events("view","pointermove"),e,((t,n)=>{const r=e.value,i=r?xt(r)?r:r.user:FB,o=n.item&&n.item.cursor||null;return r&&i===r.user&&o==r.item?r:{user:i,item:o}})),t.add(null,(function(e){let n=e.cursor,r=this.value;return xt(n)||(r=n.item,n=n.user),SB(t,n&&n!==FB?n:r||n),r}),{cursor:e})}(n),n.description(t.description),e.hover&&n.hover(),e.container&&n.initialize(e.container,e.bind),e.watchPixelRatio&&n._watchPixelRatio()}function xz(t,e){return lt(t._signals,e)?t._signals[e]:s("Unrecognized signal name: "+Ct(e))}function bz(t,e){const n=(t._targets||[]).filter((t=>t._update&&t._update.handler===e));return n.length?n[0]:null}function wz(t,e,n,r){let i=bz(n,r);return i||(i=rz(t,(()=>r(e,n.value))),i.handler=r,t.on(n,null,i)),t}function kz(t,e,n){const r=bz(e,n);return r&&e._targets.remove(r),t}dt(_z,Va,{async evaluate(t,e,n){if(await Va.prototype.evaluate.call(this,t,e),this._redraw||this._resize)try{this._renderer&&(this._resize&&(this._resize=0,function(t){var e=NB(t),n=BB(t),r=zB(t);t._renderer.background(t.background()),t._renderer.resize(n,r,e),t._handler.origin(e),t._resizeListeners.forEach((e=>{try{e(n,r)}catch(e){t.error(e)}}))}(this)),await this._renderer.renderAsync(this._scenegraph.root)),this._redraw=!1}catch(t){this.error(t)}return n&&da(this,n),this},dirty(t){this._redraw=!0,this._renderer&&this._renderer.dirty(t)},description(t){if(arguments.length){const e=null!=t?t+"":null;return e!==this._desc&&CB(this._el,this._desc=e),this}return this._desc},container(){return this._el},scenegraph(){return this._scenegraph},origin(){return this._origin.slice()},signal(t,e,n){const r=xz(this,t);return 1===arguments.length?r.value:this.update(r,e,n)},width(t){return arguments.length?this.signal("width",t):this.signal("width")},height(t){return arguments.length?this.signal("height",t):this.signal("height")},padding(t){return arguments.length?this.signal("padding",az(t)):az(this.signal("padding"))},autosize(t){return arguments.length?this.signal("autosize",t):this.signal("autosize")},background(t){return arguments.length?this.signal("background",t):this.signal("background")},renderer(t){return arguments.length?(b_(t)||s("Unrecognized renderer type: "+t),t!==this._renderType&&(this._renderType=t,this._resetRenderer()),this):this._renderType},tooltip(t){return arguments.length?(t!==this._tooltip&&(this._tooltip=t,this._resetRenderer()),this):this._tooltip},loader(t){return arguments.length?(t!==this._loader&&(Va.prototype.loader.call(this,t),this._resetRenderer()),this):this._loader},resize(){return this._autosize=1,this.touch(xz(this,"autosize"))},_resetRenderer(){this._renderer&&(this._renderer=null,this.initialize(this._el,this._elBind))},_resizeView:function(t,e,n,r,i,o){this.runAfter((a=>{let s=0;a._autosize=0,a.width()!==n&&(s=1,a.signal(uz,n,fz),a._resizeWidth.skip(!0)),a.height()!==r&&(s=1,a.signal(lz,r,fz),a._resizeHeight.skip(!0)),a._viewWidth!==t&&(a._resize=1,a._viewWidth=t),a._viewHeight!==e&&(a._resize=1,a._viewHeight=e),a._origin[0]===i[0]&&a._origin[1]===i[1]||(a._resize=1,a._origin=i),s&&a.run("enter"),o&&a.runAfter((t=>t.resize()))}),!1,1)},addEventListener(t,e,n){let r=e;return n&&!1===n.trap||(r=rz(this,e),r.raw=e),this._handler.on(t,r),this},removeEventListener(t,e){for(var n,r,i=this._handler.handlers(t),o=i.length;--o>=0;)if(r=i[o].type,n=i[o].handler,t===r&&(e===n||e===n.raw)){this._handler.off(r,n);break}return this},addResizeListener(t){const e=this._resizeListeners;return e.includes(t)||e.push(t),this},removeResizeListener(t){var e=this._resizeListeners,n=e.indexOf(t);return n>=0&&e.splice(n,1),this},addSignalListener(t,e){return wz(this,t,xz(this,t),e)},removeSignalListener(t,e){return kz(this,xz(this,t),e)},addDataListener(t,e){return wz(this,t,$B(this,t).values,e)},removeDataListener(t,e){return kz(this,$B(this,t).values,e)},globalCursor(t){if(arguments.length){if(this._globalCursor!==!!t){const e=SB(this,null);this._globalCursor=!!t,e&&SB(this,e)}return this}return this._globalCursor},preventDefault(t){return arguments.length?(this._preventDefault=t,this):this._preventDefault},timer:function(t,e){this._timers.push(function(t,e,n){var r=new QE,i=e;return null==e?(r.restart(t,e,n),r):(r._restart=r.restart,r.restart=function(t,e,n){e=+e,n=null==n?JE():+n,r._restart((function o(a){a+=i,r._restart(o,i+=e,n),t(a)}),e,n)},r.restart(t,e,n),r)}((function(e){t({timestamp:Date.now(),elapsed:e})}),e))},events:function(t,e,n){var r,i=this,o=new Ba(n),a=function(n,r){i.runAsync(null,(()=>{t===RB&&function(t,e){var n=t._eventConfig.defaults,r=n.prevent,i=n.allow;return!1!==r&&!0!==i&&(!0===r||!1===i||(r?r[e]:i?!i[e]:t.preventDefault()))}(i,e)&&n.preventDefault(),o.receive(OB(i,n,r))}))};if("timer"===t)qB(i,"timer",e)&&i.timer(a,e);else if(t===RB)qB(i,"view",e)&&i.addEventListener(e,a,UB);else if("window"===t?qB(i,"window",e)&&"undefined"!=typeof window&&(r=[window]):"undefined"!=typeof document&&qB(i,"selector",e)&&(r=Array.from(document.querySelectorAll(t))),r){for(var s=0,u=r.length;s=0;)i[t].stop();for(t=o.length;--t>=0;)for(e=(n=o[t]).sources.length;--e>=0;)n.sources[e].removeEventListener(n.type,n.handler);return r&&r.call(this,this._handler,null,null,null),this},hover:function(t,e){return e=[e||"update",(t=[t||"hover"])[0]],this.on(this.events("view","pointerover",PB),jB,IB(t)),this.on(this.events("view","pointerout",PB),jB,IB(e)),this},data:function(t,e){return arguments.length<2?$B(this,t).values.value:TB.call(this,t,Ma().remove(p).insert(e))},change:TB,insert:function(t,e){return TB.call(this,t,Ma().insert(e))},remove:function(t,e){return TB.call(this,t,Ma().remove(e))},scale:function(t){var e=this._runtime.scales;return lt(e,t)||s("Unrecognized scale or projection: "+t),e[t].value},initialize:function(t,e){const n=this,r=n._renderType,i=n._eventConfig.bind,o=b_(r);t=n._el=t?iz(n,t,!0):null,function(t){const e=t.container();e&&(e.setAttribute("role","graphics-document"),e.setAttribute("aria-roleDescription","visualization"),CB(e,t.description()))}(n),o||n.error("Unrecognized renderer type: "+r);const a=o.handler||vv,s=t?o.renderer:o.headless;return n._renderer=s?nz(n,n._renderer,t,s):null,n._handler=function(t,e,n,r){const i=new r(t.loader(),rz(t,t.tooltip())).scene(t.scenegraph().root).initialize(n,NB(t),t);return e&&e.handlers().forEach((t=>{i.on(t.type,t.handler)})),i}(n,n._handler,t,a),n._redraw=!0,t&&"none"!==i&&(e=e?n._elBind=iz(n,e,!0):t.appendChild(WB("form",{class:"vega-bindings"})),n._bind.forEach((t=>{t.param.element&&"container"!==i&&(t.element=iz(n,t.param.element,!!t.param.input))})),n._bind.forEach((t=>{!function(t,e,n){if(!e)return;const r=n.param;let i=n.state;i||(i=n.state={elements:null,active:!1,set:null,update:e=>{e!=t.signal(r.signal)&&t.runAsync(null,(()=>{i.source=!0,t.signal(r.signal,e)}))}},r.debounce&&(i.update=it(r.debounce,i.update))),(null==r.input&&r.element?VB:XB)(i,e,r,t),i.active||(t.on(t._signals[r.signal],null,(()=>{i.source?i.source=!1:i.set(t.signal(r.signal))})),i.active=!0)}(n,t.element||e,t)}))),n},toImageURL:async function(t,e){t!==__.Canvas&&t!==__.SVG&&t!==__.PNG&&s("Unrecognized image type: "+t);const n=await sz(this,t,e);return t===__.SVG?function(t,e){const n=new Blob([t],{type:e});return window.URL.createObjectURL(n)}(n.svg(),"image/svg+xml"):n.canvas().toDataURL("image/png")},toCanvas:async function(t,e){return(await sz(this,__.Canvas,t,e)).canvas()},toSVG:async function(t){return(await sz(this,__.SVG,t)).svg()},getState:function(t){return this._runtime.getState(t||{data:pz,signals:gz,recurse:!0})},setState:function(t){return this.runAsync(null,(e=>{e._trigger=!1,e._runtime.setState(t)}),(t=>{t._trigger=!0})),this},_watchPixelRatio:function(){if("canvas"===this.renderer()&&this._renderer._canvas){let t=null;const e=()=>{null!=t&&t();const n=matchMedia(`(resolution: ${window.devicePixelRatio}dppx)`);n.addEventListener("change",e),t=()=>{n.removeEventListener("change",e)},this._renderer._canvas.getContext("2d").pixelRatio=window.devicePixelRatio||1,this._redraw=!0,this._resize=1,this.resize().runAsync()};e()}}});const Az="view",Mz="[",Ez="]",Dz="{",Cz="}",Fz=":",Sz=",",$z="@",Tz=">",Bz=/[[\]{}]/,zz={"*":1,arc:1,area:1,group:1,image:1,line:1,path:1,rect:1,rule:1,shape:1,symbol:1,text:1,trail:1};let Nz,Oz;function Rz(t,e,n){return Nz=e||Az,Oz=n||zz,Lz(t.trim()).map(qz)}function Uz(t,e,n,r,i){const o=t.length;let a,s=0;for(;e' after between selector: "+t;n=n.map(qz);const i=qz(t.slice(1).trim());if(i.between)return{between:n,stream:i};i.between=n;return i}(t):function(t){const e={source:Nz},n=[];let r,i,o=[0,0],a=0,s=0,u=t.length,l=0;if(t[u-1]===Cz){if(l=t.lastIndexOf(Dz),!(l>=0))throw"Unmatched right brace: "+t;try{o=function(t){const e=t.split(Sz);if(!t.length||e.length>2)throw t;return e.map((e=>{const n=+e;if(n!=n)throw t;return n}))}(t.substring(l+1,u-1))}catch(e){throw"Invalid throttle specification: "+t}u=(t=t.slice(0,l).trim()).length,l=0}if(!u)throw t;t[0]===$z&&(a=++l);r=Uz(t,l,Fz),r1?(e.type=n[1],a?e.markname=n[0].slice(1):!function(t){return Oz[t]}(n[0])?e.source=n[0]:e.marktype=n[0]):e.type=n[0];"!"===e.type.slice(-1)&&(e.consume=!0,e.type=e.type.slice(0,-1));null!=i&&(e.filter=i);o[0]&&(e.throttle=o[0]);o[1]&&(e.debounce=o[1]);return e}(t)}function Pz(t){return A(t)?t:{type:t||"pad"}}const jz=t=>+t||0,Iz=t=>({top:t,bottom:t,left:t,right:t});function Wz(t){return A(t)?t.signal?t:{top:jz(t.top),bottom:jz(t.bottom),left:jz(t.left),right:jz(t.right)}:Iz(jz(t))}const Hz=t=>A(t)&&!k(t)?ot({},t):{value:t};function Yz(t,e,n,r){if(null!=n){return A(n)&&!k(n)||k(n)&&n.length&&A(n[0])?t.update[e]=n:t[r||"enter"][e]={value:n},1}return 0}function Gz(t,e,n){for(const n in e)Yz(t,n,e[n]);for(const e in n)Yz(t,e,n[e],"update")}function Vz(t,e,n){for(const r in e)n&<(n,r)||(t[r]=ot(t[r]||{},e[r]));return t}function Xz(t,e){return e&&(e.enter&&e.enter[t]||e.update&&e.update[t])}const Jz="mark",Zz="frame",Qz="scope",Kz="axis",tN="axis-domain",eN="axis-grid",nN="axis-label",rN="axis-tick",iN="axis-title",oN="legend",aN="legend-band",sN="legend-entry",uN="legend-gradient",lN="legend-label",cN="legend-symbol",fN="legend-title",hN="title",dN="title-text",pN="title-subtitle";function gN(t,e,n){t[e]=n&&n.signal?{signal:n.signal}:{value:n}}const mN=t=>xt(t)?Ct(t):t.signal?`(${t.signal})`:xN(t);function yN(t){if(null!=t.gradient)return function(t){const e=[t.start,t.stop,t.count].map((t=>null==t?null:Ct(t)));for(;e.length&&null==F(e);)e.pop();return e.unshift(mN(t.gradient)),`gradient(${e.join(",")})`}(t);let e=t.signal?`(${t.signal})`:t.color?function(t){return t.c?vN("hcl",t.h,t.c,t.l):t.h||t.s?vN("hsl",t.h,t.s,t.l):t.l||t.a?vN("lab",t.l,t.a,t.b):t.r||t.g||t.b?vN("rgb",t.r,t.g,t.b):null}(t.color):null!=t.field?xN(t.field):void 0!==t.value?Ct(t.value):void 0;return null!=t.scale&&(e=function(t,e){const n=mN(t.scale);null!=t.range?e=`lerp(_range(${n}), ${+t.range})`:(void 0!==e&&(e=`_scale(${n}, ${e})`),t.band&&(e=(e?e+"+":"")+`_bandwidth(${n})`+(1==+t.band?"":"*"+_N(t.band)),t.extra&&(e=`(datum.extra ? _scale(${n}, datum.extra.value) : ${e})`)),null==e&&(e="0"));return e}(t,e)),void 0===e&&(e=null),null!=t.exponent&&(e=`pow(${e},${_N(t.exponent)})`),null!=t.mult&&(e+=`*${_N(t.mult)}`),null!=t.offset&&(e+=`+${_N(t.offset)}`),t.round&&(e=`round(${e})`),e}const vN=(t,e,n,r)=>`(${t}(${[e,n,r].map(yN).join(",")})+'')`;function _N(t){return A(t)?"("+yN(t)+")":t}function xN(t){return bN(A(t)?t:{datum:t})}function bN(t){let e,n,r;if(t.signal)e="datum",r=t.signal;else if(t.group||t.parent){for(n=Math.max(1,t.level||1),e="item";n-- >0;)e+=".mark.group";t.parent?(r=t.parent,e+=".datum"):r=t.group}else t.datum?(e="datum",r=t.datum):s("Invalid field reference: "+Ct(t));return t.signal||(r=xt(r)?u(r).map(Ct).join("]["):bN(r)),e+"["+r+"]"}function wN(t,e,n,r,i,o){const a={};(o=o||{}).encoders={$encode:a},t=function(t,e,n,r,i){const o={},a={};let s,u,l,c;for(u in u="lineBreak","text"!==e||null==i[u]||Xz(u,t)||gN(o,u,i[u]),("legend"==n||String(n).startsWith("axis"))&&(n=null),c=n===Zz?i.group:n===Jz?ot({},i.mark,i[e]):null,c)l=Xz(u,t)||("fill"===u||"stroke"===u)&&(Xz("fill",t)||Xz("stroke",t)),l||gN(o,u,c[u]);for(u in V(r).forEach((e=>{const n=i.style&&i.style[e];for(const e in n)Xz(e,t)||gN(o,e,n[e])})),t=ot({},t),o)c=o[u],c.signal?(s=s||{})[u]=c:a[u]=c;return t.enter=ot(a,t.enter),s&&(t.update=ot(s,t.update)),t}(t,e,n,r,i.config);for(const n in t)a[n]=kN(t[n],e,o,i);return o}function kN(t,e,n,r){const i={},o={};for(const e in t)null!=t[e]&&(i[e]=AN((a=t[e],k(a)?function(t){let e="";return t.forEach((t=>{const n=yN(t);e+=t.test?`(${t.test})?${n}:`:n})),":"===F(e)&&(e+="null"),e}(a):yN(a)),r,n,o));var a;return{$expr:{marktype:e,channels:i},$fields:Object.keys(o),$output:Object.keys(t)}}function AN(t,e,n,r){const i=mB(t,e);return i.$fields.forEach((t=>r[t]=1)),ot(n,i.$params),i.$expr}const MN="outer",EN=["value","update","init","react","bind"];function DN(t,e){s(t+' for "outer" push: '+Ct(e))}function CN(t,e){const n=t.name;if(t.push===MN)e.signals[n]||DN("No prior signal definition",n),EN.forEach((e=>{void 0!==t[e]&&DN("Invalid property ",e)}));else{const r=e.addSignal(n,t.value);!1===t.react&&(r.react=!1),t.bind&&e.addBinding(n,t.bind)}}function FN(t,e,n,r){this.id=-1,this.type=t,this.value=e,this.params=n,r&&(this.parent=r)}function SN(t,e,n,r){return new FN(t,e,n,r)}function $N(t,e){return SN("operator",t,e)}function TN(t){const e={$ref:t.id};return t.id<0&&(t.refs=t.refs||[]).push(e),e}function BN(t,e){return e?{$field:t,$name:e}:{$field:t}}const zN=BN("key");function NN(t,e){return{$compare:t,$order:e}}const ON="descending";function RN(t,e){return(t&&t.signal?"$"+t.signal:t||"")+(t&&e?"_":"")+(e&&e.signal?"$"+e.signal:e||"")}const UN="scope",LN="view";function qN(t){return t&&t.signal}function PN(t){if(qN(t))return!0;if(A(t))for(const e in t)if(PN(t[e]))return!0;return!1}function jN(t,e){return null!=t?t:e}function IN(t){return t&&t.signal||t}const WN="timer";function HN(t,e){return(t.merge?YN:t.stream?GN:t.type?VN:s("Invalid stream specification: "+Ct(t)))(t,e)}function YN(t,e){const n=XN({merge:t.merge.map((t=>HN(t,e)))},t,e);return e.addStream(n).id}function GN(t,e){const n=XN({stream:HN(t.stream,e)},t,e);return e.addStream(n).id}function VN(t,e){let n;t.type===WN?(n=e.event(WN,t.throttle),t={between:t.between,filter:t.filter}):n=e.event(function(t){return t===UN?LN:t||LN}(t.source),t.type);const r=XN({stream:n},t,e);return 1===Object.keys(r).length?n:e.addStream(r).id}function XN(t,e,n){let r=e.between;return r&&(2!==r.length&&s('Stream "between" parameter must have 2 entries: '+Ct(e)),t.between=[HN(r[0],n),HN(r[1],n)]),r=e.filter?[].concat(e.filter):[],(e.marktype||e.markname||e.markrole)&&r.push(function(t,e,n){const r="event.item";return r+(t&&"*"!==t?"&&"+r+".mark.marktype==='"+t+"'":"")+(n?"&&"+r+".mark.role==='"+n+"'":"")+(e?"&&"+r+".mark.name==='"+e+"'":"")}(e.marktype,e.markname,e.markrole)),e.source===UN&&r.push("inScope(event.item)"),r.length&&(t.filter=mB("("+r.join(")&&(")+")",n).$expr),null!=(r=e.throttle)&&(t.throttle=+r),null!=(r=e.debounce)&&(t.debounce=+r),e.consume&&(t.consume=!0),t}const JN={code:"_.$value",ast:{type:"Identifier",value:"value"}};function ZN(t,e,n){const r=t.encode,i={target:n};let o=t.events,a=t.update,u=[];o||s("Signal update missing events specification."),xt(o)&&(o=Rz(o,e.isSubscope()?UN:LN)),o=V(o).filter((t=>t.signal||t.scale?(u.push(t),0):1)),u.length>1&&(u=[QN(u)]),o.length&&u.push(o.length>1?{merge:o}:o[0]),null!=r&&(a&&s("Signal encode and update are mutually exclusive."),a="encode(item(),"+Ct(r)+")"),i.update=xt(a)?mB(a,e):null!=a.expr?mB(a.expr,e):null!=a.value?a.value:null!=a.signal?{$expr:JN,$params:{$value:e.signalRef(a.signal)}}:s("Invalid signal update specification."),t.force&&(i.options={force:!0}),u.forEach((t=>e.addUpdate(ot(function(t,e){return{source:t.signal?e.signalRef(t.signal):t.scale?e.scaleRef(t.scale):HN(t,e)}}(t,e),i))))}function QN(t){return{signal:"["+t.map((t=>t.scale?'scale("'+t.scale+'")':t.signal))+"]"}}const KN=t=>(e,n,r)=>SN(t,n,e||void 0,r),tO=KN("aggregate"),eO=KN("axisticks"),nO=KN("bound"),rO=KN("collect"),iO=KN("compare"),oO=KN("datajoin"),aO=KN("encode"),sO=KN("expression"),uO=KN("facet"),lO=KN("field"),cO=KN("key"),fO=KN("legendentries"),hO=KN("load"),dO=KN("mark"),pO=KN("multiextent"),gO=KN("multivalues"),mO=KN("overlap"),yO=KN("params"),vO=KN("prefacet"),_O=KN("projection"),xO=KN("proxy"),bO=KN("relay"),wO=KN("render"),kO=KN("scale"),AO=KN("sieve"),MO=KN("sortitems"),EO=KN("viewlayout"),DO=KN("values");let CO=0;const FO={min:"min",max:"max",count:"sum"};function SO(t,e){const n=e.getScale(t.name).params;let r;for(r in n.domain=zO(t.domain,t,e),null!=t.range&&(n.range=jO(t,e,n)),null!=t.interpolate&&function(t,e){e.interpolate=$O(t.type||t),null!=t.gamma&&(e.interpolateGamma=$O(t.gamma))}(t.interpolate,n),null!=t.nice&&(n.nice=function(t){return A(t)?{interval:$O(t.interval),step:$O(t.step)}:$O(t)}(t.nice)),null!=t.bins&&(n.bins=function(t,e){return t.signal||k(t)?TO(t,e):e.objectProperty(t)}(t.bins,e)),t)lt(n,r)||"name"===r||(n[r]=$O(t[r],e))}function $O(t,e){return A(t)?t.signal?e.signalRef(t.signal):s("Unsupported object: "+Ct(t)):t}function TO(t,e){return t.signal?e.signalRef(t.signal):t.map((t=>$O(t,e)))}function BO(t){s("Can not find data set: "+Ct(t))}function zO(t,e,n){if(t)return t.signal?n.signalRef(t.signal):(k(t)?NO:t.fields?RO:OO)(t,e,n);null==e.domainMin&&null==e.domainMax||s("No scale domain defined for domainMin/domainMax to override.")}function NO(t,e,n){return t.map((t=>$O(t,n)))}function OO(t,e,n){const r=n.getData(t.data);return r||BO(t.data),Kd(e.type)?r.valuesRef(n,t.field,LO(t.sort,!1)):rp(e.type)?r.domainRef(n,t.field):r.extentRef(n,t.field)}function RO(t,e,n){const r=t.data,i=t.fields.reduce(((t,e)=>(e=xt(e)?{data:r,field:e}:k(e)||e.signal?function(t,e){const n="_:vega:_"+CO++,r=rO({});if(k(t))r.value={$ingest:t};else if(t.signal){const i="setdata("+Ct(n)+","+t.signal+")";r.params.input=e.signalRef(i)}return e.addDataPipeline(n,[r,AO({})]),{data:n,field:"data"}}(e,n):e,t.push(e),t)),[]);return(Kd(e.type)?UO:rp(e.type)?qO:PO)(t,n,i)}function UO(t,e,n){const r=LO(t.sort,!0);let i,o;const a=n.map((t=>{const n=e.getData(t.data);return n||BO(t.data),n.countsRef(e,t.field,r)})),s={groupby:zN,pulse:a};r&&(i=r.op||"count",o=r.field?RN(i,r.field):"count",s.ops=[FO[i]],s.fields=[e.fieldRef(o)],s.as=[o]),i=e.add(tO(s));const u=e.add(rO({pulse:TN(i)}));return o=e.add(DO({field:zN,sort:e.sortRef(r),pulse:TN(u)})),TN(o)}function LO(t,e){return t&&(t.field||t.op?t.field||"count"===t.op?e&&t.field&&t.op&&!FO[t.op]&&s("Multiple domain scales can not be sorted using "+t.op):s("No field provided for sort aggregate op: "+t.op):A(t)?t.field="key":t={field:"key"}),t}function qO(t,e,n){const r=n.map((t=>{const n=e.getData(t.data);return n||BO(t.data),n.domainRef(e,t.field)}));return TN(e.add(gO({values:r})))}function PO(t,e,n){const r=n.map((t=>{const n=e.getData(t.data);return n||BO(t.data),n.extentRef(e,t.field)}));return TN(e.add(pO({extents:r})))}function jO(t,e,n){const r=e.config.range;let i=t.range;if(i.signal)return e.signalRef(i.signal);if(xt(i)){if(r&<(r,i))return jO(t=ot({},t,{range:r[i]}),e,n);"width"===i?i=[0,{signal:"width"}]:"height"===i?i=Kd(t.type)?[0,{signal:"height"}]:[{signal:"height"},0]:s("Unrecognized scale range value: "+Ct(i))}else{if(i.scheme)return n.scheme=k(i.scheme)?TO(i.scheme,e):$O(i.scheme,e),i.extent&&(n.schemeExtent=TO(i.extent,e)),void(i.count&&(n.schemeCount=$O(i.count,e)));if(i.step)return void(n.rangeStep=$O(i.step,e));if(Kd(t.type)&&!k(i))return zO(i,t,e);k(i)||s("Unsupported range type: "+Ct(i))}return i.map((t=>(k(t)?TO:$O)(t,e)))}function IO(t,e,n){return k(t)?t.map((t=>IO(t,e,n))):A(t)?t.signal?n.signalRef(t.signal):"fit"===e?t:s("Unsupported parameter object: "+Ct(t)):t}const WO="top",HO="left",YO="right",GO="bottom",VO="center",XO="vertical",JO="start",ZO="end",QO="index",KO="label",tR="offset",eR="perc",nR="perc2",rR="value",iR="guide-label",oR="guide-title",aR="group-title",sR="group-subtitle",uR="symbol",lR="gradient",cR="discrete",fR="size",hR=[fR,"shape","fill","stroke","strokeWidth","strokeDash","opacity"],dR={name:1,style:1,interactive:1},pR={value:0},gR={value:1},mR="group",yR="rect",vR="rule",_R="symbol",xR="text";function bR(t){return t.type=mR,t.interactive=t.interactive||!1,t}function wR(t,e){const n=(n,r)=>jN(t[n],jN(e[n],r));return n.isVertical=n=>XO===jN(t.direction,e.direction||(n?e.symbolDirection:e.gradientDirection)),n.gradientLength=()=>jN(t.gradientLength,e.gradientLength||e.gradientWidth),n.gradientThickness=()=>jN(t.gradientThickness,e.gradientThickness||e.gradientHeight),n.entryColumns=()=>jN(t.columns,jN(e.columns,+n.isVertical(!0))),n}function kR(t,e){const n=e&&(e.update&&e.update[t]||e.enter&&e.enter[t]);return n&&n.signal?n:n?n.value:null}function AR(t,e,n){return`item.anchor === '${JO}' ? ${t} : item.anchor === '${ZO}' ? ${e} : ${n}`}const MR=AR(Ct(HO),Ct(YO),Ct(VO));function ER(t,e){return e?t?A(t)?Object.assign({},t,{offset:ER(t.offset,e)}):{value:t,offset:e}:e:t}function DR(t,e){return e?(t.name=e.name,t.style=e.style||t.style,t.interactive=!!e.interactive,t.encode=Vz(t.encode,e,dR)):t.interactive=!1,t}function CR(t,e,n,r){const i=wR(t,n),o=i.isVertical(),a=i.gradientThickness(),s=i.gradientLength();let u,l,c,f,h;o?(l=[0,1],c=[0,0],f=a,h=s):(l=[0,0],c=[1,0],f=s,h=a);const d={enter:u={opacity:pR,x:pR,y:pR,width:Hz(f),height:Hz(h)},update:ot({},u,{opacity:gR,fill:{gradient:e,start:l,stop:c}}),exit:{opacity:pR}};return Gz(d,{stroke:i("gradientStrokeColor"),strokeWidth:i("gradientStrokeWidth")},{opacity:i("gradientOpacity")}),DR({type:yR,role:uN,encode:d},r)}function FR(t,e,n,r,i){const o=wR(t,n),a=o.isVertical(),s=o.gradientThickness(),u=o.gradientLength();let l,c,f,h,d="";a?(l="y",f="y2",c="x",h="width",d="1-"):(l="x",f="x2",c="y",h="height");const p={opacity:pR,fill:{scale:e,field:rR}};p[l]={signal:d+"datum."+eR,mult:u},p[c]=pR,p[f]={signal:d+"datum."+nR,mult:u},p[h]=Hz(s);const g={enter:p,update:ot({},p,{opacity:gR}),exit:{opacity:pR}};return Gz(g,{stroke:o("gradientStrokeColor"),strokeWidth:o("gradientStrokeWidth")},{opacity:o("gradientOpacity")}),DR({type:yR,role:aN,key:rR,from:i,encode:g},r)}const SR=`datum.${eR}<=0?"${HO}":datum.${eR}>=1?"${YO}":"${VO}"`,$R=`datum.${eR}<=0?"${GO}":datum.${eR}>=1?"${WO}":"middle"`;function TR(t,e,n,r){const i=wR(t,e),o=i.isVertical(),a=Hz(i.gradientThickness()),s=i.gradientLength();let u,l,c,f,h=i("labelOverlap"),d="";const p={enter:u={opacity:pR},update:l={opacity:gR,text:{field:KO}},exit:{opacity:pR}};return Gz(p,{fill:i("labelColor"),fillOpacity:i("labelOpacity"),font:i("labelFont"),fontSize:i("labelFontSize"),fontStyle:i("labelFontStyle"),fontWeight:i("labelFontWeight"),limit:jN(t.labelLimit,e.gradientLabelLimit)}),o?(u.align={value:"left"},u.baseline=l.baseline={signal:$R},c="y",f="x",d="1-"):(u.align=l.align={signal:SR},u.baseline={value:"top"},c="x",f="y"),u[c]=l[c]={signal:d+"datum."+eR,mult:s},u[f]=l[f]=a,a.offset=jN(t.labelOffset,e.gradientLabelOffset)||0,h=h?{separation:i("labelSeparation"),method:h,order:"datum."+QO}:void 0,DR({type:xR,role:lN,style:iR,key:rR,from:r,encode:p,overlap:h},n)}function BR(t,e,n,r,i){const o=wR(t,e),a=n.entries,s=!(!a||!a.interactive),u=a?a.name:void 0,l=o("clipHeight"),c=o("symbolOffset"),f={data:"value"},h=`(${i}) ? datum.${tR} : datum.${fR}`,d=l?Hz(l):{field:fR},p=`datum.${QO}`,g=`max(1, ${i})`;let m,y,v,_,x;d.mult=.5,m={enter:y={opacity:pR,x:{signal:h,mult:.5,offset:c},y:d},update:v={opacity:gR,x:y.x,y:y.y},exit:{opacity:pR}};let b=null,w=null;t.fill||(b=e.symbolBaseFillColor,w=e.symbolBaseStrokeColor),Gz(m,{fill:o("symbolFillColor",b),shape:o("symbolType"),size:o("symbolSize"),stroke:o("symbolStrokeColor",w),strokeDash:o("symbolDash"),strokeDashOffset:o("symbolDashOffset"),strokeWidth:o("symbolStrokeWidth")},{opacity:o("symbolOpacity")}),hR.forEach((e=>{t[e]&&(v[e]=y[e]={scale:t[e],field:rR})}));const k=DR({type:_R,role:cN,key:rR,from:f,clip:!!l||void 0,encode:m},n.symbols),A=Hz(c);A.offset=o("labelOffset"),m={enter:y={opacity:pR,x:{signal:h,offset:A},y:d},update:v={opacity:gR,text:{field:KO},x:y.x,y:y.y},exit:{opacity:pR}},Gz(m,{align:o("labelAlign"),baseline:o("labelBaseline"),fill:o("labelColor"),fillOpacity:o("labelOpacity"),font:o("labelFont"),fontSize:o("labelFontSize"),fontStyle:o("labelFontStyle"),fontWeight:o("labelFontWeight"),limit:o("labelLimit")});const M=DR({type:xR,role:lN,style:iR,key:rR,from:f,encode:m},n.labels);return m={enter:{noBound:{value:!l},width:pR,height:l?Hz(l):pR,opacity:pR},exit:{opacity:pR},update:v={opacity:gR,row:{signal:null},column:{signal:null}}},o.isVertical(!0)?(_=`ceil(item.mark.items.length / ${g})`,v.row.signal=`${p}%${_}`,v.column.signal=`floor(${p} / ${_})`,x={field:["row",p]}):(v.row.signal=`floor(${p} / ${g})`,v.column.signal=`${p} % ${g}`,x={field:p}),v.column.signal=`(${i})?${v.column.signal}:${p}`,bR({role:Qz,from:r={facet:{data:r,name:"value",groupby:QO}},encode:Vz(m,a,dR),marks:[k,M],name:u,interactive:s,sort:x})}const zR='item.orient === "left"',NR='item.orient === "right"',OR=`(${zR} || ${NR})`,RR=`datum.vgrad && ${OR}`,UR=AR('"top"','"bottom"','"middle"'),LR=`datum.vgrad && ${NR} ? (${AR('"right"','"left"','"center"')}) : (${OR} && !(datum.vgrad && ${zR})) ? "left" : ${MR}`,qR=`item._anchor || (${OR} ? "middle" : "start")`,PR=`${RR} ? (${zR} ? -90 : 90) : 0`,jR=`${OR} ? (datum.vgrad ? (${NR} ? "bottom" : "top") : ${UR}) : "top"`;function IR(t,e){let n;return A(t)&&(t.signal?n=t.signal:t.path?n="pathShape("+WR(t.path)+")":t.sphere&&(n="geoShape("+WR(t.sphere)+', {type: "Sphere"})')),n?e.signalRef(n):!!t}function WR(t){return A(t)&&t.signal?t.signal:Ct(t)}function HR(t){const e=t.role||"";return e.startsWith("axis")||e.startsWith("legend")||e.startsWith("title")?e:t.type===mR?Qz:e||Jz}function YR(t){return{marktype:t.type,name:t.name||void 0,role:t.role||HR(t),zindex:+t.zindex||void 0,aria:t.aria,description:t.description}}function GR(t,e){return t&&t.signal?e.signalRef(t.signal):!1!==t}function VR(t,e){const n=Qa(t.type);n||s("Unrecognized transform type: "+Ct(t.type));const r=SN(n.type.toLowerCase(),null,XR(n,t,e));return t.signal&&e.addSignal(t.signal,e.proxy(r)),r.metadata=n.metadata||{},r}function XR(t,e,n){const r={},i=t.params.length;for(let o=0;oQR(t,e,n)))):QR(t,r,n)}(t,e,n):"projection"===r?n.projectionRef(e[t.name]):t.array&&!qN(i)?i.map((e=>ZR(t,e,n))):ZR(t,i,n):void(t.required&&s("Missing required "+Ct(e.type)+" parameter: "+Ct(t.name)))}function ZR(t,e,n){const r=t.type;if(qN(e))return nU(r)?s("Expression references can not be signals."):rU(r)?n.fieldRef(e):iU(r)?n.compareRef(e):n.signalRef(e.signal);{const i=t.expr||rU(r);return i&&KR(e)?n.exprRef(e.expr,e.as):i&&tU(e)?BN(e.field,e.as):nU(r)?mB(e,n):eU(r)?TN(n.getData(e).values):rU(r)?BN(e):iU(r)?n.compareRef(e):e}}function QR(t,e,n){const r=t.params.length;let i;for(let n=0;nt&&t.expr,tU=t=>t&&t.field,eU=t=>"data"===t,nU=t=>"expr"===t,rU=t=>"field"===t,iU=t=>"compare"===t;function oU(t,e){return t.$ref?t:t.data&&t.data.$ref?t.data:TN(e.getData(t.data).output)}function aU(t,e,n,r,i){this.scope=t,this.input=e,this.output=n,this.values=r,this.aggregate=i,this.index={}}function sU(t){return xt(t)?t:null}function uU(t,e,n){const r=RN(n.op,n.field);let i;if(e.ops){for(let t=0,n=e.as.length;tnull==t?"null":t)).join(",")+"),0)",e);u.update=l.$expr,u.params=l.$params}function fU(t,e){const n=HR(t),r=t.type===mR,i=t.from&&t.from.facet,o=t.overlap;let a,u,l,c,f,h,d,p=t.layout||n===Qz||n===Zz;const g=n===Jz||p||i,m=function(t,e,n){let r,i,o,a,u;return t?(r=t.facet)&&(e||s("Only group marks can be faceted."),null!=r.field?a=u=oU(r,n):(t.data?u=TN(n.getData(t.data).aggregate):(o=VR(ot({type:"aggregate",groupby:V(r.groupby)},r.aggregate),n),o.params.key=n.keyRef(r.groupby),o.params.pulse=oU(r,n),a=u=TN(n.add(o))),i=n.keyRef(r.groupby,!0))):a=TN(n.add(rO(null,[{}]))),a||(a=oU(t,n)),{key:i,pulse:a,parent:u}}(t.from,r,e);u=e.add(oO({key:m.key||(t.key?BN(t.key):void 0),pulse:m.pulse,clean:!r}));const y=TN(u);u=l=e.add(rO({pulse:y})),u=e.add(dO({markdef:YR(t),interactive:GR(t.interactive,e),clip:IR(t.clip,e),context:{$context:!0},groups:e.lookup(),parent:e.signals.parent?e.signalRef("parent"):null,index:e.markpath(),pulse:TN(u)}));const v=TN(u);u=c=e.add(aO(wN(t.encode,t.type,n,t.style,e,{mod:!1,pulse:v}))),u.params.parent=e.encode(),t.transform&&t.transform.forEach((t=>{const n=VR(t,e),r=n.metadata;(r.generates||r.changes)&&s("Mark transforms should not generate new data."),r.nomod||(c.params.mod=!0),n.params.pulse=TN(u),e.add(u=n)})),t.sort&&(u=e.add(MO({sort:e.compareRef(t.sort),pulse:TN(u)})));const _=TN(u);(i||p)&&(p=e.add(EO({layout:e.objectProperty(t.layout),legends:e.legends,mark:v,pulse:_})),h=TN(p));const x=e.add(nO({mark:v,pulse:h||_}));d=TN(x),r&&(g&&(a=e.operators,a.pop(),p&&a.pop()),e.pushState(_,h||d,y),i?function(t,e,n){const r=t.from.facet,i=r.name,o=oU(r,e);let a;r.name||s("Facet must have a name: "+Ct(r)),r.data||s("Facet must reference a data set: "+Ct(r)),r.field?a=e.add(vO({field:e.fieldRef(r.field),pulse:o})):r.groupby?a=e.add(uO({key:e.keyRef(r.groupby),group:TN(e.proxy(n.parent)),pulse:o})):s("Facet must specify groupby or field: "+Ct(r));const u=e.fork(),l=u.add(rO()),c=u.add(AO({pulse:TN(l)}));u.addData(i,new aU(u,l,l,c)),u.addSignal("parent",null),a.params.subflow={$subflow:u.parse(t).toRuntime()}}(t,e,m):g?function(t,e,n){const r=e.add(vO({pulse:n.pulse})),i=e.fork();i.add(AO()),i.addSignal("parent",null),r.params.subflow={$subflow:i.parse(t).toRuntime()}}(t,e,m):e.parse(t),e.popState(),g&&(p&&a.push(p),a.push(x))),o&&(d=function(t,e,n){const r=t.method,i=t.bound,o=t.separation,a={separation:qN(o)?n.signalRef(o.signal):o,method:qN(r)?n.signalRef(r.signal):r,pulse:e};t.order&&(a.sort=n.compareRef({field:t.order}));if(i){const t=i.tolerance;a.boundTolerance=qN(t)?n.signalRef(t.signal):+t,a.boundScale=n.scaleRef(i.scale),a.boundOrient=i.orient}return TN(n.add(mO(a)))}(o,d,e));const b=e.add(wO({pulse:d})),w=e.add(AO({pulse:TN(b)},void 0,e.parent()));null!=t.name&&(f=t.name,e.addData(f,new aU(e,l,b,w)),t.on&&t.on.forEach((t=>{(t.insert||t.remove||t.toggle)&&s("Marks only support modify triggers."),cU(t,e,f)})))}function hU(t,e){const n=e.config.legend,r=t.encode||{},i=wR(t,n),o=r.legend||{},a=o.name||void 0,u=o.interactive,l=o.style,c={};let f,h,d,p=0;hR.forEach((e=>t[e]?(c[e]=t[e],p=p||t[e]):0)),p||s("Missing valid scale for legend.");const g=function(t,e){let n=t.type||uR;t.type||1!==function(t){return hR.reduce(((e,n)=>e+(t[n]?1:0)),0)}(t)||!t.fill&&!t.stroke||(n=Qd(e)?lR:tp(e)?cR:uR);return n!==lR?n:tp(e)?cR:lR}(t,e.scaleType(p)),m={title:null!=t.title,scales:c,type:g,vgrad:"symbol"!==g&&i.isVertical()},y=TN(e.add(rO(null,[m]))),v=TN(e.add(fO(h={type:g,scale:e.scaleRef(p),count:e.objectProperty(i("tickCount")),limit:e.property(i("symbolLimit")),values:e.objectProperty(t.values),minstep:e.property(t.tickMinStep),formatType:e.property(t.formatType),formatSpecifier:e.property(t.format)})));return g===lR?(d=[CR(t,p,n,r.gradient),TR(t,n,r.labels,v)],h.count=h.count||e.signalRef(`max(2,2*floor((${IN(i.gradientLength())})/100))`)):g===cR?d=[FR(t,p,n,r.gradient,v),TR(t,n,r.labels,v)]:(f=function(t,e){const n=wR(t,e);return{align:n("gridAlign"),columns:n.entryColumns(),center:{row:!0,column:!1},padding:{row:n("rowPadding"),column:n("columnPadding")}}}(t,n),d=[BR(t,n,r,v,IN(f.columns))],h.size=function(t,e,n){const r=IN(pU("size",t,n)),i=IN(pU("strokeWidth",t,n)),o=IN(function(t,e,n){return kR("fontSize",t)||function(t,e,n){const r=e.config.style[n];return r&&r[t]}("fontSize",e,n)}(n[1].encode,e,iR));return mB(`max(ceil(sqrt(${r})+${i}),${o})`,e)}(t,e,d[0].marks)),d=[bR({role:sN,from:y,encode:{enter:{x:{value:0},y:{value:0}}},marks:d,layout:f,interactive:u})],m.title&&d.push(function(t,e,n,r){const i=wR(t,e),o={enter:{opacity:pR},update:{opacity:gR,x:{field:{group:"padding"}},y:{field:{group:"padding"}}},exit:{opacity:pR}};return Gz(o,{orient:i("titleOrient"),_anchor:i("titleAnchor"),anchor:{signal:qR},angle:{signal:PR},align:{signal:LR},baseline:{signal:jR},text:t.title,fill:i("titleColor"),fillOpacity:i("titleOpacity"),font:i("titleFont"),fontSize:i("titleFontSize"),fontStyle:i("titleFontStyle"),fontWeight:i("titleFontWeight"),limit:i("titleLimit"),lineHeight:i("titleLineHeight")},{align:i("titleAlign"),baseline:i("titleBaseline")}),DR({type:xR,role:fN,style:oR,from:r,encode:o},n)}(t,n,r.title,y)),fU(bR({role:oN,from:y,encode:Vz(dU(i,t,n),o,dR),marks:d,aria:i("aria"),description:i("description"),zindex:i("zindex"),name:a,interactive:u,style:l}),e)}function dU(t,e,n){const r={enter:{},update:{}};return Gz(r,{orient:t("orient"),offset:t("offset"),padding:t("padding"),titlePadding:t("titlePadding"),cornerRadius:t("cornerRadius"),fill:t("fillColor"),stroke:t("strokeColor"),strokeWidth:n.strokeWidth,strokeDash:n.strokeDash,x:t("legendX"),y:t("legendY"),format:e.format,formatType:e.formatType}),r}function pU(t,e,n){return e[t]?`scale("${e[t]}",datum)`:kR(t,n[0].encode)}aU.fromEntries=function(t,e){const n=e.length,r=e[n-1],i=e[n-2];let o=e[0],a=null,s=1;for(o&&"load"===o.type&&(o=e[1]),t.add(e[0]);s{n.push(VR(t,e))})),t.on&&t.on.forEach((n=>{cU(n,e,t.name)})),e.addDataPipeline(t.name,function(t,e,n){const r=[];let i,o,a,s,u,l=null,c=!1,f=!1;t.values?qN(t.values)||PN(t.format)?(r.push(xU(e,t)),r.push(l=_U())):r.push(l=_U({$ingest:t.values,$format:t.format})):t.url?PN(t.url)||PN(t.format)?(r.push(xU(e,t)),r.push(l=_U())):r.push(l=_U({$request:t.url,$format:t.format})):t.source&&(l=i=V(t.source).map((t=>TN(e.getData(t).output))),r.push(null));for(o=0,a=n.length;ot===GO||t===WO,wU=(t,e,n)=>qN(t)?FU(t.signal,e,n):t===HO||t===WO?e:n,kU=(t,e,n)=>qN(t)?DU(t.signal,e,n):bU(t)?e:n,AU=(t,e,n)=>qN(t)?CU(t.signal,e,n):bU(t)?n:e,MU=(t,e,n)=>qN(t)?SU(t.signal,e,n):t===WO?{value:e}:{value:n},EU=(t,e,n)=>qN(t)?$U(t.signal,e,n):t===YO?{value:e}:{value:n},DU=(t,e,n)=>TU(`${t} === '${WO}' || ${t} === '${GO}'`,e,n),CU=(t,e,n)=>TU(`${t} !== '${WO}' && ${t} !== '${GO}'`,e,n),FU=(t,e,n)=>zU(`${t} === '${HO}' || ${t} === '${WO}'`,e,n),SU=(t,e,n)=>zU(`${t} === '${WO}'`,e,n),$U=(t,e,n)=>zU(`${t} === '${YO}'`,e,n),TU=(t,e,n)=>(e=null!=e?Hz(e):e,n=null!=n?Hz(n):n,BU(e)&&BU(n)?{signal:`${t} ? (${e=e?e.signal||Ct(e.value):null}) : (${n=n?n.signal||Ct(n.value):null})`}:[ot({test:t},e)].concat(n||[])),BU=t=>null==t||1===Object.keys(t).length,zU=(t,e,n)=>({signal:`${t} ? (${OU(e)}) : (${OU(n)})`}),NU=(t,e,n,r,i)=>({signal:(null!=r?`${t} === '${HO}' ? (${OU(r)}) : `:"")+(null!=n?`${t} === '${GO}' ? (${OU(n)}) : `:"")+(null!=i?`${t} === '${YO}' ? (${OU(i)}) : `:"")+(null!=e?`${t} === '${WO}' ? (${OU(e)}) : `:"")+"(null)"}),OU=t=>qN(t)?t.signal:null==t?null:Ct(t),RU=(t,e)=>0===e?0:qN(t)?{signal:`(${t.signal}) * ${e}`}:{value:t*e},UU=(t,e)=>{const n=t.signal;return n&&n.endsWith("(null)")?{signal:n.slice(0,-6)+e.signal}:t};function LU(t,e,n,r){let i;if(e&<(e,t))return e[t];if(lt(n,t))return n[t];if(t.startsWith("title")){switch(t){case"titleColor":i="fill";break;case"titleFont":case"titleFontSize":case"titleFontWeight":i=t[5].toLowerCase()+t.slice(6)}return r[oR][i]}if(t.startsWith("label")){switch(t){case"labelColor":i="fill";break;case"labelFont":case"labelFontSize":i=t[5].toLowerCase()+t.slice(6)}return r[iR][i]}return null}function qU(t){const e={};for(const n of t)if(n)for(const t in n)e[t]=1;return Object.keys(e)}function PU(t,e){return{scale:t.scale,range:e}}function jU(t,e,n,r,i){const o=wR(t,e),a=t.orient,s=t.gridScale,u=wU(a,1,-1),l=function(t,e){if(1===e);else if(A(t)){let n=t=ot({},t);for(;null!=n.mult;){if(!A(n.mult))return n.mult=qN(e)?{signal:`(${n.mult}) * (${e.signal})`}:n.mult*e,t;n=n.mult=ot({},n.mult)}n.mult=e}else t=qN(e)?{signal:`(${e.signal}) * (${t||0})`}:e*(t||0);return t}(t.offset,u);let c,f,h;const d={enter:c={opacity:pR},update:h={opacity:gR},exit:f={opacity:pR}};Gz(d,{stroke:o("gridColor"),strokeCap:o("gridCap"),strokeDash:o("gridDash"),strokeDashOffset:o("gridDashOffset"),strokeOpacity:o("gridOpacity"),strokeWidth:o("gridWidth")});const p={scale:t.scale,field:rR,band:i.band,extra:i.extra,offset:i.offset,round:o("tickRound")},g=kU(a,{signal:"height"},{signal:"width"}),m=s?{scale:s,range:0,mult:u,offset:l}:{value:0,offset:l},y=s?{scale:s,range:1,mult:u,offset:l}:ot(g,{mult:u,offset:l});return c.x=h.x=kU(a,p,m),c.y=h.y=AU(a,p,m),c.x2=h.x2=AU(a,y),c.y2=h.y2=kU(a,y),f.x=kU(a,p),f.y=AU(a,p),DR({type:vR,role:eN,key:rR,from:r,encode:d},n)}function IU(t,e,n,r,i){return{signal:'flush(range("'+t+'"), scale("'+t+'", datum.value), '+e+","+n+","+r+","+i+")"}}function WU(t,e,n,r){const i=wR(t,e),o=t.orient,a=wU(o,-1,1);let s,u;const l={enter:s={opacity:pR,anchor:Hz(i("titleAnchor",null)),align:{signal:MR}},update:u=ot({},s,{opacity:gR,text:Hz(t.title)}),exit:{opacity:pR}},c={signal:`lerp(range("${t.scale}"), ${AR(0,1,.5)})`};return u.x=kU(o,c),u.y=AU(o,c),s.angle=kU(o,pR,RU(a,90)),s.baseline=kU(o,MU(o,GO,WO),{value:GO}),u.angle=s.angle,u.baseline=s.baseline,Gz(l,{fill:i("titleColor"),fillOpacity:i("titleOpacity"),font:i("titleFont"),fontSize:i("titleFontSize"),fontStyle:i("titleFontStyle"),fontWeight:i("titleFontWeight"),limit:i("titleLimit"),lineHeight:i("titleLineHeight")},{align:i("titleAlign"),angle:i("titleAngle"),baseline:i("titleBaseline")}),function(t,e,n,r){const i=(t,e)=>null!=t?(n.update[e]=UU(Hz(t),n.update[e]),!1):!Xz(e,r),o=i(t("titleX"),"x"),a=i(t("titleY"),"y");n.enter.auto=a===o?Hz(a):kU(e,Hz(a),Hz(o))}(i,o,l,n),l.update.align=UU(l.update.align,s.align),l.update.angle=UU(l.update.angle,s.angle),l.update.baseline=UU(l.update.baseline,s.baseline),DR({type:xR,role:iN,style:oR,from:r,encode:l},n)}function HU(t,e){const n=function(t,e){var n,r,i,o=e.config,a=o.style,s=o.axis,u="band"===e.scaleType(t.scale)&&o.axisBand,l=t.orient;if(qN(l)){const t=qU([o.axisX,o.axisY]),e=qU([o.axisTop,o.axisBottom,o.axisLeft,o.axisRight]);for(i of(n={},t))n[i]=kU(l,LU(i,o.axisX,s,a),LU(i,o.axisY,s,a));for(i of(r={},e))r[i]=NU(l.signal,LU(i,o.axisTop,s,a),LU(i,o.axisBottom,s,a),LU(i,o.axisLeft,s,a),LU(i,o.axisRight,s,a))}else n=l===WO||l===GO?o.axisX:o.axisY,r=o["axis"+l[0].toUpperCase()+l.slice(1)];return n||r||u?ot({},s,n,r,u):s}(t,e),r=t.encode||{},i=r.axis||{},o=i.name||void 0,a=i.interactive,s=i.style,u=wR(t,n),l=function(t){const e=t("tickBand");let n,r,i=t("tickOffset");return e?e.signal?(n={signal:`(${e.signal}) === 'extent' ? 1 : 0.5`},r={signal:`(${e.signal}) === 'extent'`},A(i)||(i={signal:`(${e.signal}) === 'extent' ? 0 : ${i}`})):"extent"===e?(n=1,r=!0,i=0):(n=.5,r=!1):(n=t("bandPosition"),r=t("tickExtra")),{extra:r,band:n,offset:i}}(u),c={scale:t.scale,ticks:!!u("ticks"),labels:!!u("labels"),grid:!!u("grid"),domain:!!u("domain"),title:null!=t.title},f=TN(e.add(rO({},[c]))),h=TN(e.add(eO({scale:e.scaleRef(t.scale),extra:e.property(l.extra),count:e.objectProperty(t.tickCount),values:e.objectProperty(t.values),minstep:e.property(t.tickMinStep),formatType:e.property(t.formatType),formatSpecifier:e.property(t.format)}))),d=[];let p;return c.grid&&d.push(jU(t,n,r.grid,h,l)),c.ticks&&(p=u("tickSize"),d.push(function(t,e,n,r,i,o){const a=wR(t,e),s=t.orient,u=wU(s,-1,1);let l,c,f;const h={enter:l={opacity:pR},update:f={opacity:gR},exit:c={opacity:pR}};Gz(h,{stroke:a("tickColor"),strokeCap:a("tickCap"),strokeDash:a("tickDash"),strokeDashOffset:a("tickDashOffset"),strokeOpacity:a("tickOpacity"),strokeWidth:a("tickWidth")});const d=Hz(i);d.mult=u;const p={scale:t.scale,field:rR,band:o.band,extra:o.extra,offset:o.offset,round:a("tickRound")};return f.y=l.y=kU(s,pR,p),f.y2=l.y2=kU(s,d),c.x=kU(s,p),f.x=l.x=AU(s,pR,p),f.x2=l.x2=AU(s,d),c.y=AU(s,p),DR({type:vR,role:rN,key:rR,from:r,encode:h},n)}(t,n,r.ticks,h,p,l))),c.labels&&(p=c.ticks?p:0,d.push(function(t,e,n,r,i,o){const a=wR(t,e),s=t.orient,u=t.scale,l=wU(s,-1,1),c=IN(a("labelFlush")),f=IN(a("labelFlushOffset")),h=a("labelAlign"),d=a("labelBaseline");let p,g=0===c||!!c;const m=Hz(i);m.mult=l,m.offset=Hz(a("labelPadding")||0),m.offset.mult=l;const y={scale:u,field:rR,band:.5,offset:ER(o.offset,a("labelOffset"))},v=kU(s,g?IU(u,c,'"left"','"right"','"center"'):{value:"center"},EU(s,"left","right")),_=kU(s,MU(s,"bottom","top"),g?IU(u,c,'"top"','"bottom"','"middle"'):{value:"middle"}),x=IU(u,c,`-(${f})`,f,0);g=g&&f;const b={opacity:pR,x:kU(s,y,m),y:AU(s,y,m)},w={enter:b,update:p={opacity:gR,text:{field:KO},x:b.x,y:b.y,align:v,baseline:_},exit:{opacity:pR,x:b.x,y:b.y}};Gz(w,{dx:!h&&g?kU(s,x):null,dy:!d&&g?AU(s,x):null}),Gz(w,{angle:a("labelAngle"),fill:a("labelColor"),fillOpacity:a("labelOpacity"),font:a("labelFont"),fontSize:a("labelFontSize"),fontWeight:a("labelFontWeight"),fontStyle:a("labelFontStyle"),limit:a("labelLimit"),lineHeight:a("labelLineHeight")},{align:h,baseline:d});const k=a("labelBound");let A=a("labelOverlap");return A=A||k?{separation:a("labelSeparation"),method:A,order:"datum.index",bound:k?{scale:u,orient:s,tolerance:k}:null}:void 0,p.align!==v&&(p.align=UU(p.align,v)),p.baseline!==_&&(p.baseline=UU(p.baseline,_)),DR({type:xR,role:nN,style:iR,key:rR,from:r,encode:w,overlap:A},n)}(t,n,r.labels,h,p,l))),c.domain&&d.push(function(t,e,n,r){const i=wR(t,e),o=t.orient;let a,s;const u={enter:a={opacity:pR},update:s={opacity:gR},exit:{opacity:pR}};Gz(u,{stroke:i("domainColor"),strokeCap:i("domainCap"),strokeDash:i("domainDash"),strokeDashOffset:i("domainDashOffset"),strokeWidth:i("domainWidth"),strokeOpacity:i("domainOpacity")});const l=PU(t,0),c=PU(t,1);return a.x=s.x=kU(o,l,pR),a.x2=s.x2=kU(o,c),a.y=s.y=AU(o,l,pR),a.y2=s.y2=AU(o,c),DR({type:vR,role:tN,from:r,encode:u},n)}(t,n,r.domain,f)),c.title&&d.push(WU(t,n,r.title,f)),fU(bR({role:Kz,from:f,encode:Vz(YU(u,t),i,dR),marks:d,aria:u("aria"),description:u("description"),zindex:u("zindex"),name:o,interactive:a,style:s}),e)}function YU(t,e){const n={enter:{},update:{}};return Gz(n,{orient:t("orient"),offset:t("offset")||0,position:jN(e.position,0),titlePadding:t("titlePadding"),minExtent:t("minExtent"),maxExtent:t("maxExtent"),range:{signal:`abs(span(range("${e.scale}")))`},translate:t("translate"),format:e.format,formatType:e.formatType}),n}function GU(t,e,n){const r=V(t.signals),i=V(t.scales);return n||r.forEach((t=>CN(t,e))),V(t.projections).forEach((t=>function(t,e){const n=e.config.projection||{},r={};for(const n in t)"name"!==n&&(r[n]=IO(t[n],n,e));for(const t in n)null==r[t]&&(r[t]=IO(n[t],t,e));e.addProjection(t.name,r)}(t,e))),i.forEach((t=>function(t,e){const n=t.type||"linear";Jd(n)||s("Unrecognized scale type: "+Ct(n)),e.addScale(t.name,{type:n,domain:void 0})}(t,e))),V(t.data).forEach((t=>vU(t,e))),i.forEach((t=>SO(t,e))),(n||r).forEach((t=>function(t,e){const n=e.getSignal(t.name);let r=t.update;t.init&&(r?s("Signals can not include both init and update expressions."):(r=t.init,n.initonly=!0)),r&&(r=mB(r,e),n.update=r.$expr,n.params=r.$params),t.on&&t.on.forEach((t=>ZN(t,e,n.id)))}(t,e))),V(t.axes).forEach((t=>HU(t,e))),V(t.marks).forEach((t=>fU(t,e))),V(t.legends).forEach((t=>hU(t,e))),t.title&&mU(t.title,e),e.parseLambdas(),e}const VU=t=>Vz({enter:{x:{value:0},y:{value:0}},update:{width:{signal:"width"},height:{signal:"height"}}},t);function XU(t,e){const n=e.config,r=TN(e.root=e.add($N())),i=function(t,e){const n=n=>jN(t[n],e[n]),r=[JU("background",n("background")),JU("autosize",Pz(n("autosize"))),JU("padding",Wz(n("padding"))),JU("width",n("width")||0),JU("height",n("height")||0)],i=r.reduce(((t,e)=>(t[e.name]=e,t)),{}),o={};return V(t.signals).forEach((t=>{lt(i,t.name)?t=ot(i[t.name],t):r.push(t),o[t.name]=t})),V(e.signals).forEach((t=>{lt(o,t.name)||lt(i,t.name)||r.push(t)})),r}(t,n);i.forEach((t=>CN(t,e))),e.description=t.description||n.description,e.eventConfig=n.events,e.legends=e.objectProperty(n.legend&&n.legend.layout),e.locale=n.locale;const o=e.add(rO()),a=e.add(aO(wN(VU(t.encode),mR,Zz,t.style,e,{pulse:TN(o)}))),s=e.add(EO({layout:e.objectProperty(t.layout),legends:e.legends,autosize:e.signalRef("autosize"),mark:r,pulse:TN(a)}));e.operators.pop(),e.pushState(TN(a),TN(s),null),GU(t,e,i),e.operators.push(s);let u=e.add(nO({mark:r,pulse:TN(s)}));return u=e.add(wO({pulse:TN(u)})),u=e.add(AO({pulse:TN(u)})),e.addData("root",new aU(e,o,o,u)),e}function JU(t,e){return e&&e.signal?{name:t,update:e.signal}:{name:t,value:e}}function ZU(t,e){this.config=t||{},this.options=e||{},this.bindings=[],this.field={},this.signals={},this.lambdas={},this.scales={},this.events={},this.data={},this.streams=[],this.updates=[],this.operators=[],this.eventConfig=null,this.locale=null,this._id=0,this._subid=0,this._nextsub=[0],this._parent=[],this._encode=[],this._lookup=[],this._markpath=[]}function QU(t){this.config=t.config,this.options=t.options,this.legends=t.legends,this.field=Object.create(t.field),this.signals=Object.create(t.signals),this.lambdas=Object.create(t.lambdas),this.scales=Object.create(t.scales),this.events=Object.create(t.events),this.data=Object.create(t.data),this.streams=[],this.updates=[],this.operators=[],this._id=0,this._subid=++t._nextsub[0],this._nextsub=t._nextsub,this._parent=t._parent.slice(),this._encode=t._encode.slice(),this._lookup=t._lookup.slice(),this._markpath=t._markpath}function KU(t){return(k(t)?tL:eL)(t)}function tL(t){const e=t.length;let n="[";for(let r=0;r0?",":"")+(A(e)?e.signal||KU(e):Ct(e))}return n+"]"}function eL(t){let e,n,r="{",i=0;for(e in t)n=t[e],r+=(++i>1?",":"")+Ct(e)+":"+(A(n)?n.signal||KU(n):Ct(n));return r+"}"}ZU.prototype=QU.prototype={parse(t){return GU(t,this)},fork(){return new QU(this)},isSubscope(){return this._subid>0},toRuntime(){return this.finish(),{description:this.description,operators:this.operators,streams:this.streams,updates:this.updates,bindings:this.bindings,eventConfig:this.eventConfig,locale:this.locale}},id(){return(this._subid?this._subid+":":0)+this._id++},add(t){return this.operators.push(t),t.id=this.id(),t.refs&&(t.refs.forEach((e=>{e.$ref=t.id})),t.refs=null),t},proxy(t){const e=t instanceof FN?TN(t):t;return this.add(xO({value:e}))},addStream(t){return this.streams.push(t),t.id=this.id(),t},addUpdate(t){return this.updates.push(t),t},finish(){let t,e;for(t in this.root&&(this.root.root=!0),this.signals)this.signals[t].signal=t;for(t in this.scales)this.scales[t].scale=t;function n(t,e,n){let r,i;t&&(r=t.data||(t.data={}),i=r[e]||(r[e]=[]),i.push(n))}for(t in this.data){e=this.data[t],n(e.input,t,"input"),n(e.output,t,"output"),n(e.values,t,"values");for(const r in e.index)n(e.index[r],t,"index:"+r)}return this},pushState(t,e,n){this._encode.push(TN(this.add(AO({pulse:t})))),this._parent.push(e),this._lookup.push(n?TN(this.proxy(n)):null),this._markpath.push(-1)},popState(){this._encode.pop(),this._parent.pop(),this._lookup.pop(),this._markpath.pop()},parent(){return F(this._parent)},encode(){return F(this._encode)},lookup(){return F(this._lookup)},markpath(){const t=this._markpath;return++t[t.length-1]},fieldRef(t,e){if(xt(t))return BN(t,e);t.signal||s("Unsupported field reference: "+Ct(t));const n=t.signal;let r=this.field[n];if(!r){const t={name:this.signalRef(n)};e&&(t.as=e),this.field[n]=r=TN(this.add(lO(t)))}return r},compareRef(t){let e=!1;const n=t=>qN(t)?(e=!0,this.signalRef(t.signal)):function(t){return t&&t.expr}(t)?(e=!0,this.exprRef(t.expr)):t,r=V(t.field).map(n),i=V(t.order).map(n);return e?TN(this.add(iO({fields:r,orders:i}))):NN(r,i)},keyRef(t,e){let n=!1;const r=this.signals;return t=V(t).map((t=>qN(t)?(n=!0,TN(r[t.signal])):t)),n?TN(this.add(cO({fields:t,flat:e}))):function(t,e){const n={$key:t};return e&&(n.$flat=!0),n}(t,e)},sortRef(t){if(!t)return t;const e=RN(t.op,t.field),n=t.order||"ascending";return n.signal?TN(this.add(iO({fields:e,orders:this.signalRef(n.signal)}))):NN(e,n)},event(t,e){const n=t+":"+e;if(!this.events[n]){const r=this.id();this.streams.push({id:r,source:t,type:e}),this.events[n]=r}return this.events[n]},hasOwnSignal(t){return lt(this.signals,t)},addSignal(t,e){this.hasOwnSignal(t)&&s("Duplicate signal name: "+Ct(t));const n=e instanceof FN?e:this.add($N(e));return this.signals[t]=n},getSignal(t){return this.signals[t]||s("Unrecognized signal name: "+Ct(t)),this.signals[t]},signalRef(t){return this.signals[t]?TN(this.signals[t]):(lt(this.lambdas,t)||(this.lambdas[t]=this.add($N(null))),TN(this.lambdas[t]))},parseLambdas(){const t=Object.keys(this.lambdas);for(let e=0,n=t.length;er+Math.floor(o*t.random()),pdf:t=>t===Math.floor(t)&&t>=r&&t=i?1:(e-r+1)/o},icdf:t=>t>=0&&t<=1?r-1+Math.floor(t*o):NaN};return a.min(e).max(n)},t.randomKDE=gs,t.randomLCG=function(t){return function(){return(t=(1103515245*t+12345)%2147483647)/2147483647}},t.randomLogNormal=xs,t.randomMixture=bs,t.randomNormal=ps,t.randomUniform=Es,t.read=ca,t.regressionConstant=Ds,t.regressionExp=zs,t.regressionLinear=Ts,t.regressionLoess=Ls,t.regressionLog=Bs,t.regressionPoly=Rs,t.regressionPow=Ns,t.regressionQuad=Os,t.renderModule=b_,t.repeat=Mt,t.resetDefaultLocale=function(){return Co(),Bo(),Uo()},t.resetSVGClipId=Ng,t.resetSVGDefIds=function(){Ng(),Op=0},t.responseType=la,t.runtimeContext=MB,t.sampleCurve=Is,t.sampleLogNormal=ms,t.sampleNormal=cs,t.sampleUniform=ws,t.scale=Xd,t.sceneEqual=F_,t.sceneFromJSON=qy,t.scenePickVisit=Fm,t.sceneToJSON=Ly,t.sceneVisit=Cm,t.sceneZOrder=Dm,t.scheme=dp,t.serializeXML=Wv,t.setHybridRendererOptions=function(t){h_.svgMarkTypes=t.svgMarkTypes??["text"],h_.svgOnTop=t.svgOnTop??!0,h_.debug=t.debug??!1},t.setRandom=function(e){t.random=e},t.span=Dt,t.splitAccessPath=u,t.stringValue=Ct,t.textMetrics=py,t.timeBin=Jr,t.timeFloor=wr,t.timeFormatLocale=No,t.timeInterval=Cr,t.timeOffset=$r,t.timeSequence=zr,t.timeUnitSpecifier=rr,t.timeUnits=er,t.toBoolean=Ft,t.toDate=$t,t.toNumber=S,t.toSet=Bt,t.toString=Tt,t.transform=Ka,t.transforms=Za,t.truncate=zt,t.truthy=p,t.tupleid=ya,t.typeParsers=Zo,t.utcFloor=Mr,t.utcInterval=Fr,t.utcOffset=Tr,t.utcSequence=Nr,t.utcdayofyear=hr,t.utcquarter=G,t.utcweek=dr,t.version="5.27.0",t.visitArray=Nt,t.week=sr,t.writeConfig=D,t.zero=h,t.zoomLinear=j,t.zoomLog=I,t.zoomPow=W,t.zoomSymlog=H})); //# sourceMappingURL=vega.min.js.map ================================================ FILE: docs/autogen_rst.py ================================================ import logging import os import shutil from pathlib import Path log = logging.getLogger(os.path.basename(__file__)) def module_template(module_qualname: str): module_name = module_qualname.split(".")[-1] title = module_name.replace("_", r"\_") return f"""{title} {"=" * len(title)} .. automodule:: {module_qualname} :members: :undoc-members: """ def index_template(package_name: str, doc_references: list[str] | None = None, text_prefix=""): doc_references = doc_references or "" if doc_references: doc_references = "\n" + "\n".join(f"* :doc:`{ref}`" for ref in doc_references) + "\n" dirname = package_name.split(".")[-1] title = dirname.replace("_", r"\_") if title == "tianshou": title = "Tianshou API Reference" return f"{title}\n{'=' * len(title)}" + text_prefix + doc_references def write_to_file(content: str, path: str): os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w") as f: f.write(content) os.chmod(path, 0o666) _SUBTITLE = ( "\n Here is the autogenerated documentation of the Tianshou API. \n \n " "The Table of Contents to the left has the same structure as the " "repository's package code. The links at each page point to the submodules and subpackages. \n\n " "Enjoy scrolling through! \n" ) def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""): """Creates/updates documentation in form of rst files for modules and packages. Does not delete any existing rst files. Thus, rst files for packages or modules that have been removed or renamed should be deleted by hand. This method should be executed from the project's top-level directory :param src_root: path to library base directory, typically "src/" :param clean: whether to completely clean the target directory beforehand, removing any existing .rst files :param overwrite: whether to overwrite existing rst files. This should be used with caution as it will delete all manual changes to documentation files :package_prefix: a prefix to prepend to each module (for the case where the src_root is not the base package), which, if not empty, should end with a "." :return: """ rst_root = os.path.abspath(rst_root) if clean and os.path.isdir(rst_root): shutil.rmtree(rst_root) base_package_name = package_prefix + os.path.basename(src_root) # TODO: reduce duplication with same logic for subpackages below files_in_dir = os.listdir(src_root) module_names = [f[:-3] for f in files_in_dir if f.endswith(".py") and not f.startswith("_")] subdir_refs = [ f"{f}/index" for f in files_in_dir if os.path.isdir(os.path.join(src_root, f)) and not f.startswith("_") and not f.startswith(".") ] package_index_rst_path = os.path.join( rst_root, "index.rst", ) log.info(f"Writing {package_index_rst_path}") write_to_file( index_template( base_package_name, doc_references=module_names + subdir_refs, text_prefix=_SUBTITLE, ), package_index_rst_path, ) for root, dirnames, filenames in os.walk(src_root): if os.path.basename(root).startswith("_"): continue base_package_relpath = os.path.relpath(root, start=src_root) base_package_qualname = package_prefix + os.path.relpath( root, start=os.path.dirname(src_root), ).replace(os.path.sep, ".") for dirname in dirnames: if dirname.startswith("_"): log.debug(f"Skipping {dirname}") continue files_in_dir = os.listdir(os.path.join(root, dirname)) module_names = [ f[:-3] for f in files_in_dir if f.endswith(".py") and not f.startswith("_") ] subdir_refs = [ f"{f}/index" for f in files_in_dir if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_") ] if not module_names and "__init__.py" not in files_in_dir: log.debug(f"Skipping {dirname} as it does not contain any modules or __init__.py") continue package_qualname = f"{base_package_qualname}.{dirname}" package_index_rst_path = os.path.join( rst_root, base_package_relpath, dirname, "index.rst", ) log.info(f"Writing {package_index_rst_path}") write_to_file( index_template(package_qualname, doc_references=module_names + subdir_refs), package_index_rst_path, ) for filename in filenames: base_name, ext = os.path.splitext(filename) if ext == ".py" and not filename.startswith("_"): module_qualname = f"{base_package_qualname}.{filename[:-3]}" module_rst_path = os.path.join(rst_root, base_package_relpath, f"{base_name}.rst") if os.path.exists(module_rst_path) and not overwrite: log.debug(f"{module_rst_path} already exists, skipping it") log.info(f"Writing module documentation to {module_rst_path}") write_to_file(module_template(module_qualname), module_rst_path) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) docs_root = Path(__file__).parent make_rst( docs_root / ".." / "tianshou", docs_root / "03_api", clean=True, ) ================================================ FILE: docs/bibtex.json ================================================ { "cited": { "tutorials/dqn": [ "DQN", "DDPG", "PPO" ] } } ================================================ FILE: docs/create_toc.py ================================================ import os from pathlib import Path # This script provides a platform-independent way of making the jupyter-book call (used in pyproject.toml) toc_file = Path(__file__).parent / "_toc.yml" cmd = f'jupyter-book toc from-project docs -e .rst -e .md -e .ipynb >"{toc_file}"' print(cmd) os.system(cmd) ================================================ FILE: docs/index.rst ================================================ Welcome to Tianshou! ==================== **Tianshou** (`天授 `_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include: * :class:`~tianshou.algorithm.modelfree.dqn.DQN` `Deep Q-Network `_ * :class:`~tianshou.algorithm.modelfree.dqn.DQN` `Double DQN `_ * :class:`~tianshou.algorithm.modelfree.dqn.DQN` `Dueling DQN `_ * :class:`~tianshou.algorithm.modelfree.bdqn.BDQN` `Branching DQN `_ * :class:`~tianshou.algorithm.modelfree.c51.C51` `Categorical DQN `_ * :class:`~tianshou.algorithm.modelfree.rainbow.RainbowDQN` `Rainbow DQN `_ * :class:`~tianshou.algorithm.modelfree.qrdqn.QRDQN` `Quantile Regression DQN `_ * :class:`~tianshou.algorithm.modelfree.iqn.IQN` `Implicit Quantile Network `_ * :class:`~tianshou.algorithm.modelfree.fqf.FQF` `Fully-parameterized Quantile Function `_ * :class:`~tianshou.algorithm.modelfree.reinforce.Reinforce` `Reinforce/Vanilla Policy Gradients `_ * :class:`~tianshou.algorithm.modelfree.npg.NPG` `Natural Policy Gradient `_ * :class:`~tianshou.algorithm.modelfree.a2c.A2C` `Advantage Actor-Critic `_ * :class:`~tianshou.algorithm.modelfree.trpo.TRPO` `Trust Region Policy Optimization `_ * :class:`~tianshou.algorithm.modelfree.ppo.PPO` `Proximal Policy Optimization `_ * :class:`~tianshou.algorithm.modelfree.ddpg.DDPG` `Deep Deterministic Policy Gradient `_ * :class:`~tianshou.algorithm.modelfree.td3.TD3` `Twin Delayed DDPG `_ * :class:`~tianshou.algorithm.modelfree.sac.SAC` `Soft Actor-Critic `_ * :class:`~tianshou.algorithm.modelfree.redq.REDQ` `Randomized Ensembled Double Q-Learning `_ * :class:`~tianshou.algorithm.modelfree.discrete_sac.DiscreteSAC` `Discrete Soft Actor-Critic `_ * :class:`~tianshou.algorithm.imitation.imitation_base.ImitationPolicy` Imitation Learning * :class:`~tianshou.algorithm.imitation.bcq.BCQ` `Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.algorithm.imitation.cql.CQL` `Conservative Q-Learning `_ * :class:`~tianshou.algorithm.imitation.td3_bc.TD3BC` `Twin Delayed DDPG with Behavior Cloning `_ * :class:`~tianshou.algorithm.imitation.discrete_cql.DiscreteCQL` `Discrete Conservative Q-Learning `_ * :class:`~tianshou.algorithm.imitation.discrete_bcq.DiscreteBCQ` `Discrete Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.algorithm.imitation.discrete_crr.DiscreteCRR` `Critic Regularized Regression `_ * :class:`~tianshou.algorithm.imitation.gail.GAIL` `Generative Adversarial Imitation Learning `_ * :class:`~tianshou.algorithm.modelbased.psrl.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.algorithm.modelbased.icm.ICMOffPolicyWrapper`, :class:`~tianshou.algorithm.modelbased.icm.ICMOnPolicyWrapper` `Intrinsic Curiosity Module `_ * :class:`~tianshou.data.buffer.prio.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.algorithm.algorithm_base.Algorithm.compute_episodic_return` `Generalized Advantage Estimator `_ * :class:`~tianshou.data.buffer.her.HERReplayBuffer` `Hindsight Experience Replay `_ Installation ------------ Tianshou is available through `PyPI `_. New releases require Python >= 3.11. Install Tianshou with the following command: .. code-block:: bash $ pip install tianshou Alternatively, install the current version on GitHub: .. code-block:: bash $ pip install git+https://github.com/thu-ml/tianshou.git@master --upgrade After installation, open your python console and type :: import tianshou print(tianshou.__version__) If no error occurs, you have successfully installed Tianshou. Indices and tables ------------------ * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/nbstripout.py ================================================ """Implements a platform-independent way of calling nbstripout (used in pyproject.toml).""" import glob import os from pathlib import Path if __name__ == "__main__": docs_dir = Path(__file__).parent for path in glob.glob(str(docs_dir / "02_notebooks" / "*.ipynb")): cmd = f"nbstripout {path}" os.system(cmd) ================================================ FILE: docs/refs.bib ================================================ @article{DQN, author = {Volodymyr Mnih and Koray Kavukcuoglu and David Silver and Andrei A. Rusu and Joel Veness and Marc G. Bellemare and Alex Graves and Martin A. Riedmiller and Andreas Fidjeland and Georg Ostrovski and Stig Petersen and Charles Beattie and Amir Sadik and Ioannis Antonoglou and Helen King and Dharshan Kumaran and Daan Wierstra and Shane Legg and Demis Hassabis}, title = {Human-level control through deep reinforcement learning}, journal = {Nature}, volume = {518}, number = {7540}, pages = {529--533}, year = {2015}, url = {https://doi.org/10.1038/nature14236}, doi = {10.1038/nature14236}, timestamp = {Wed, 14 Nov 2018 10:30:43 +0100}, biburl = {https://dblp.org/rec/journals/nature/MnihKSRVBGRFOPB15.bib}, bibsource = {dblp computer science bibliography, https://dblp.org} } @inproceedings{DDPG, author = {Timothy P. Lillicrap and Jonathan J. Hunt and Alexander Pritzel and Nicolas Heess and Tom Erez and Yuval Tassa and David Silver and Daan Wierstra}, title = {Continuous control with deep reinforcement learning}, booktitle = {4th International Conference on Learning Representations, {ICLR} 2016, San Juan, Puerto Rico, May 2-4, 2016, Conference Track Proceedings}, year = {2016}, url = {http://arxiv.org/abs/1509.02971}, timestamp = {Thu, 25 Jul 2019 14:25:37 +0200}, biburl = {https://dblp.org/rec/journals/corr/LillicrapHPHETS15.bib}, bibsource = {dblp computer science bibliography, https://dblp.org} } @article{PPO, author = {John Schulman and Filip Wolski and Prafulla Dhariwal and Alec Radford and Oleg Klimov}, title = {Proximal Policy Optimization Algorithms}, journal = {CoRR}, volume = {abs/1707.06347}, year = {2017}, url = {http://arxiv.org/abs/1707.06347}, archivePrefix = {arXiv}, eprint = {1707.06347}, timestamp = {Mon, 13 Aug 2018 16:47:34 +0200}, biburl = {https://dblp.org/rec/journals/corr/SchulmanWDRK17.bib}, bibsource = {dblp computer science bibliography, https://dblp.org} } ================================================ FILE: examples/__init__.py ================================================ ================================================ FILE: examples/atari/README.md ================================================ # Atari Environment ## EnvPool We highly recommend using envpool to run the following experiments. To install, in a linux machine, type: ```bash pip install envpool ``` After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster ( about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below. For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool). ## ALE-py The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The env wrapper is a crucial thing. Without wrappers, the agent cannot perform well enough on Atari games. Many existing RL codebases use [OpenAI wrapper](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py), but it is not the original DeepMind version ([related issue](https://github.com/openai/baselines/issues/240)). Dopamine has a different [wrapper](https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py) but unfortunately it cannot work very well in our codebase. # DQN (single run) One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | time cost | |-----------------------------|-------------|---------------------------------------|---------------------------------------------------------------------------------|---------------------| | PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) | | BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | | EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --num_test_envs 100` | 3~4h (100 epoch) | | QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | | MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | | SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | | SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed. We haven't tuned this result to the best, so have fun with playing these hyperparameters! # C51 (single run) One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | |-----------------------------|-------------|---------------------------------------|--------------------------------------------------------------------| | PongNoFrameskip-v4 | 20 | ![](results/c51/Pong_rew.png) | `python3 atari_c51.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 536.6 | ![](results/c51/Breakout_rew.png) | `python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1032 | ![](results/c51/Enduro_rew.png) | `python3 atari_c51.py --task "EnduroNoFrameskip-v4 " ` | | QbertNoFrameskip-v4 | 16245 | ![](results/c51/Qbert_rew.png) | `python3 atari_c51.py --task "QbertNoFrameskip-v4"` | | MsPacmanNoFrameskip-v4 | 3133 | ![](results/c51/MsPacman_rew.png) | `python3 atari_c51.py --task "MsPacmanNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 6226 | ![](results/c51/Seaquest_rew.png) | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 988.5 | ![](results/c51/SpaceInvader_rew.png) | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` | Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper. # QRDQN (single run) One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | |-----------------------------|-------------|-----------------------------------------|----------------------------------------------------------------------| | PongNoFrameskip-v4 | 20 | ![](results/qrdqn/Pong_rew.png) | `python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 409.2 | ![](results/qrdqn/Breakout_rew.png) | `python3 atari_qrdqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1055.9 | ![](results/qrdqn/Enduro_rew.png) | `python3 atari_qrdqn.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 14990 | ![](results/qrdqn/Qbert_rew.png) | `python3 atari_qrdqn.py --task "QbertNoFrameskip-v4"` | | MsPacmanNoFrameskip-v4 | 2886 | ![](results/qrdqn/MsPacman_rew.png) | `python3 atari_qrdqn.py --task "MsPacmanNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 5676 | ![](results/qrdqn/Seaquest_rew.png) | `python3 atari_qrdqn.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 938 | ![](results/qrdqn/SpaceInvader_rew.png) | `python3 atari_qrdqn.py --task "SpaceInvadersNoFrameskip-v4"` | # IQN (single run) One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | |-----------------------------|-------------|----------------------------------------|--------------------------------------------------------------------| | PongNoFrameskip-v4 | 20.3 | ![](results/iqn/Pong_rew.png) | `python3 atari_iqn.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 496.7 | ![](results/iqn/Breakout_rew.png) | `python3 atari_iqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1545 | ![](results/iqn/Enduro_rew.png) | `python3 atari_iqn.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 15342.5 | ![](results/iqn/Qbert_rew.png) | `python3 atari_iqn.py --task "QbertNoFrameskip-v4"` | | MsPacmanNoFrameskip-v4 | 2915 | ![](results/iqn/MsPacman_rew.png) | `python3 atari_iqn.py --task "MsPacmanNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 4874 | ![](results/iqn/Seaquest_rew.png) | `python3 atari_iqn.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 1498.5 | ![](results/iqn/SpaceInvaders_rew.png) | `python3 atari_iqn.py --task "SpaceInvadersNoFrameskip-v4"` | # FQF (single run) One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | |-----------------------------|-------------|----------------------------------------|--------------------------------------------------------------------| | PongNoFrameskip-v4 | 20.7 | ![](results/fqf/Pong_rew.png) | `python3 atari_fqf.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 517.3 | ![](results/fqf/Breakout_rew.png) | `python3 atari_fqf.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 2240.5 | ![](results/fqf/Enduro_rew.png) | `python3 atari_fqf.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 16172.5 | ![](results/fqf/Qbert_rew.png) | `python3 atari_fqf.py --task "QbertNoFrameskip-v4"` | | MsPacmanNoFrameskip-v4 | 2429 | ![](results/fqf/MsPacman_rew.png) | `python3 atari_fqf.py --task "MsPacmanNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 10775 | ![](results/fqf/Seaquest_rew.png) | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 2482 | ![](results/fqf/SpaceInvaders_rew.png) | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` | # Rainbow (single run) One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | |-----------------------------|-------------|--------------------------------------------|------------------------------------------------------------------------| | PongNoFrameskip-v4 | 21 | ![](results/rainbow/Pong_rew.png) | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 684.6 | ![](results/rainbow/Breakout_rew.png) | `python3 atari_rainbow.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1625.9 | ![](results/rainbow/Enduro_rew.png) | `python3 atari_rainbow.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 16192.5 | ![](results/rainbow/Qbert_rew.png) | `python3 atari_rainbow.py --task "QbertNoFrameskip-v4"` | | MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` | # PPO (single run) One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | |-----------------------------|-------------|----------------------------------------|------------------------------------------------------------------| | PongNoFrameskip-v4 | 20.2 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` | | BreakoutNoFrameskip-v4 | 441.8 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` | | EnduroNoFrameskip-v4 | 1245.4 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 17395 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` | | MsPacmanNoFrameskip-v4 | 2098 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 882 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 1e-4` | | SpaceInvadersNoFrameskip-v4 | 1340.5 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` | # SAC (single run) One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | |-----------------------------|-------------|-------------------------------------------------|----------------------------------------------------------------------------------------------------| | PongNoFrameskip-v4 | 20.1 | ![](results/discrete_sac/Pong_rew.png) | `python3 atari_sac.py --task "PongNoFrameskip-v4"` | | BreakoutNoFrameskip-v4 | 211.2 | ![](results/discrete_sac/Breakout_rew.png) | `python3 atari_sac.py --task "BreakoutNoFrameskip-v4" --n-step 1 --actor-lr 1e-4 --critic-lr 1e-4` | | EnduroNoFrameskip-v4 | 1290.7 | ![](results/discrete_sac/Enduro_rew.png) | `python3 atari_sac.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 13157.5 | ![](results/discrete_sac/Qbert_rew.png) | `python3 atari_sac.py --task "QbertNoFrameskip-v4"` | | MsPacmanNoFrameskip-v4 | 3836 | ![](results/discrete_sac/MsPacman_rew.png) | `python3 atari_sac.py --task "MsPacmanNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 1772 | ![](results/discrete_sac/Seaquest_rew.png) | `python3 atari_sac.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 649 | ![](results/discrete_sac/SpaceInvaders_rew.png) | `python3 atari_sac.py --task "SpaceInvadersNoFrameskip-v4"` | ================================================ FILE: examples/atari/__init__.py ================================================ ================================================ FILE: examples/atari/atari_c51.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import sys import numpy as np import torch from sensai.util import logging from tianshou.algorithm import C51 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import C51Net from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams log = logging.getLogger(__name__) def main( task: str = "PongNoFrameskip-v4", seed: int = 0, scale_obs: int = 0, eps_test: float = 0.005, eps_train: float = 1.0, eps_train_final: float = 0.05, buffer_size: int = 100000, lr: float = 0.0001, gamma: float = 0.99, num_atoms: int = 51, v_min: float = -10.0, v_max: float = 10.0, n_step: int = 3, target_update_freq: int = 500, epoch: int = 100, epoch_num_steps: int = 100000, collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 32, num_training_envs: int = 10, num_test_envs: int = 10, persistence_base_dir: str = "log", render: float = 0.0, device: str | None = None, frames_stack: int = 4, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "atari.benchmark", watch: bool = False, save_buffer_name: str | None = None, ) -> None: # Set defaults for mutable arguments if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config (excluding internal/temporary ones) params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_atari_env( task, seed, num_training_envs, num_test_envs, scale=scale_obs, frame_stack=frames_stack, ) state_shape = env.observation_space.shape or env.observation_space.n # type: ignore action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") # seed np.random.seed(seed) torch.manual_seed(seed) # define model c, h, w = state_shape net = C51Net(c=c, h=h, w=w, action_shape=action_shape, num_atoms=num_atoms) # define policy and algorithm optim = AdamOptimizerFactory(lr=lr) policy = C51Policy( model=net, action_space=env.action_space, num_atoms=num_atoms, v_min=v_min, v_max=v_max, eps_training=eps_train, eps_inference=eps_test, ) algorithm: C51 = C51( policy=policy, optim=optim, gamma=gamma, n_step_return_horizon=n_step, target_update_freq=target_update_freq, ).to(device) # load a previous model if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) # collectors training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "c51" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec.reward_threshold: # type: ignore return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in task: return mean_rewards >= 20 return False def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final) else: eps = eps_train_final policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) def watch_fn() -> None: log.info("Setup test envs ...") test_envs.seed(seed) if save_buffer_name: log.info(f"Generate buffer with size {buffer_size}") buffer = VectorReplayBuffer( buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) collector = Collector[CollectStats]( algorithm, test_envs, buffer, exploration_noise=True ) result = collector.collect(n_step=buffer_size, reset_before_collect=True) log.info(f"Save buffer into {save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(save_buffer_name) else: log.info("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=num_test_envs, render=render) result.pprint_asdict() if watch: watch_fn() sys.exit(0) # test training_collector and start filling replay buffer training_collector.reset() training_collector.collect(n_step=batch_size * num_training_envs) # trainer result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, ) ) pprint.pprint(result) watch_fn() if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_dqn.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import sys import numpy as np import torch from sensai.util import logging from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOffPolicyWrapper from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import IntrinsicCuriosityModule log = logging.getLogger(__name__) def main( task: str = "PongNoFrameskip-v4", seed: int = 0, scale_obs: int = 0, eps_test: float = 0.005, eps_train: float = 1.0, eps_train_final: float = 0.05, buffer_size: int = 100000, lr: float = 0.0001, gamma: float = 0.99, n_step: int = 3, target_update_freq: int = 500, epoch: int = 100, epoch_num_steps: int = 100000, collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 32, num_training_envs: int = 10, num_test_envs: int = 10, persistence_base_dir: str = "log", render: float = 0.0, device: str | None = None, frames_stack: int = 4, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "atari.benchmark", watch: bool = False, save_buffer_name: str | None = None, icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, ) -> None: # Set defaults for mutable arguments if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config (excluding internal/temporary ones) params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") state_shape: tuple[int, ...] | int action_shape: int env, training_envs, test_envs = make_atari_env( task, seed, num_training_envs, num_test_envs, scale=scale_obs, frame_stack=frames_stack, ) state_shape = env.observation_space.shape or env.observation_space.n # type: ignore action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") # seed np.random.seed(seed) torch.manual_seed(seed) # define model c, h, w = state_shape net = DQNet(c=c, h=h, w=w, action_shape=action_shape).to(device) optim = AdamOptimizerFactory(lr=lr) # define policy and algorithm policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=eps_train, eps_inference=eps_test, ) algorithm: DQN | ICMOffPolicyWrapper algorithm = DQN( policy=policy, optim=optim, gamma=gamma, n_step_return_horizon=n_step, target_update_freq=target_update_freq, ) if icm_lr_scale > 0: c, h, w = state_shape feature_net = DQNet(c=c, h=h, w=w, action_shape=action_shape, features_only=True) action_dim = int(np.prod(action_shape)) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net=feature_net.net, feature_dim=feature_dim, action_dim=action_dim, hidden_sizes=[512], ) icm_optim = AdamOptimizerFactory(lr=lr) algorithm = ICMOffPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=icm_lr_scale, reward_scale=icm_reward_scale, forward_loss_weight=icm_forward_loss_weight, ).to(device) # load a previous model if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) # collectors training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "dqn_icm" if icm_lr_scale > 0 else "dqn" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec.reward_threshold: # type: ignore return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in task: return mean_rewards >= 20 return False def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final) else: eps = eps_train_final policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save({"model": algorithm.state_dict()}, ckpt_path) return ckpt_path def watch_fn() -> None: log.info("Setup test envs ...") test_envs.seed(seed) if save_buffer_name: log.info(f"Generate buffer with size {buffer_size}") buffer = VectorReplayBuffer( buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) collector = Collector[CollectStats]( algorithm, test_envs, buffer, exploration_noise=True ) result = collector.collect(n_step=buffer_size, reset_before_collect=True) log.info(f"Save buffer into {save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(save_buffer_name) else: log.info("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=num_test_envs, render=render) result.pprint_asdict() if watch: watch_fn() sys.exit(0) # test training_collector and start filling replay buffer training_collector.reset() training_collector.collect(n_step=batch_size * num_training_envs) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, resume_from_log=resume_id is not None, save_checkpoint_fn=save_checkpoint_fn, ) ) pprint.pprint(result) watch_fn() if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_dqn_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal from sensai.util import logging from tianshou.env.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, ) from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( DQNExperimentBuilder, ExperimentConfig, ) from tianshou.highlevel.params.algorithm_params import DQNParams from tianshou.highlevel.trainer import ( EpochTestCallbackDQNSetEps, EpochTrainCallbackDQNEpsLinearDecay, ) def main( task: str = "PongNoFrameskip-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "sequential", max_epochs: int = 100, epoch_num_steps: int = 100000, ) -> None: """ Train an agent using DQN on a specified Atari task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the Atari task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OffPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, batch_size=32, num_training_envs=10, num_test_envs=10, buffer_size=100000, collection_step_num_env_steps=10, update_step_num_gradient_steps_per_sample=0.1, replay_buffer_stack_num=4, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, ) env_factory = AtariEnvFactory(task, 4, scale=False) experiment_builder = ( DQNExperimentBuilder(env_factory, experiment_config, training_config) .with_dqn_params( DQNParams( gamma=0.99, n_step_return_horizon=3, lr=0.0001, target_update_freq=500, ), ) .with_model_factory(IntermediateModuleFactoryAtariDQN()) .with_epoch_train_callback( EpochTrainCallbackDQNEpsLinearDecay(1.0, 0.05), ) .with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.005)) .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_fqf.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import sys import numpy as np import torch from sensai.util import logging from tianshou.algorithm import FQF from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.fqf import FQFPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction log = logging.getLogger(__name__) def main( task: str = "PongNoFrameskip-v4", seed: int = 3128, scale_obs: int = 0, eps_test: float = 0.005, eps_train: float = 1.0, eps_train_final: float = 0.05, buffer_size: int = 100000, lr: float = 5e-5, fraction_lr: float = 2.5e-9, gamma: float = 0.99, num_fractions: int = 32, num_cosines: int = 64, ent_coef: float = 10.0, hidden_sizes: list | None = None, n_step: int = 3, target_update_freq: int = 500, epoch: int = 100, epoch_num_steps: int = 100000, collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 32, num_training_envs: int = 10, num_test_envs: int = 10, persistence_base_dir: str = "log", render: float = 0.0, device: str | None = None, frames_stack: int = 4, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "atari.benchmark", watch: bool = False, save_buffer_name: str | None = None, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [512] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config (excluding internal/temporary ones) params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_atari_env( task, seed, num_training_envs, num_test_envs, scale=scale_obs, frame_stack=frames_stack, ) state_shape = env.observation_space.shape or env.observation_space.n # type: ignore action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") # seed np.random.seed(seed) torch.manual_seed(seed) # define model c, h, w = state_shape feature_net = DQNet(c=c, h=h, w=w, action_shape=action_shape, features_only=True) net = FullQuantileFunction( preprocess_net=feature_net, action_shape=action_shape, hidden_sizes=hidden_sizes, num_cosines=num_cosines, ).to(device) optim = AdamOptimizerFactory(lr=lr) fraction_net = FractionProposalNetwork(num_fractions, net.input_dim) fraction_optim = RMSpropOptimizerFactory(lr=fraction_lr) # define policy and algorithm policy = FQFPolicy( model=net, fraction_model=fraction_net, action_space=env.action_space, ) algorithm: FQF = FQF( policy=policy, optim=optim, fraction_optim=fraction_optim, gamma=gamma, num_fractions=num_fractions, ent_coef=ent_coef, n_step_return_horizon=n_step, target_update_freq=target_update_freq, ).to(device) # load a previous policy if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) # collectors training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "fqf" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec.reward_threshold: # type: ignore return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in task: return mean_rewards >= 20 return False def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final) else: eps = eps_train_final policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps_training(eps_test) def watch_fn() -> None: log.info("Setup test envs ...") policy.set_eps_training(eps_test) test_envs.seed(seed) if save_buffer_name: log.info(f"Generate buffer with size {buffer_size}") buffer = VectorReplayBuffer( buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) collector = Collector[CollectStats]( algorithm, test_envs, buffer, exploration_noise=True ) result = collector.collect(n_step=buffer_size, reset_before_collect=True) log.info(f"Save buffer into {save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(save_buffer_name) else: log.info("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=num_test_envs, render=render) result.pprint_asdict() if watch: watch_fn() sys.exit(0) # test training_collector and start filling replay buffer training_collector.reset() training_collector.collect(n_step=batch_size * num_training_envs) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, training_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, ) ) pprint.pprint(result) watch_fn() if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_iqn.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import sys import numpy as np import torch from sensai.util import logging from tianshou.algorithm import IQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.iqn import IQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import ImplicitQuantileNetwork log = logging.getLogger(__name__) def main( task: str = "PongNoFrameskip-v4", seed: int = 1234, scale_obs: int = 0, eps_test: float = 0.005, eps_train: float = 1.0, eps_train_final: float = 0.05, buffer_size: int = 100000, lr: float = 0.0001, gamma: float = 0.99, sample_size: int = 32, online_sample_size: int = 8, target_sample_size: int = 8, num_cosines: int = 64, hidden_sizes: list | None = None, n_step: int = 3, target_update_freq: int = 500, epoch: int = 100, epoch_num_steps: int = 100000, collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 32, num_training_envs: int = 10, num_test_envs: int = 10, persistence_base_dir: str = "log", render: float = 0.0, device: str | None = None, frames_stack: int = 4, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "atari.benchmark", watch: bool = False, save_buffer_name: str | None = None, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [512] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config (excluding internal/temporary ones) params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_atari_env( task, seed, num_training_envs, num_test_envs, scale=scale_obs, frame_stack=frames_stack, ) state_shape = env.observation_space.shape or env.observation_space.n # type: ignore action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") # seed np.random.seed(seed) torch.manual_seed(seed) # define model c, h, w = state_shape feature_net = DQNet(c=c, h=h, w=w, action_shape=action_shape, features_only=True) net = ImplicitQuantileNetwork( preprocess_net=feature_net, action_shape=action_shape, hidden_sizes=hidden_sizes, num_cosines=num_cosines, ).to(device) optim = AdamOptimizerFactory(lr=lr) # define policy and algorithm policy = IQNPolicy( model=net, action_space=env.action_space, sample_size=sample_size, online_sample_size=online_sample_size, target_sample_size=target_sample_size, eps_training=eps_train, eps_inference=eps_test, ) algorithm: IQN = IQN( policy=policy, optim=optim, gamma=gamma, n_step_return_horizon=n_step, target_update_freq=target_update_freq, ).to(device) # load previous model if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "iqn" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec.reward_threshold: # type: ignore return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in task: return mean_rewards >= 20 return False def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final) else: eps = eps_train_final policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) # watch agent's performance def watch_fn() -> None: log.info("Setup test envs ...") test_envs.seed(seed) if save_buffer_name: log.info(f"Generate buffer with size {buffer_size}") buffer = VectorReplayBuffer( buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) collector = Collector[CollectStats]( algorithm, test_envs, buffer, exploration_noise=True ) result = collector.collect(n_step=buffer_size, reset_before_collect=True) log.info(f"Save buffer into {save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(save_buffer_name) else: log.info("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=num_test_envs, render=render) result.pprint_asdict() if watch: watch_fn() sys.exit(0) # test training_collector and start filling replay buffer training_collector.reset() training_collector.collect(n_step=batch_size * num_training_envs) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, ) ) pprint.pprint(result) watch_fn() if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_iqn_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal from sensai.util import logging from tianshou.env.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, ) from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, IQNExperimentBuilder, ) from tianshou.highlevel.params.algorithm_params import IQNParams from tianshou.highlevel.trainer import ( EpochTrainCallbackDQNEpsLinearDecay, ) def main( task: str = "PongNoFrameskip-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "sequential", max_epochs: int = 100, epoch_num_steps: int = 100000, ) -> None: """ Train an agent using IQN on a specified Atari task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the Atari task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OffPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, batch_size=32, num_training_envs=10, num_test_envs=10, buffer_size=100000, collection_step_num_env_steps=10, update_step_num_gradient_steps_per_sample=0.1, replay_buffer_stack_num=4, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, ) env_factory = AtariEnvFactory(task, 4, scale=False) experiment_builder = ( IQNExperimentBuilder(env_factory, experiment_config, training_config) .with_iqn_params( IQNParams( gamma=0.99, n_step_return_horizon=3, lr=0.0001, sample_size=32, online_sample_size=8, target_update_freq=500, target_sample_size=8, hidden_sizes=(512,), num_cosines=64, eps_training=1.0, eps_inference=0.005, ), ) .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True)) .with_epoch_train_callback( EpochTrainCallbackDQNEpsLinearDecay(1.0, 0.05), ) .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_ppo.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import sys from collections.abc import Sequence from typing import cast import numpy as np import torch from sensai.util import logging from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import ( DQNet, ScaledObsInputActionReprNet, layer_init, ) from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.discrete import ( DiscreteActor, DiscreteCritic, IntrinsicCuriosityModule, ) log = logging.getLogger(__name__) def main( task: str = "PongNoFrameskip-v4", seed: int = 4213, scale_obs: int = 1, buffer_size: int = 100000, lr: float = 2.5e-4, gamma: float = 0.99, epoch: int = 100, epoch_num_steps: int = 100000, collection_step_num_env_steps: int = 1000, update_step_num_repetitions: int = 4, batch_size: int = 256, hidden_size: int = 512, num_training_envs: int = 10, num_test_envs: int = 10, return_scaling: bool = False, vf_coef: float = 0.25, ent_coef: float = 0.01, gae_lambda: float = 0.95, lr_decay: int = True, max_grad_norm: float = 0.5, eps_clip: float = 0.1, dual_clip: float | None = None, value_clip: bool = True, advantage_normalization: bool = True, recompute_adv: bool = False, persistence_base_dir: str = "log", render: float = 0.0, device: str | None = None, frames_stack: int = 4, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "atari.benchmark", watch: bool = False, save_buffer_name: str | None = None, icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, ) -> None: # Set defaults for mutable arguments if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config (excluding internal/temporary ones) params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") state_shape: tuple[int, ...] | int action_shape: Sequence[int] | int env, training_envs, test_envs = make_atari_env( task, seed, num_training_envs, num_test_envs, scale=scale_obs, frame_stack=frames_stack, ) state_shape = cast(tuple[int, ...], env.observation_space.shape) action_shape = cast(Sequence[int] | int, env.action_space.shape or env.action_space.n) # type: ignore # should be N_FRAMES x H x W log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") # seed np.random.seed(seed) torch.manual_seed(seed) # define model c, h, w = state_shape net: ScaledObsInputActionReprNet | DQNet net = DQNet( c=c, h=h, w=w, action_shape=action_shape, features_only=True, output_dim_added_layer=hidden_size, layer_init=layer_init, ) if scale_obs: net = ScaledObsInputActionReprNet(net) actor = DiscreteActor(preprocess_net=net, action_shape=action_shape, softmax_output=False) critic = DiscreteCritic(preprocess_net=net) optim = AdamOptimizerFactory(lr=lr, eps=1e-5) if lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, ) ) # define algorithm policy = DiscreteActorPolicy( actor=actor, action_space=env.action_space, ) algorithm: PPO = PPO( policy=policy, critic=critic, optim=optim, gamma=gamma, gae_lambda=gae_lambda, max_grad_norm=max_grad_norm, vf_coef=vf_coef, ent_coef=ent_coef, return_scaling=return_scaling, eps_clip=eps_clip, value_clip=value_clip, dual_clip=dual_clip, advantage_normalization=advantage_normalization, recompute_advantage=recompute_adv, ).to(device) if icm_lr_scale > 0: c, h, w = state_shape feature_net = DQNet(c=c, h=h, w=w, action_shape=action_shape, features_only=True) action_dim = int(np.prod(action_shape)) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net=feature_net.net, feature_dim=feature_dim, action_dim=action_dim, hidden_sizes=[hidden_size], ) icm_optim = AdamOptimizerFactory(lr=lr) algorithm = ICMOnPolicyWrapper( # type: ignore[assignment] wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=icm_lr_scale, reward_scale=icm_reward_scale, forward_loss_weight=icm_forward_loss_weight, ).to(device) # load a previous policy if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "ppo_icm" if icm_lr_scale > 0 else "ppo" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec.reward_threshold: # type: ignore return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in task: return mean_rewards >= 20 return False def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save({"model": algorithm.state_dict()}, ckpt_path) return ckpt_path def watch_fn() -> None: log.info("Setup test envs ...") test_envs.seed(seed) if save_buffer_name: log.info(f"Generate buffer with size {buffer_size}") buffer = VectorReplayBuffer( buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) collector = Collector[CollectStats]( algorithm, test_envs, buffer, exploration_noise=True ) result = collector.collect(n_step=buffer_size, reset_before_collect=True) log.info(f"Save buffer into {save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(save_buffer_name) else: log.info("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=num_test_envs, render=render) result.pprint_asdict() if watch: watch_fn() sys.exit(0) # test training_collector and start filling replay buffer training_collector.reset() training_collector.collect(n_step=batch_size * num_training_envs) # train result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, update_step_num_repetitions=update_step_num_repetitions, test_step_num_episodes=num_test_envs, batch_size=batch_size, collection_step_num_env_steps=collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=False, resume_from_log=resume_id is not None, save_checkpoint_fn=save_checkpoint_fn, ) ) pprint.pprint(result) watch_fn() if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_ppo_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal from sensai.util import logging from tianshou.env.atari.atari_network import ( ActorFactoryAtariDQN, ) from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, ) from tianshou.highlevel.params.algorithm_params import PPOParams from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( task: str = "PongNoFrameskip-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "sequential", max_epochs: int = 100, epoch_num_steps: int = 100000, ) -> None: """ Train an agent using PPO on a specified Atari task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the Atari task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OnPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, batch_size=256, num_training_envs=10, num_test_envs=10, buffer_size=100000, collection_step_num_env_steps=1000, update_step_num_repetitions=4, replay_buffer_stack_num=4, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, ) env_factory = AtariEnvFactory(task, 4, scale=True) experiment_builder = ( PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( gamma=0.99, gae_lambda=0.95, return_scaling=False, ent_coef=0.01, vf_coef=0.25, max_grad_norm=0.5, value_clip=True, advantage_normalization=True, eps_clip=0.1, dual_clip=None, recompute_advantage=False, lr=2.5e-4, lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config), ), ) .with_actor_factory(ActorFactoryAtariDQN(scale_obs=True, features_only=True)) .with_critic_factory_use_actor() .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_qrdqn.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import sys import numpy as np import torch from sensai.util import logging from tianshou.algorithm import QRDQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import QRDQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams log = logging.getLogger(__name__) def main( task: str = "PongNoFrameskip-v4", seed: int = 0, scale_obs: int = 0, eps_test: float = 0.005, eps_train: float = 1.0, eps_train_final: float = 0.05, buffer_size: int = 100000, lr: float = 0.0001, gamma: float = 0.99, num_quantiles: int = 200, n_step: int = 3, target_update_freq: int = 500, epoch: int = 100, epoch_num_steps: int = 100000, collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 32, num_training_envs: int = 10, num_test_envs: int = 10, persistence_base_dir: str = "log", render: float = 0.0, device: str | None = None, frames_stack: int = 4, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "atari.benchmark", watch: bool = False, save_buffer_name: str | None = None, ) -> None: # Set defaults for mutable arguments if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config (excluding internal/temporary ones) params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_atari_env( task, seed, num_training_envs, num_test_envs, scale=scale_obs, frame_stack=frames_stack, ) state_shape = env.observation_space.shape or env.observation_space.n # type: ignore action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") # seed np.random.seed(seed) torch.manual_seed(seed) # define model c, h, w = state_shape net = QRDQNet( c=c, h=h, w=w, action_shape=action_shape, num_quantiles=num_quantiles, ) # define policy and algorithm optim = AdamOptimizerFactory(lr=lr) policy = QRDQNPolicy( model=net, action_space=env.action_space, eps_training=eps_train, eps_inference=eps_test, ) algorithm: QRDQN = QRDQN( policy=policy, optim=optim, gamma=gamma, num_quantiles=num_quantiles, n_step_return_horizon=n_step, target_update_freq=target_update_freq, ).to(device) # load a previous policy if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "qrdqn" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec.reward_threshold: # type: ignore return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in task: return mean_rewards >= 20 return False def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final) else: eps = eps_train_final policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) # watch agent's performance def watch_fn() -> None: log.info("Setup test envs ...") test_envs.seed(seed) if save_buffer_name: log.info(f"Generate buffer with size {buffer_size}") buffer = VectorReplayBuffer( buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) collector = Collector[CollectStats]( algorithm, test_envs, buffer, exploration_noise=True ) result = collector.collect(n_step=buffer_size, reset_before_collect=True) log.info(f"Save buffer into {save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(save_buffer_name) else: log.info("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=num_test_envs, render=render) result.pprint_asdict() if watch: watch_fn() sys.exit(0) # test training_collector and start filling replay buffer training_collector.reset() training_collector.collect(n_step=batch_size * num_training_envs) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, ) ) pprint.pprint(result) watch_fn() if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_rainbow.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import sys import numpy as np import torch from sensai.util import logging from tianshou.algorithm import C51, RainbowDQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) from tianshou.env.atari.atari_network import RainbowNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams log = logging.getLogger(__name__) def main( task: str = "PongNoFrameskip-v4", seed: int = 0, scale_obs: int = 0, eps_test: float = 0.005, eps_train: float = 1.0, eps_train_final: float = 0.05, buffer_size: int = 100000, lr: float = 0.0000625, gamma: float = 0.99, num_atoms: int = 51, v_min: float = -10.0, v_max: float = 10.0, noisy_std: float = 0.1, no_dueling: bool = False, no_noisy: bool = False, no_priority: bool = False, alpha: float = 0.5, beta: float = 0.4, beta_final: float = 1.0, beta_anneal_step: int = 5000000, no_weight_norm: bool = False, n_step: int = 3, target_update_freq: int = 500, epoch: int = 100, epoch_num_steps: int = 100000, collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 32, num_training_envs: int = 10, num_test_envs: int = 10, persistence_base_dir: str = "log", render: float = 0.0, device: str | None = None, frames_stack: int = 4, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "atari.benchmark", watch: bool = False, save_buffer_name: str | None = None, ) -> None: # Set defaults for mutable arguments if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config (excluding internal/temporary ones) params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_atari_env( task, seed, num_training_envs, num_test_envs, scale=scale_obs, frame_stack=frames_stack, ) state_shape = env.observation_space.shape or env.observation_space.n # type: ignore action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") # seed np.random.seed(seed) torch.manual_seed(seed) # define model c, h, w = state_shape net = RainbowNet( c=c, h=h, w=w, action_shape=action_shape, num_atoms=num_atoms, noisy_std=noisy_std, is_dueling=not no_dueling, is_noisy=not no_noisy, ) # define policy and algorithm policy = C51Policy( model=net, action_space=env.action_space, num_atoms=num_atoms, v_min=v_min, v_max=v_max, eps_training=eps_train, eps_inference=eps_test, ) optim = AdamOptimizerFactory(lr=lr) algorithm: C51 = RainbowDQN( policy=policy, optim=optim, gamma=gamma, n_step_return_horizon=n_step, target_update_freq=target_update_freq, ).to(device) # load a previous policy if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if no_priority: buffer = VectorReplayBuffer( buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) else: buffer = PrioritizedVectorReplayBuffer( buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, alpha=alpha, beta=beta, weight_norm=not no_weight_norm, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "rainbow" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec.reward_threshold: # type: ignore return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in task: return mean_rewards >= 20 return False def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final) else: eps = eps_train_final policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) if not no_priority: if env_step <= beta_anneal_step: beta_value = beta - env_step / beta_anneal_step * (beta - beta_final) else: beta_value = beta_final buffer.set_beta(beta_value) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/beta": beta_value}) def watch_fn() -> None: log.info("Setup test envs ...") test_envs.seed(seed) if save_buffer_name: log.info(f"Generate buffer with size {buffer_size}") buffer = PrioritizedVectorReplayBuffer( buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, alpha=alpha, beta=beta, ) collector = Collector[CollectStats]( algorithm, test_envs, buffer, exploration_noise=True ) result = collector.collect(n_step=buffer_size, reset_before_collect=True) log.info(f"Save buffer into {save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(save_buffer_name) else: log.info("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=num_test_envs, render=render) result.pprint_asdict() if watch: watch_fn() sys.exit(0) # test training_collector and start filling replay buffer training_collector.reset() training_collector.collect(n_step=batch_size * num_training_envs) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, ) ) pprint.pprint(result) watch_fn() if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_sac.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import sys import numpy as np import torch from sensai.util import logging from tianshou.algorithm import DiscreteSAC, ICMOffPolicyWrapper from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import ( DiscreteActor, DiscreteCritic, IntrinsicCuriosityModule, ) log = logging.getLogger(__name__) def main( task: str = "PongNoFrameskip-v4", seed: int = 4213, scale_obs: int = 0, buffer_size: int = 100000, actor_lr: float = 1e-5, critic_lr: float = 1e-5, gamma: float = 0.99, n_step: int = 3, tau: float = 0.005, alpha: float = 0.05, auto_alpha: bool = False, alpha_lr: float = 3e-4, epoch: int = 100, epoch_num_steps: int = 100000, collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 64, hidden_size: int = 512, num_training_envs: int = 10, num_test_envs: int = 10, return_scaling: int = False, persistence_base_dir: str = "log", render: float = 0.0, device: str | None = None, frames_stack: int = 4, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "atari.benchmark", watch: bool = False, save_buffer_name: str | None = None, icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, ) -> None: # Set defaults for mutable arguments if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config (excluding internal/temporary ones) params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_atari_env( task, seed, num_training_envs, num_test_envs, scale=scale_obs, frame_stack=frames_stack, ) c, h, w = env.observation_space.shape # type: ignore action_shape = env.action_space.n # type: ignore # should be N_FRAMES x H x W log.info(f"Observations shape: {(c, h, w)}") log.info(f"Actions shape: {action_shape}") # seed np.random.seed(seed) torch.manual_seed(seed) # define model net = DQNet( c, h, w, action_shape=action_shape, features_only=True, output_dim_added_layer=hidden_size, ) actor = DiscreteActor(preprocess_net=net, action_shape=action_shape, softmax_output=False) actor_optim = AdamOptimizerFactory(lr=actor_lr) critic1 = DiscreteCritic(preprocess_net=net, last_size=action_shape) critic1_optim = AdamOptimizerFactory(lr=critic_lr) critic2 = DiscreteCritic(preprocess_net=net, last_size=action_shape) critic2_optim = AdamOptimizerFactory(lr=critic_lr) # define policy and algorithm alpha_param: float | AutoAlpha = alpha if auto_alpha: target_entropy = 0.98 * np.log(np.prod(action_shape)) log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=alpha_lr) alpha_param = AutoAlpha(target_entropy, log_alpha, alpha_optim) algorithm: DiscreteSAC | ICMOffPolicyWrapper policy = DiscreteSACPolicy( actor=actor, action_space=env.action_space, ) algorithm = DiscreteSAC( policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=tau, gamma=gamma, alpha=alpha_param, n_step_return_horizon=n_step, ).to(device) if icm_lr_scale > 0: feature_net = DQNet(c=c, h=h, w=w, action_shape=action_shape, features_only=True) action_dim = np.prod(action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net=feature_net.net, feature_dim=feature_dim, action_dim=int(action_dim), hidden_sizes=[hidden_size], ) icm_optim = AdamOptimizerFactory(lr=actor_lr) algorithm = ICMOffPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=icm_lr_scale, reward_scale=icm_reward_scale, forward_loss_weight=icm_forward_loss_weight, ).to(device) # load a previous model if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "discrete_sac_icm" if icm_lr_scale > 0 else "discrete_sac" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec.reward_threshold: # type: ignore return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in task: return mean_rewards >= 20 return False def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") torch.save({"model": algorithm.state_dict()}, ckpt_path) return ckpt_path def watch_fn() -> None: log.info("Setup test envs ...") test_envs.seed(seed) if save_buffer_name: log.info(f"Generate buffer with size {buffer_size}") buffer = VectorReplayBuffer( buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=frames_stack, ) collector = Collector[CollectStats]( algorithm, test_envs, buffer, exploration_noise=True ) result = collector.collect(n_step=buffer_size, reset_before_collect=True) log.info(f"Save buffer into {save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(save_buffer_name) else: log.info("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=num_test_envs, render=render) result.pprint_asdict() if watch: watch_fn() sys.exit(0) # test training_collector and start filling replay buffer training_collector.reset() training_collector.collect(n_step=batch_size * num_training_envs) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, resume_from_log=resume_id is not None, save_checkpoint_fn=save_checkpoint_fn, ) ) pprint.pprint(result) watch_fn() if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/atari/atari_sac_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal from sensai.util import logging from tianshou.env.atari.atari_network import ( ActorFactoryAtariDQN, ) from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( DiscreteSACExperimentBuilder, ExperimentConfig, ) from tianshou.highlevel.params.algorithm_params import DiscreteSACParams def main( task: str = "PongNoFrameskip-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "sequential", max_epochs: int = 100, epoch_num_steps: int = 100000, ) -> None: """ Train an agent using SAC on a specified Atari task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the Atari task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OffPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, update_step_num_gradient_steps_per_sample=0.1, batch_size=64, num_training_envs=10, num_test_envs=10, buffer_size=100000, collection_step_num_env_steps=10, replay_buffer_stack_num=4, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, ) env_factory = AtariEnvFactory(task, 4, scale=False) experiment_builder = ( DiscreteSACExperimentBuilder(env_factory, experiment_config, training_config) .with_sac_params( DiscreteSACParams( actor_lr=1e-5, critic1_lr=1e-5, critic2_lr=1e-5, gamma=0.99, tau=0.005, alpha=0.05, n_step_return_horizon=3, ), ) .with_actor_factory(ActorFactoryAtariDQN(scale_obs=False, features_only=True)) .with_common_critic_factory_use_actor() .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/box2d/README.md ================================================ # Bipedal-Hardcore-SAC - Our default choice: remove the done flag penalty, will soon converge to \~280 reward within 100 epochs (10M env steps, 3~4 hours, see the image below) - If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward) ![](results/sac/BipedalHardcore.png) # BipedalWalker-BDQ - To demonstrate the cpabilities of the BDQ to scale up to big discrete action spaces, we run it on a discretized version of the BipedalWalker-v3 environment, where the number of possible actions in each dimension is 25, for a total of 25^4 = 390 625 possible actions. A usaual DQN architecture would use 25^4 output neurons for the Q-network, thus scaling exponentially with the number of action space dimensions, while the Branching architecture scales linearly and uses only 25*4 output neurons. ![](results/bdq/BipedalWalker.png) ================================================ FILE: examples/box2d/acrobot_dualdqn.py ================================================ import argparse import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Acrobot-v1") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.5) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=100) parser.add_argument("--update_per_step", type=float, default=0.01) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128]) parser.add_argument("--dueling_q_hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--dueling_v_hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_args() def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape # training_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, dueling_param=(Q_param, V_param), ) optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: DQN = DQN( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec: if not env.spec.reward_threshold: return False else: return mean_rewards >= env.spec.reward_threshold return False def train_fn(epoch: int, env_step: int) -> None: if env_step <= 100000: policy.set_eps_training(args.eps_train) elif env_step <= 500000: eps = args.eps_train - (env_step - 100000) / 400000 * (0.5 * args.eps_train) policy.set_eps_training(eps) else: policy.set_eps_training(0.5 * args.eps_train) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": test_dqn(get_args()) ================================================ FILE: examples/box2d/bipedal_bdq.py ================================================ import argparse import datetime import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import BDQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.bdqn import BDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import BranchingNet def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() # task parser.add_argument("--task", type=str, default="BipedalWalker-v3") # network architecture parser.add_argument("--common_hidden_sizes", type=int, nargs="*", default=[512, 256]) parser.add_argument("--action_hidden_sizes", type=int, nargs="*", default=[128]) parser.add_argument("--value_hidden_sizes", type=int, nargs="*", default=[128]) parser.add_argument("--action_per_branch", type=int, default=25) # training hyperparameters parser.add_argument("--seed", type=int, default=0) parser.add_argument("--eps_test", type=float, default=0.0) parser.add_argument("--eps_train", type=float, default=0.73) parser.add_argument("--eps_decay", type=float, default=5e-6) parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--target_update_freq", type=int, default=1000) parser.add_argument("--epoch", type=int, default=25) parser.add_argument("--epoch_num_steps", type=int, default=80000) parser.add_argument("--collection_step_num_env_steps", type=int, default=16) parser.add_argument("--update_per_step", type=float, default=0.0625) parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--num_training_envs", type=int, default=20) parser.add_argument("--num_test_envs", type=int, default=10) # other parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_args() def run_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) assert isinstance(env.action_space, gym.spaces.MultiDiscrete) assert isinstance( env.observation_space, gym.spaces.Box, ) # BipedalWalker-v3 has `Box` observation space by design args.state_shape = env.observation_space.shape args.action_shape = env.action_space.shape args.num_branches = args.action_shape[0] print("Observations shape:", args.state_shape) print("Num branches:", args.num_branches) print("Actions per branch:", args.action_per_branch) # training_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) # you can also use tianshou.env.SubprocVectorEnv training_envs = SubprocVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) for _ in range(args.num_training_envs) ], ) # test_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) test_envs = SubprocVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) for _ in range(args.num_test_envs) ], ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = BranchingNet( state_shape=args.state_shape, num_branches=args.num_branches, action_per_branch=args.action_per_branch, common_hidden_sizes=args.common_hidden_sizes, value_hidden_sizes=args.value_hidden_sizes, action_hidden_sizes=args.action_hidden_sizes, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = BDQNPolicy( model=net, # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? action_space=env.action_space, # type: ignore[arg-type] eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: BDQN = BDQN( policy=policy, optim=optim, gamma=args.gamma, target_update_freq=args.target_update_freq, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=False) training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # log current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") log_path = os.path.join(args.logdir, "bdq", args.task, current_time) writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec and env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold return False def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) policy.set_eps_training(eps) # trainer result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, training_fn=train_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! policy.set_eps_training(args.eps_test) test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": run_bdq(get_args()) ================================================ FILE: examples/box2d/bipedal_hardcore_sac.py ================================================ import argparse import os import pprint from typing import Any import gymnasium as gym import numpy as np import torch from gymnasium.core import WrapperActType, WrapperObsType from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import SAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="BipedalWalkerHardcore-v3") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--buffer_size", type=int, default=1000000) parser.add_argument("--actor_lr", type=float, default=3e-4) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.1) parser.add_argument("--auto_alpha", type=int, default=1) parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--n_step", type=int, default=4) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume_path", type=str, default=None) return parser.parse_args() class Wrapper(gym.Wrapper): """Env wrapper for reward scale, action repeat and removing done penalty.""" def __init__( self, env: gym.Env, action_repeat: int = 3, reward_scale: int = 5, rm_done: bool = True, ) -> None: super().__init__(env) self.action_repeat = action_repeat self.reward_scale = reward_scale self.rm_done = rm_done def step( self, action: WrapperActType, ) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]: rew_sum = 0.0 for _ in range(self.action_repeat): obs, rew, terminated, truncated, info = self.env.step(action) done = terminated | truncated # remove done reward penalty if not done or not self.rm_done: rew_sum = rew_sum + float(rew) if done: break # scale reward return obs, self.reward_scale * rew_sum, terminated, truncated, info def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: env = Wrapper(gym.make(args.task)) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action training_envs = SubprocVectorEnv( [lambda: Wrapper(gym.make(args.task)) for _ in range(args.num_training_envs)], ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [ lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False) for _ in range(args.num_test_envs) ], ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = SACPolicy( actor=actor, action_space=env.action_space, ) algorithm: SAC = SAC( policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, alpha=args.alpha, n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path)) print("Loaded agent from: ", args.resume_path) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs) # training_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec: if not env.spec.reward_threshold: return False else: return mean_rewards >= env.spec.reward_threshold return False # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_training=False, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, ) ) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": test_sac_bipedal() ================================================ FILE: examples/box2d/lunarlander_dqn.py ================================================ import argparse import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() # the parameters are found by Optuna parser.add_argument("--task", type=str, default="LunarLander-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--eps_test", type=float, default=0.01) parser.add_argument("--eps_train", type=float, default=0.73) parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.013) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--n_step", type=int, default=4) parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=80000) parser.add_argument("--collection_step_num_env_steps", type=int, default=16) parser.add_argument("--update_per_step", type=float, default=0.0625) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--dueling_q_hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--dueling_v_hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_training_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_args() def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action # training_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, dueling_param=(Q_param, V_param), ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: DQN = DQN( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec: if not env.spec.reward_threshold: return False else: return mean_rewards >= env.spec.reward_threshold return False def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) policy.set_eps_training(eps) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, training_fn=train_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": test_dqn(get_args()) ================================================ FILE: examples/box2d/mcc_sac.py ================================================ import argparse import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import SAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="MountainCarContinuous-v0") parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--buffer_size", type=int, default=50000) parser.add_argument("--actor_lr", type=float, default=3e-4) parser.add_argument("--critic_lr", type=float, default=3e-4) parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--noise_std", type=float, default=1.2) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--auto_alpha", type=int, default=1) parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--epoch", type=int, default=20) parser.add_argument("--epoch_num_steps", type=int, default=12000) parser.add_argument("--collection_step_num_env_steps", type=int, default=5) parser.add_argument("--update_per_step", type=float, default=0.2) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_training_envs", type=int, default=5) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_args() def test_sac(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action # training_envs = gym.make(args.task) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = SACPolicy( actor=actor, exploration_noise=OUNoise(0.0, args.noise_std), action_space=env.action_space, ) algorithm: SAC = SAC( policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, alpha=args.alpha, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs) # training_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec: if not env.spec.reward_threshold: return False else: return mean_rewards >= env.spec.reward_threshold return False # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": test_sac() ================================================ FILE: examples/discrete/discrete_dqn.py ================================================ import gymnasium as gym from torch.utils.tensorboard import SummaryWriter import tianshou as ts from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import CollectStats from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo def main() -> None: task = "CartPole-v1" lr, epoch, batch_size = 1e-3, 10, 64 num_training_envs, num_test_envs = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 epoch_num_steps, collection_step_num_env_steps = 10000, 10 logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn")) # TensorBoard is supported! # For other loggers, see https://tianshou.readthedocs.io/en/master/tutorials/logger.html # Create the environments # You can also try SubprocVectorEnv, which will use parallelization training_envs = ts.env.DummyVectorEnv( [lambda: gym.make(task) for _ in range(num_training_envs)] ) test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) # Create the network and optimizer # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network env = gym.make(task, render_mode="human") assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) state_shape = space_info.observation_info.obs_shape action_shape = space_info.action_info.action_shape net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) optim = AdamOptimizerFactory(lr=lr) policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=eps_train, eps_inference=eps_test, ) algorithm = ts.algorithm.DQN( policy=policy, optim=optim, gamma=gamma, n_step_return_horizon=n_step, target_update_freq=target_freq, ) training_collector = ts.data.Collector[CollectStats]( algorithm, training_envs, ts.data.VectorReplayBuffer(buffer_size, num_training_envs), exploration_noise=True, ) test_collector = ts.data.Collector[CollectStats]( algorithm, test_envs, exploration_noise=True, ) # because DQN uses epsilon-greedy method def stop_fn(mean_rewards: float) -> bool: if env.spec: if not env.spec.reward_threshold: return False else: return mean_rewards >= env.spec.reward_threshold return False result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps, stop_fn=stop_fn, logger=logger, test_in_training=True, ) ) print(f"Finished training in {result.timing.total_time} seconds") # watch performance collector = ts.data.Collector[CollectStats](algorithm, env, exploration_noise=True) collector.collect(n_episode=100, render=1 / 35) if __name__ == "__main__": main() ================================================ FILE: examples/discrete/discrete_dqn_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal from sensai.util import logging from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.env import ( EnvFactoryRegistered, VectorEnvType, ) from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig from tianshou.highlevel.params.algorithm_params import DQNParams def main( task: str = "CartPole-v1", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "sequential", ) -> None: """ Train an agent using DQN on a specified discrete task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the discrete task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OffPolicyTrainingConfig( max_epochs=10, epoch_num_steps=10000, num_training_envs=10, num_test_envs=100, buffer_size=20000, batch_size=64, collection_step_num_env_steps=10, update_step_num_gradient_steps_per_sample=1 / 10, start_timesteps=0, start_timesteps_random=False, ) env_factory = EnvFactoryRegistered( task=task, venv_type=VectorEnvType.DUMMY, training_seed=0, test_seed=10 ) hidden_sizes = (64, 64) experiment_builder = ( DQNExperimentBuilder(env_factory, experiment_config, training_config) .with_dqn_params( DQNParams( lr=1e-3, gamma=0.9, n_step_return_horizon=3, target_update_freq=320, eps_training=0.3, eps_inference=0.0, ), ) .with_model_factory_default(hidden_sizes=hidden_sizes) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/inverse/README.md ================================================ # Inverse Reinforcement Learning In inverse reinforcement learning setting, the agent learns a policy from interaction with an environment without reward and a fixed dataset which is collected with an expert policy. ## Continuous control Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. We provide implementation of GAIL algorithm for continuous control. ### Train You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `expert_buffer` of `GAILPolicy`. `irl_gail.py` is an example of inverse RL using the d4rl dataset. To train an agent with BCQ algorithm: ```bash python irl_gail.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2 ``` ## GAIL (single run) | task | best reward | reward curve | parameters | |----------------|-------------|------------------------------------------|------------------------------------------------------------------------------------------| | HalfCheetah-v2 | 5177.07 | ![](results/gail/HalfCheetah-v2_rew.png) | `python3 irl_gail.py --task "HalfCheetah-v2" --expert-data-task "halfcheetah-expert-v2"` | | Hopper-v2 | 1761.44 | ![](results/gail/Hopper-v2_rew.png) | `python3 irl_gail.py --task "Hopper-v2" --expert-data-task "hopper-expert-v2"` | | Walker2d-v2 | 2020.77 | ![](results/gail/Walker2d-v2_rew.png) | `python3 irl_gail.py --task "Walker2d-v2" --expert-data-task "walker2d-expert-v2"` | ================================================ FILE: examples/inverse/irl_gail.py ================================================ #!/usr/bin/env python3 import argparse import datetime import os import pprint from typing import SupportsFloat, cast import d4rl import gymnasium as gym import numpy as np import torch from torch import nn from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import GAIL from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import ( Batch, Collector, CollectStats, ReplayBuffer, VectorReplayBuffer, ) from tianshou.data.types import RolloutBatchProtocol from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo class NoRewardEnv(gym.RewardWrapper): """sets the reward to 0. :param gym.Env env: the environment to wrap. """ def __init__(self, env: gym.Env) -> None: super().__init__(env) def reward(self, reward: SupportsFloat) -> np.ndarray: """Set reward to 0.""" return np.zeros_like(reward) def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") parser.add_argument("--buffer_size", type=int, default=4096) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--disc_lr", type=float, default=2.5e-5) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=30000) parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) parser.add_argument("--update_step_num_repetitions", type=int, default=10) parser.add_argument("--disc_update_num", type=int, default=2) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--num_training_envs", type=int, default=64) parser.add_argument("--num_test_envs", type=int, default=10) # ppo special parser.add_argument("--return_scaling", type=int, default=True) # In theory, `vf-coef` will not make any difference if using Adam optimizer. parser.add_argument("--vf_coef", type=float, default=0.25) parser.add_argument("--ent_coef", type=float, default=0.001) parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--bound_action_method", type=str, default="clip") parser.add_argument("--lr_decay", type=int, default=True) parser.add_argument("--max_grad_norm", type=float, default=0.5) parser.add_argument("--eps_clip", type=float, default=0.2) parser.add_argument("--dual_clip", type=float, default=None) parser.add_argument("--value_clip", type=int, default=0) parser.add_argument("--advantage_normalization", type=int, default=0) parser.add_argument("--recompute_adv", type=int, default=1) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) return parser.parse_args() def test_gail(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", args.min_action, args.max_action) # training_envs = gym.make(args.task) training_envs = SubprocVectorEnv( [lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.num_training_envs)], ) training_envs = VectorEnvNormObs(training_envs) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) test_envs.set_obs_rms(training_envs.get_obs_rms()) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, ).to(args.device) net_c = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) critic = ContinuousCritic(preprocess_net=net_c).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details for m in actor.mu.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) optim = AdamOptimizerFactory(lr=args.lr) # discriminator net_d = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, concat=True, ) disc_net = ContinuousCritic(preprocess_net=net_d).to(args.device) for m in disc_net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) disc_optim = AdamOptimizerFactory(lr=args.disc_lr) if args.lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, ) ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) # expert replay buffer dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) dataset_size = dataset["rewards"].size print("dataset_size", dataset_size) expert_buffer = ReplayBuffer(dataset_size) for i in range(dataset_size): expert_buffer.add( cast( RolloutBatchProtocol, Batch( obs=dataset["observations"][i], act=dataset["actions"][i], rew=dataset["rewards"][i], done=dataset["terminals"][i], obs_next=dataset["next_observations"][i], ), ), ) print("dataset loaded") policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=True, action_bound_method=args.bound_action_method, action_space=env.action_space, ) algorithm: GAIL = GAIL( policy=policy, critic=critic, optim=optim, expert_buffer=expert_buffer, disc_net=disc_net, disc_optim=disc_optim, disc_update_num=args.disc_update_num, gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, return_scaling=args.return_scaling, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector buffer: ReplayBuffer if args.num_training_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(training_envs)) else: buffer = ReplayBuffer(args.buffer_size) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f"seed_{args.seed}_{t0}-{args.task.replace('-', '_')}_gail" log_path = os.path.join(args.logdir, args.task, "gail", log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=100, training_interval=100) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: # train result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": test_gail() ================================================ FILE: examples/modelbased/README.md ================================================ # PSRL `NChain-v0`: `python3 psrl.py --task NChain-v0 --epoch_num_steps 10 --rew-mean-prior 0 --rew-std-prior 1` `FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch_num_steps 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20` `Taxi-v3`: `python3 psrl.py --task Taxi-v3 --epoch_num_steps 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20` ================================================ FILE: examples/mujoco/README.md ================================================ # Tianshou's Mujoco Benchmark We benchmarked Tianshou algorithm implementations in 9 out of 13 environments from the MuJoCo Gym task suite[[1]](#footnote1). For each supported algorithm and supported mujoco environments, we provide: - Default hyperparameters used for benchmark and scripts to reproduce the benchmark; - A comparison of performance (or code level details) with other open source implementations or classic papers; - Graphs and raw data that can be used for research purposes[[2]](#footnote2); - Log details obtained during training[[2]](#footnote2); - Pretrained agents[[2]](#footnote2); - Some hints on how to tune the algorithm. Supported algorithms are listed below: - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) - [REINFORCE algorithm](https://papers.nips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e27b5a26f330de446fe15388bf81c3777f024fb9) - [Natural Policy Gradient](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/844d7703c313009c4c364edb4018c91de93439ca) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32) - [Hindsight Experience Replay (HER)](https://arxiv.org/abs/1707.01495) ## EnvPool We highly recommend using envpool to run the following experiments. To install, in a linux machine, type: ```bash pip install envpool ``` After that, `make_mujoco_env` will automatically switch to envpool's Mujoco env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline in average) than python vectorized env implementation, and it's behavior is consistent to gym's Mujoco env. For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/) and [Docs](https://envpool.readthedocs.io/en/latest/api/mujoco.html). ## Usage Run ```bash $ python mujoco_sac.py --task Ant-v3 ``` Logs is saved in `./log/` and can be monitored with tensorboard. ```bash $ tensorboard --logdir log ``` You can also reproduce the benchmark (e.g. SAC in Ant-v3) with the example script we provide under `examples/mujoco/`: ```bash $ ./run_experiments.sh Ant-v3 sac ``` This will start 10 experiments with different seeds. Now that all the experiments are finished, we can convert all tfevent files into csv files and then try plotting the results. ```bash # generate csv $ ./tools.py --root-dir ./results/Ant-v3/sac # generate figures $ ./plotter.py --root-dir ./results/Ant-v3 --shaded-std --legend-pattern "\\w+" # generate numerical result (support multiple groups: `--root-dir ./` instead of single dir) $ ./analysis.py --root-dir ./results --norm ``` ## Example benchmark Other graphs can be found under `examples/mujuco/benchmark/` For pretrained agents, detailed graphs (single agent, single game) and log details, please refer to [https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/](https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/). ## Offpolicy algorithms ### Notes 1. In offpolicy algorithms (DDPG, TD3, SAC), the shared hyperparameters are almost the same, and unless otherwise stated, hyperparameters are consistent with those used for benchmark in SpinningUp's implementations (e.g. we use batchsize 256 in DDPG/TD3/SAC while SpinningUp use 100. Minor difference also lies with `start-timesteps`, data loop method `collection_step_num_env_steps`, method to deal with/bootstrap truncated steps because of timelimit and unfinished/collecting episodes (contribute to performance improvement), etc.). 2. By comparison to both classic literature and open source implementations (e.g., SpinningUp)[[1]](#footnote1)[[2]](#footnote2), Tianshou's implementations of DDPG, TD3, and SAC are roughly at-parity with or better than the best reported results for these algorithms, so you can definitely use Tianshou's benchmark for research purposes. 3. We didn't compare offpolicy algorithms to OpenAI baselines [benchmark](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm), because for now it seems that they haven't provided benchmark for offpolicy algorithms, but in [SpinningUp docs](https://spinningup.openai.com/en/latest/spinningup/bench.html) they stated that "SpinningUp implementations of DDPG, TD3, and SAC are roughly at-parity with the best-reported results for these algorithms", so we think lack of comparisons with OpenAI baselines is okay. ### DDPG | Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper (DDPG)](https://arxiv.org/abs/1802.09477) | [TD3 paper (OurDDPG)](https://arxiv.org/abs/1802.09477) | |:----------------------:|:-----------------:|:--------------------------------------------------------------------------------------:|:----------------------------------------------------:|:-------------------------------------------------------:| | Ant | 990.4±4.3 | ~840 | **1005.3** | 888.8 | | HalfCheetah | **11718.7±465.6** | ~11000 | 3305.6 | 8577.3 | | Hopper | **2197.0±971.6** | ~1800 | **2020.5** | 1860.0 | | Walker2d | 1400.6±905.0 | ~1950 | 1843.6 | **3098.1** | | Swimmer | **144.1±6.5** | ~137 | N | N | | Humanoid | **177.3±77.6** | N | N | N | | Reacher | **-3.3±0.3** | N | -6.51 | -4.01 | | InvertedPendulum | **1000.0±0.0** | N | **1000.0** | **1000.0** | | InvertedDoublePendulum | 8364.3±2778.9 | N | **9355.5** | 8370.0 | \* details[[4]](#footnote4)[[5]](#footnote5)[[6]](#footnote6) ### TD3 | Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper](https://arxiv.org/abs/1802.09477) | |:----------------------:|:-----------------:|:--------------------------------------------------------------------------------------:|:---------------------------------------------:| | Ant | **5116.4±799.9** | ~3800 | 4372.4±1000.3 | | HalfCheetah | **10201.2±772.8** | ~9750 | 9637.0±859.1 | | Hopper | 3472.2±116.8 | ~2860 | **3564.1±114.7** | | Walker2d | 3982.4±274.5 | ~4000 | **4682.8±539.6** | | Swimmer | **104.2±34.2** | ~78 | N | | Humanoid | **5189.5±178.5** | N | N | | Reacher | **-2.7±0.2** | N | -3.6±0.6 | | InvertedPendulum | **1000.0±0.0** | N | **1000.0±0.0** | | InvertedDoublePendulum | **9349.2±14.3** | N | **9337.5±15.0** | \* details[[4]](#footnote4)[[5]](#footnote5)[[6]](#footnote6) #### Hints for TD3 1. TD3's learning rate is set to 3e-4 while it is 1e-3 for DDPG/SAC. However, there is NO enough evidence to support our choice of such hyperparameters (we simply choose them because SpinningUp do so) and you can try playing with those hyperparameters to see if you can improve performance. Do tell us if you can! ### SAC | Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [SAC paper](https://arxiv.org/abs/1801.01290) | |:----------------------:|:------------------:|:--------------------------------------------------------------------------------------:|:---------------------------------------------:| | Ant | **5850.2±475.7** | ~3980 | ~3720 | | HalfCheetah | **12138.8±1049.3** | ~11520 | ~10400 | | Hopper | **3542.2±51.5** | ~3150 | ~3370 | | Walker2d | **5007.0±251.5** | ~4250 | ~3740 | | Swimmer | **44.4±0.5** | ~41.7 | N | | Humanoid | **5488.5±81.2** | N | ~5200 | | Reacher | **-2.6±0.2** | N | N | | InvertedPendulum | **1000.0±0.0** | N | N | | InvertedDoublePendulum | **9359.5±0.4** | N | N | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for SAC 1. SAC's start-timesteps is set to 10000 by default while it is 25000 is DDPG/TD3. However, there is NO enough evidence to support our choice of such hyperparameters (we simply choose them because SpinningUp do so) and you can try playing with those hyperparameters to see if you can improve performance. Do tell us if you can! 2. DO NOT share the same network with two critic networks. 3. The sigma (of the Gaussian policy) should be conditioned on input. 4. The deterministic evaluation helps a lot :) ## Onpolicy Algorithms ### Notes 1. In A2C and PPO, unless otherwise stated, most hyperparameters are consistent with those used for benchmark in [ikostrikov/pytorch-a2c-ppo-acktr-gail](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail). 2. Gernally speaking, by comparison to both classic literature and open source implementations (e.g., OPENAI Baselines)[[1]](#footnote1)[[2]](#footnote2), Tianshou's implementations of REINFORCE, A2C, PPO are better than the best reported results for these algorithms, so you can definitely use Tianshou's benchmark for research purposes. ### REINFORCE | Environment | Tianshou (10M) | |:----------------------:|:-----------------:| | Ant | **1108.1±323.1** | | HalfCheetah | **1138.8±104.7** | | Hopper | **416.0±104.7** | | Walker2d | **440.9±148.2** | | Swimmer | **35.6±2.6** | | Humanoid | **464.3±58.4** | | Reacher | **-5.5±0.2** | | InvertedPendulum | **1000.0±0.0** | | InvertedDoublePendulum | **7726.2±1287.3** | | Environment | Tianshou (3M) | [Spinning Up (VPG PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html)[[7]](#footnote7) | |:----------------------:|:-----------------:|:--------------------------------------------------------------------------------------------------------------------------:| | Ant | **474.9+-133.5** | ~5 | | HalfCheetah | **884.0+-41.0** | ~600 | | Hopper | 395.8+-64.5\* | **~800** | | Walker2d | 412.0+-52.4 | **~460** | | Swimmer | 35.3+-1.4 | **~51** | | Humanoid | **438.2+-47.8** | N | | Reacher | **-10.5+-0.7** | N | | InvertedPendulum | **999.2+-2.4** | N | | InvertedDoublePendulum | **1059.7+-307.7** | N | \* details[[4]](#footnote4)[[5]](#footnote5) ### Hints for REINFORCE 1. Following [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990), we downscale last layer of policy network by a factor of 0.01 after orthogonal initialization. 2. We choose "tanh" function to squash sampled action from range (-inf, inf) to (-1, 1) rather than usually used clipping method (As in StableBaselines3). We did full scale ablation studies and results show that tanh squashing performs a tiny little bit better than clipping overall, and is much better than no action bounding. However, "clip" method is still a very good method, considering its simplicity. 3. We use global observation normalization and global rew-to-go (value) normalization by default. Both are crucial to good performance of REINFORCE algorithm. Since we minus mean when doing rew-to-go normalization, you can treat global mean of rew-to-go as a naive version of "baseline". 4. Since we do not have a value estimator, we use global rew-to-go mean to bootstrap truncated steps because of timelimit and unfinished collecting, while most other implementations use 0. We feel this would help because mean is more likely a better estimate than 0 (no ablation study has been done). 5. We have done full scale ablation study on learning rate and lr decay strategy. We experiment with lr of 3e-4, 5e-4, 1e-3, each have 2 options: no lr decay or linear decay to 0. Experiments show that 3e-4 learning rate will cause slowly learning and make agent step in local optima easily for certain environments like InvertedDoublePendulum, Ant, HalfCheetah, and 1e-3 lr helps a lot. However, after training agents with lr 1e-3 for 5M steps or so, agents in certain environments like InvertedPendulum will become unstable. Conclusion is that we should start with a large learning rate and linearly decay it, but for a small initial learning rate or if you only train agents for limited timesteps, DO NOT decay it. 6. We didn't tune `step-per-collect` option and `training-num` option. Default values are finetuned with PPO algorithm so we assume they are also good for REINFORCE. You can play with them if you want, but remember that `buffer-size` should always be larger than `step-per-collect`, and if `step-per-collect` is too small and `training-num` too large, episodes will be truncated and bootstrapped very often, which will harm performance. If `training-num` is too small ( e.g., less than 8), speed will go down. 7. Sigma of action is not fixed (normally seen in other implementation) or conditioned on observation, but is an independent parameter which can be updated by gradient descent. We choose this setting because it works well in PPO, and is recommended by [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990). See Fig. 23. ### A2C | Environment | Tianshou (3M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html) | |:----------------------:|:------------------:|:------------------------------------------------------------------------------------------:| | Ant | **5236.8+-236.7** | ~5 | | HalfCheetah | **2377.3+-1363.7** | ~600 | | Hopper | **1608.6+-529.5** | ~800 | | Walker2d | **1805.4+-1055.9** | ~460 | | Swimmer | 40.2+-1.8 | **~51** | | Humanoid | **5316.6+-554.8** | N | | Reacher | **-5.2+-0.5** | N | | InvertedPendulum | **1000.0+-0.0** | N | | InvertedDoublePendulum | **9351.3+-12.8** | N | | Environment | Tianshou (1M) | [PPO paper](https://arxiv.org/abs/1707.06347) A2C | [PPO paper](https://arxiv.org/abs/1707.06347) A2C + Trust Region | |:----------------------:|:------------------:|:-------------------------------------------------:|:----------------------------------------------------------------:| | Ant | **3485.4+-433.1** | N | N | | HalfCheetah | **1829.9+-1068.3** | ~1000 | ~930 | | Hopper | **1253.2+-458.0** | ~900 | ~1220 | | Walker2d | **1091.6+-709.2** | ~850 | ~700 | | Swimmer | **36.6+-2.1** | ~31 | **~36** | | Humanoid | **1726.0+-1070.1** | N | N | | Reacher | **-6.7+-2.3** | ~-24 | ~-27 | | InvertedPendulum | **1000.0+-0.0** | **~1000** | **~1000** | | InvertedDoublePendulum | **9257.7+-277.4** | ~7100 | ~8100 | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for A2C 1. We choose `clip` action method in A2C instead of `tanh` option as used in REINFORCE simply to be consistent with original implementation. `tanh` may be better or equally well but we didn't have a try. 2. (Initial) learning rate, lr_decay, `step-per-collect` and `training-num` affect the performance of A2C to a great extend. These 4 hyperparameters also affect each other and should be tuned together. We have done full scale ablation studies on these 4 hyperparameters (more than 800 agents have been trained). Below are our findings. 3. `step-per-collect` / `training-num` are equal to `bootstrap-lenghth`, which is the max length of an "episode" used in GAE estimator and 80/16=5 in default settings. When `bootstrap-lenghth` is small, (maybe) because GAE can look forward at most 5 steps and use bootstrap strategy very often, the critic is less well-trained leading the actor to a not very high score. However, if we increase `step-per-collect` to increase `bootstrap-lenghth` (e.g. 256/16=16), actor/critic will be updated less often, resulting in low sample efficiency and slow training process. To conclude, If you don't restrict env timesteps, you can try using larger `bootstrap-lenghth` and train with more steps to get a better converged score. Train slower, achieve higher. 4. The learning rate 7e-4 with decay strategy is appropriate for `step-per-collect=80` and `training-num=16`. But if you use a larger `step-per-collect`(e.g. 256 - 2048), 7e-4 is a little bit small for `lr` because each update will have more data, less noise and thus smaller deviation in this case. So it is more appropriate to use a higher learning rate (e.g. 1e-3) to boost performance in this setting. If plotting results arise fast in early stages and become unstable later, consider lr decay first before decreasing lr. 5. `max-grad-norm` didn't really help in our experiments. We simply keep it for consistency with other open-source implementations (e.g. SB3). 6. Although original paper of A3C uses RMSprop optimizer, we found that Adam with the same learning rate worked equally well. We use RMSprop anyway. Again, for consistency. 7. We noticed that the implementation of A2C in SB3 sets `gae-lambda` to 1 by default for no reason, and our experiments showed better results overall when `gae-lambda` was set to 0.95. 8. We found out that `step-per-collect=256` and `training-num=8` are also good settings. You can have a try. ### PPO | Environment | Tianshou (1M) | [ikostrikov/pytorch-a2c-ppo-acktr-gail](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail) | [PPO paper](https://arxiv.org/pdf/1707.06347.pdf) | [OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_ppo.html) | |:----------------------:|:------------------:|:-------------------------------------------------------------------------------------------------:|:-------------------------------------------------:|:-------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------:| | Ant | **3258.4+-1079.3** | N | N | N | ~650 | | HalfCheetah | **5783.9+-1244.0** | ~3120 | ~1800 | ~1700 | ~1670 | | Hopper | **2609.3+-700.8** | ~2300 | ~2330 | ~2400 | ~1850 | | Walker2d | 3588.5+-756.6 | **~4000** | ~3460 | ~3510 | ~1230 | | Swimmer | 66.7+-99.1 | N | ~108 | ~111 | **~120** | | Humanoid | **787.1+-193.5** | N | N | N | N | | Reacher | **-4.1+-0.3** | ~-5 | ~-7 | ~-6 | N | | InvertedPendulum | **1000.0+-0.0** | N | **~1000** | ~940 | N | | InvertedDoublePendulum | **9231.3+-270.4** | N | ~8000 | ~7350 | N | | Environment | Tianshou (3M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_ppo.html) | |:----------------------:|:------------------:|:------------------------------------------------------------------------------------------:| | Ant | **4079.3+-880.2** | ~3000 | | HalfCheetah | **7337.4+-1508.2** | ~3130 | | Hopper | **3127.7+-413.0** | ~2460 | | Walker2d | **4895.6+-704.3** | ~2600 | | Swimmer | 81.4+-96.0 | **~120** | | Humanoid | **1359.7+-572.7** | N | | Reacher | **-3.7+-0.3** | N | | InvertedPendulum | **1000.0+-0.0** | N | | InvertedDoublePendulum | **9231.3+-270.4** | N | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for PPO 1. Following [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990) Sec 3.5, we use "recompute advantage" strategy, which contributes a lot to our SOTA benchmark. However, I personally don't quite agree with their explanation about why "recompute advantage" helps. They stated that it's because old strategy "makes it impossible to compute advantages as the temporal structure is broken", but PPO's update equation is designed to learn from slightly-outdated advantages. I think the only reason "recompute advantage" works is that it update the critic several times rather than just one time per update, which leads to a better value function estimation. 2. We have done full scale ablation studies of PPO algorithm's hyperparameters. Here are our findings: In Mujoco settings, `value-clip` and `norm-adv` may help a litte bit in some games (e.g. `norm-adv` helps stabilize training in InvertedPendulum-v2), but they make no difference to overall performance. So in our benchmark we do not use such tricks. We validate that setting `ent-coef` to 0.0 rather than 0.01 will increase overall performance in mujoco environments. `max-grad-norm` still offers no help for PPO algorithm, but we still keep it for consistency. 3. [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990)'s work indicates that using `gae-lambda` 0.9 and changing policy network's width based on which game you play (e.g. use [16, 16] `hidden-sizes` for `actor` network in HalfCheetah and [256, 256] for Ant) may help boost performance. Our ablation studies say otherwise: both options may lead to equal or lower performance overall in our experiments. We are not confident about this claim because we didn't change learning rate and other maybe-correlated factors in our experiments. So if you want, you can still have a try. 4. `batch-size` 128 and 64 (default) work equally well. Changing `training-num` alone slightly (maybe in range [8, 128]) won't affect performance. For bound action method, both `clip` and `tanh` work quite well. 5. In OPENAI implementations of PPO, they multiply value loss with a factor of 0.5 for no good reason (see this [issue](https://github.com/openai/baselines/issues/445#issuecomment-777988738)). We do not do so and therefore make our `vf-coef` 0.25 (half of standard 0.5). However, since value loss is only used to optimize `critic` network, setting different `vf-coef` should in theory make no difference if using Adam optimizer. ### TRPO | Environment | Tianshou (1M) | [ACKTR paper](https://arxiv.org/pdf/1708.05144.pdf) | [PPO paper](https://arxiv.org/pdf/1707.06347.pdf) | [OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm) | [Spinning Up (Tensorflow)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | |:----------------------:|:-----------------:|:---------------------------------------------------:|:-------------------------------------------------:|:-------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------:| | Ant | **2866.7±707.9** | ~0 | N | N | ~150 | | HalfCheetah | **4471.2±804.9** | ~400 | ~0 | ~1350 | ~850 | | Hopper | 2046.0±1037.9 | ~1400 | ~2100 | **~2200** | ~1200 | | Walker2d | **3826.7±782.7** | ~550 | ~1100 | ~2350 | ~600 | | Swimmer | 40.9±19.6 | ~40 | **~121** | ~95 | ~85 | | Humanoid | **810.1±126.1** | N | N | N | N | | Reacher | **-5.1±0.8** | -8 | ~-115 | **~-5** | N | | InvertedPendulum | **1000.0±0.0** | **~1000** | **~1000** | ~910 | N | | InvertedDoublePendulum | **8435.2±1073.3** | ~800 | ~200 | ~7000 | N | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for TRPO 1. We have tried `step-per-collect` in (80, 1024, 2048, 4096), and `training-num` in (4, 16, 32, 64), and found out 1024 for `step-per-collect` (same as OpenAI Baselines) and smaller `training-num` (below 16) are good choices. Set `training-num` to 4 is actually better but we still use 16 considering the boost of training speed. 2. Advantage normalization is a standard trick in TRPO, but we found it of minor help, just like in PPO. 3. Larger `optim-critic-iters` (than 5, as used in OpenAI Baselines) helps in most environments. Smaller lr and lr_decay strategy also help a tiny little bit for performance. 4. `gae-lambda` 0.98 and 0.95 work equally well. 5. We use GAE returns (GAE advantage + value) as the target of critic network when updating, while people usually tend to use reward to go (lambda = 0.) as target. We found that they work equally well although using GAE returns is a little bit inaccurate (biased) by math. 6. Empirically, Swimmer-v3 usually requires larger bootstrap lengths and learning rate. Humanoid-v3 and InvertedPendulum-v2, however, are on the opposite. 7. In contrast, with the statement made in TRPO paper, we found that backtracking in line search is rarely used at least in Mujoco settings, which is actually unimportant. This makes TRPO algorithm actually the same as TNPG algorithm ( described in this [paper](http://proceedings.mlr.press/v48/duan16.html)). This also explains why TNPG and TRPO's plotting results look so similar in that paper. 8. "recompute advantage" is helpful in PPO but doesn't help in TRPO. ### NPG | Environment | Tianshou (1M) | |:----------------------:|:----------------:| | Ant | **2358.0±517.5** | | HalfCheetah | **3485.2±716.6** | | Hopper | **1915.2±550.5** | | Walker2d | **2503.2±963.3** | | Swimmer | **31.5±8.0** | | Humanoid | **765.1±91.3** | | Reacher | **-4.5±0.5** | | InvertedPendulum | **1000.0±0.0** | | InvertedDoublePendulum | **9243.2±276.0** | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for NPG 1. All shared hyperparameters are exactly the same as TRPO, regarding how similar these two algorithms are. 2. We found different games in Mujoco may require quite different `actor-step-size`: Reacher/Swimmer are insensitive to step-size in range (0.1~1.0), while InvertedDoublePendulum / InvertedPendulum / Humanoid are quite sensitive to step size, and even 0.1 is too large. Other games may require `actor-step-size` in range (0.1~0.4), but aren't that sensitive in general. ## Others ### HER | Environment | DDPG without HER | DDPG with HER | |:-----------:|:----------------:|:--------------:| | FetchReach | -49.9±0.2. | **-17.6±21.7** | #### Hints for HER 1. The HER technique is proposed for solving task-based environments, so it cannot be compared with non-task-based mujoco benchmarks. The environment used in this evaluation is `FetchReach-v3` which requires an extra [installation](https://github.com/Farama-Foundation/Gymnasium-Robotics). 2. Simple hyperparameters optimizations are done for both settings, DDPG with and without HER. However, since _DDPG without HER_ failed in every experiment, the best hyperparameters for _DDPG with HER_ are used in the evaluation of both settings. 3. The scores are the mean reward ± 1 standard deviation of 16 seeds. The minimum reward for `FetchReach-v3` is -50 which we can imply that _DDPG without HER_ performs as good as a random policy. _DDPG with HER_ although has a better mean reward, the standard deviation is quite high. This is because in this setting, the agent will either fail completely (-50 reward) or successfully learn the task (close to 0 reward). This means that the agent successfully learned in about 70% of the 16 seeds. ## Note [1] Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures. [2] Pretrained agents, detailed graphs (single agent, single game) and log details can all be found at [Google Drive](https://drive.google.com/drive/folders/1IycImzTmWcyEeD38viea5JHoboC4zmNP?usp=share_link). [3] We used the latest version of all mujoco environments in gym (0.17.3 with mujoco==2.0.2.13), but it's not often the case with other benchmarks. Please check for details yourself in the original paper. (Different version's outcomes are usually similar, though) [4] ~ means the number is approximated from the graph because accurate numbers is not provided in the paper. N means graphs not provided. [5] Reward metric: The meaning of the table value is the max average return over 10 trails ( different seeds) ± a single standard deviation over trails. Each trial is averaged on another 10 test seeds. Only the first 1M steps data will be considered, if not otherwise stated. The shaded region on the graph also represents a single standard deviation. It is the same as [TD3 evaluation method](https://github.com/sfujim/TD3/issues/34). [6] In TD3 paper, shaded region represents only half of standard deviation. [7] Comparing Tianshou's REINFORCE algorithm with SpinningUp's VPG is quite unfair because SpinningUp's VPG uses a generative advantage estimator (GAE) which requires a dnn value predictor (critic network), which makes so called "VPG" more like A2C (advantage actor critic) algorithm. Even so, you can see that we are roughly at-parity with each other even if tianshou's REINFORCE do not use a critic or GAE. ================================================ FILE: examples/mujoco/analysis.py ================================================ #!/usr/bin/env python3 import argparse import re from collections import defaultdict from os import PathLike import numpy as np from tabulate import tabulate from tools import csv2numpy, find_all_files, group_files def numerical_analysis(root_dir: str | PathLike, xlim: float, norm: bool = False) -> None: file_pattern = re.compile(r".*/test_reward_\d+seeds.csv$") norm_group_pattern = re.compile(r"(/|^)\w+?\-v(\d|$)") output_group_pattern = re.compile(r".*?(?=(/|^)\w+?\-v\d)") csv_files = find_all_files(root_dir, file_pattern) norm_group = group_files(csv_files, norm_group_pattern) output_group = group_files(csv_files, output_group_pattern) # calculate numerical outcome for each csv_file (y/std integration max_y, final_y) results = defaultdict(list) for f in csv_files: result = csv2numpy(f) if norm: result = np.stack( [ result["env_step"], result["reward"] - result["reward"][0], result["reward:shaded"], ], ) else: result = np.stack([result["env_step"], result["reward"], result["reward:shaded"]]) if result[0, -1] < xlim: continue final_rew = np.interp(xlim, result[0], result[1]) final_rew_std = np.interp(xlim, result[0], result[2]) result = result[:, result[0] <= xlim] if len(result) == 0: continue if result[0, -1] < xlim: last_line = np.array([xlim, final_rew, final_rew_std]).reshape(3, 1) result = np.concatenate([result, last_line], axis=-1) max_id = np.argmax(result[1]) results["name"].append(f) results["final_reward"].append(result[1, -1]) results["final_reward_std"].append(result[2, -1]) results["max_reward"].append(result[1, max_id]) results["max_std"].append(result[2, max_id]) results["reward_integration"].append(np.trapz(result[1], x=result[0])) results["reward_std_integration"].append(np.trapz(result[2], x=result[0])) results = {k: np.array(v) for k, v in results.items()} print(tabulate(results, headers="keys")) if norm: # calculate normalized numerical outcome for each csv_file group for _, fs in norm_group.items(): mask = np.isin(results["name"], fs) for k, v in results.items(): if k == "name": continue v[mask] = v[mask] / max(v[mask]) # Add all numerical results for each outcome group group_results = defaultdict(list) for g, fs in output_group.items(): group_results["name"].append(g) mask = np.isin(results["name"], fs) group_results["num"].append(sum(mask)) for k in results: if k == "name": continue group_results[k + ":norm"].append(results[k][mask].mean()) # print all outputs for each csv_file and each outcome group print() print(tabulate(group_results, headers="keys")) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--xlim", type=int, default=1000000, help="x-axis limitation (default: 1000000)", ) parser.add_argument("--root_dir", type=str) parser.add_argument( "--norm", action="store_true", help="Normalize all results according to environment.", ) args = parser.parse_args() numerical_analysis(args.root_dir, args.xlim, norm=args.norm) ================================================ FILE: examples/mujoco/fetch_her_ddpg.py ================================================ #!/usr/bin/env python3 # isort: skip_file import argparse import datetime import os import pprint import gymnasium as gym import numpy as np import torch from tianshou.algorithm import DDPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, HERReplayBuffer, HERVectorReplayBuffer, ReplayBuffer, VectorReplayBuffer, ) from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated from tianshou.env.venvs import BaseVectorEnv from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import ActionSpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="FetchReach-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--actor_lr", type=float, default=1e-3) parser.add_argument("--critic_lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--exploration_noise", type=float, default=0.1) parser.add_argument("--start_timesteps", type=int, default=25000) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--collection_step_num_env_steps", type=int, default=1) parser.add_argument("--update_per_step", type=int, default=1) parser.add_argument("--n_step", type=int, default=1) parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--replay_buffer", type=str, default="her", choices=["normal", "her"]) parser.add_argument("--her_horizon", type=int, default=50) parser.add_argument("--her_future_k", type=int, default=8) parser.add_argument("--num_training_envs", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="HER-benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) return parser.parse_args() def make_fetch_env( task: str, num_training_envs: int, num_test_envs: int, ) -> tuple[gym.Env, BaseVectorEnv, BaseVectorEnv]: env = TruncatedAsTerminated(gym.make(task)) training_envs = ShmemVectorEnv( [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(num_training_envs)], ) test_envs = ShmemVectorEnv( [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(num_test_envs)], ) return env, training_envs, test_envs def test_ddpg(args: argparse.Namespace = get_args()) -> None: # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "ddpg" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=args.resume_id, config_dict=vars(args), ) env, training_envs, test_envs = make_fetch_env( args.task, args.num_training_envs, args.num_test_envs ) # The method HER works with goal-based environments if not isinstance(env.observation_space, gym.spaces.Dict): raise ValueError( "`env.observation_space` must be of type `gym.spaces.Dict`. Make sure you're using a goal-based environment like `FetchReach-v2`.", ) if not hasattr(env, "compute_reward"): raise ValueError( "Atrribute `compute_reward` not found in `env`. " "HER-based algorithms typically require this attribute. Make sure you're using a goal-based environment like `FetchReach-v2`.", ) args.state_shape = { "observation": env.observation_space["observation"].shape, "achieved_goal": env.observation_space["achieved_goal"].shape, "desired_goal": env.observation_space["desired_goal"].shape, } action_info = ActionSpaceInfo.from_space(env.action_space) args.action_shape = action_info.action_shape args.max_action = action_info.max_action args.exploration_noise = args.exploration_noise * args.max_action print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", action_info.min_action, action_info.max_action) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # model dict_state_dec, flat_state_shape = get_dict_state_decorator( state_shape=args.state_shape, keys=["observation", "achieved_goal", "desired_goal"], ) net_a = dict_state_dec(Net)( flat_state_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) actor = dict_state_dec(ContinuousActorDeterministic)( net_a, args.action_shape, max_action=args.max_action, device=args.device, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = dict_state_dec(Net)( flat_state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) critic = dict_state_dec(ContinuousCritic)(net_c, device=args.device).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = ContinuousDeterministicPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, ) algorithm: DDPG = DDPG( policy=policy, policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: return env.compute_reward(ag, g, {}) buffer: VectorReplayBuffer | ReplayBuffer | HERReplayBuffer | HERVectorReplayBuffer if args.replay_buffer == "normal": if args.num_training_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(training_envs)) else: buffer = ReplayBuffer(args.buffer_size) else: if args.num_training_envs > 1: buffer = HERVectorReplayBuffer( args.buffer_size, len(training_envs), compute_reward_fn=compute_reward_fn, horizon=args.her_horizon, future_k=args.her_future_k, ) else: buffer = HERReplayBuffer( args.buffer_size, compute_reward_fn=compute_reward_fn, horizon=args.her_horizon, future_k=args.her_future_k, ) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) training_collector.reset() training_collector.collect(n_step=args.start_timesteps, random=True) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) collector_stats.pprint_asdict() if __name__ == "__main__": test_ddpg() ================================================ FILE: examples/mujoco/mujoco_a2c.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint from typing import Literal import numpy as np import torch from mujoco_env import make_mujoco_env from sensai.util import logging from torch import nn from torch.distributions import Distribution, Independent, Normal from tianshou.algorithm import A2C from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic log = logging.getLogger(__name__) def main( task: str = "Ant-v4", persistence_base_dir: str = "log", seed: int = 0, buffer_size: int = 4096, hidden_sizes: list | None = None, lr: float = 7e-4, gamma: float = 0.99, epoch: int = 100, epoch_num_steps: int = 30000, collection_step_num_env_steps: int = 80, update_step_num_repetitions: int = 1, batch_size: int | None = None, num_training_envs: int = 16, num_test_envs: int = 10, return_scaling: bool = True, vf_coef: float = 0.5, ent_coef: float = 0.01, gae_lambda: float = 0.95, action_bound_method: Literal["clip", "tanh"] | None = "clip", lr_decay: bool = True, max_grad_norm: float = 0.5, render: float = 0.0, device: str | None = None, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "mujoco.benchmark", watch: bool = False, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [64, 64] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config (excluding internal/temporary ones) params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_mujoco_env( task, seed, num_training_envs, num_test_envs, obs_norm=True, ) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n max_action = env.action_space.high[0] log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") log.info(f"Action range: {np.min(env.action_space.low)}, {np.max(env.action_space.high)}") # seed np.random.seed(seed) torch.manual_seed(seed) # model net_a = Net( state_shape=state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, ) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=action_shape, unbounded=True, ).to(device) net_c = Net( state_shape=state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, ) critic = ContinuousCritic(preprocess_net=net_c).to(device) actor_critic = ActorCritic(actor, critic) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details for m in actor.mu.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) optim = RMSpropOptimizerFactory( lr=lr, eps=1e-5, alpha=0.99, ) if lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, ) ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=True, action_bound_method=action_bound_method, action_space=env.action_space, ) algorithm: A2C = A2C( policy=policy, critic=critic, optim=optim, gamma=gamma, gae_lambda=gae_lambda, max_grad_norm=max_grad_norm, vf_coef=vf_coef, ent_coef=ent_coef, return_scaling=return_scaling, ) # load a previous policy if resume_path: ckpt = torch.load(resume_path, map_location=device) algorithm.load_state_dict(ckpt["model"]) training_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", resume_path) # collector buffer: VectorReplayBuffer | ReplayBuffer if num_training_envs > 1: buffer = VectorReplayBuffer(buffer_size, len(training_envs)) else: buffer = ReplayBuffer(buffer_size) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "a2c" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": training_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) if not watch: # train result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, update_step_num_repetitions=update_step_num_repetitions, test_step_num_episodes=num_test_envs, batch_size=batch_size, collection_step_num_env_steps=collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=num_test_envs, render=render) print(collector_stats) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_a2c_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal import torch from sensai.util import logging from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( A2CExperimentBuilder, ExperimentConfig, ) from tianshou.highlevel.params.algorithm_params import A2CParams from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.optim import OptimizerFactoryFactoryRMSprop def main( task: str = "Ant-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "sequential", max_epochs: int = 100, epoch_num_steps: int = 30000, ) -> None: """ Train an agent using A2C on a specified MuJoCo task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the MuJoCo task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OnPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, batch_size=None, num_training_envs=16, num_test_envs=10, buffer_size=4096, collection_step_num_env_steps=80, update_step_num_repetitions=1, ) env_factory = MujocoEnvFactory(task, obs_norm=True) hidden_sizes = (64, 64) experiment_builder = ( A2CExperimentBuilder(env_factory, experiment_config, training_config) .with_a2c_params( A2CParams( gamma=0.99, gae_lambda=0.95, action_bound_method="clip", return_scaling=True, ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, optim=OptimizerFactoryFactoryRMSprop(eps=1e-5, alpha=0.99), lr=7e-4, lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_ddpg.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import numpy as np import torch from mujoco_env import make_mujoco_env from sensai.util import logging from tianshou.algorithm import DDPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic log = logging.getLogger(__name__) def main( task: str = "Ant-v4", persistence_base_dir: str = "log", seed: int = 0, hidden_sizes: list | None = None, actor_lr: float = 1e-3, critic_lr: float = 1e-3, gamma: float = 0.99, tau: float = 0.005, exploration_noise: float = 0.1, start_timesteps: int = 25000, epoch: int = 50, epoch_num_steps: int = 5000, buffer_size: int = 1000000, collection_step_num_env_steps: int = 1, update_per_step: int = 1, n_step: int = 1, batch_size: int = 256, num_training_envs: int = 1, num_test_envs: int = 10, device: str | None = None, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "mujoco.benchmark", watch: bool = False, render: float = 0.0, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [256, 256] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_mujoco_env( task, seed, num_training_envs, num_test_envs, obs_norm=False, ) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n max_action = env.action_space.high[0] exploration_noise = exploration_noise * max_action log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") log.info(f"Action range: {np.min(env.action_space.low)}, {np.max(env.action_space.high)}") # seed np.random.seed(seed) torch.manual_seed(seed) # model net_a = Net(state_shape=state_shape, hidden_sizes=hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net_a, action_shape=action_shape, max_action=max_action ).to(device) actor_optim = AdamOptimizerFactory(lr=actor_lr) net_c = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=hidden_sizes, concat=True, ) critic = ContinuousCritic(preprocess_net=net_c).to(device) critic_optim = AdamOptimizerFactory(lr=critic_lr) policy = ContinuousDeterministicPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=exploration_noise), action_space=env.action_space, ) algorithm: DDPG = DDPG( policy=policy, policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=tau, gamma=gamma, n_step_return_horizon=n_step, ) # load a previous policy if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # collector buffer: VectorReplayBuffer | ReplayBuffer if num_training_envs > 1: buffer = VectorReplayBuffer(buffer_size, len(training_envs)) else: buffer = ReplayBuffer(buffer_size) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) training_collector.reset() training_collector.collect(n_step=start_timesteps, random=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "ddpg" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not watch: # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=num_test_envs, render=render) log.info(f"Collector stats: {collector_stats}") if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_ddpg_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal from sensai.util import logging from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( DDPGExperimentBuilder, ExperimentConfig, ) from tianshou.highlevel.params.algorithm_params import DDPGParams from tianshou.highlevel.params.noise import MaxActionScaledGaussian def main( task: str = "Ant-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "joblib", max_epochs: int = 50, epoch_num_steps: int = 5000, ) -> None: """ Train an agent using DDPG on a specified MuJoCo task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the MuJoCo task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OffPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, num_training_envs=1, num_test_envs=10, buffer_size=1000000, batch_size=256, collection_step_num_env_steps=1, update_step_num_gradient_steps_per_sample=1, start_timesteps=25000, start_timesteps_random=True, ) env_factory = MujocoEnvFactory(task, obs_norm=False) hidden_sizes = (256, 256) experiment_builder = ( DDPGExperimentBuilder(env_factory, experiment_config, training_config) .with_ddpg_params( DDPGParams( actor_lr=1e-3, critic_lr=1e-3, gamma=0.99, tau=0.005, exploration_noise=MaxActionScaledGaussian(0.1), n_step_return_horizon=1, ), ) .with_actor_factory_default(hidden_sizes) .with_critic_factory_default(hidden_sizes) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_env.py ================================================ import logging import pickle from gymnasium import Env from tianshou.env import BaseVectorEnv, VectorEnvNormObs from tianshou.highlevel.env import ( ContinuousEnvironments, EnvFactoryRegistered, EnvMode, EnvPoolFactory, VectorEnvType, ) from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent from tianshou.highlevel.world import World envpool_is_available = True try: import envpool except ImportError: envpool_is_available = False envpool = None log = logging.getLogger(__name__) def make_mujoco_env( task: str, seed: int, num_training_envs: int, num_test_envs: int, obs_norm: bool, ) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]: """Wrapper function for Mujoco env. If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. :return: a tuple of (single env, training envs, test envs). """ envs = MujocoEnvFactory(task, obs_norm=obs_norm).create_envs( num_training_envs, num_test_envs, seed=seed, ) return envs.env, envs.training_envs, envs.test_envs class MujocoEnvObsRmsPersistence(Persistence): FILENAME = "env_obs_rms.pkl" def persist(self, event: PersistEvent, world: World) -> None: if event != PersistEvent.PERSIST_POLICY: return # type: ignore[unreachable] # since PersistEvent has only one member, mypy infers that line is unreachable obs_rms = world.envs.training_envs.get_obs_rms() path = world.persist_path(self.FILENAME) log.info(f"Saving environment obs_rms value to {path}") with open(path, "wb") as f: pickle.dump(obs_rms, f) def restore(self, event: RestoreEvent, world: World) -> None: if event != RestoreEvent.RESTORE_POLICY: return # type: ignore[unreachable] path = world.restore_path(self.FILENAME) log.info(f"Restoring environment obs_rms value from {path}") with open(path, "rb") as f: obs_rms = pickle.load(f) world.envs.training_envs.set_obs_rms(obs_rms) world.envs.test_envs.set_obs_rms(obs_rms) if world.envs.watch_env is not None: world.envs.watch_env.set_obs_rms(obs_rms) class MujocoEnvFactory(EnvFactoryRegistered): def __init__( self, task: str, obs_norm: bool = True, venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO, ) -> None: super().__init__( task=task, venv_type=venv_type, envpool_factory=EnvPoolFactory() if envpool_is_available else None, ) self.obs_norm = obs_norm def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv: env = super().create_venv(num_envs, mode, seed=seed) # obs norm wrapper if self.obs_norm: env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAINING) return env def create_envs( self, num_training_envs: int, num_test_envs: int, create_watch_env: bool = False, seed: int | None = None, ) -> ContinuousEnvironments: envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env, seed=seed) assert isinstance(envs, ContinuousEnvironments) if self.obs_norm: envs.test_envs.set_obs_rms(envs.training_envs.get_obs_rms()) if envs.watch_env is not None: envs.watch_env.set_obs_rms(envs.training_envs.get_obs_rms()) envs.set_persistence(MujocoEnvObsRmsPersistence()) return envs ================================================ FILE: examples/mujoco/mujoco_npg.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint from typing import Literal import numpy as np import torch from mujoco_env import make_mujoco_env from sensai.util import logging from torch import nn from torch.distributions import Distribution, Independent, Normal from tianshou.algorithm import NPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic log = logging.getLogger(__name__) def main( task: str = "Ant-v4", persistence_base_dir: str = "log", seed: int = 0, hidden_sizes: list | None = None, lr: float = 1e-3, gamma: float = 0.99, epoch: int = 100, epoch_num_steps: int = 30000, collection_step_num_env_steps: int = 1024, batch_size: int | None = None, buffer_size: int = 4096, update_step_num_repetitions: int = 1, num_training_envs: int = 16, num_test_envs: int = 10, return_scaling: bool = True, gae_lambda: float = 0.95, bound_action_method: Literal["clip", "tanh"] | None = "clip", lr_decay: bool = True, render: float = 0.0, advantage_normalization: bool = True, optim_critic_iters: int = 20, device: str | None = None, resume_path: str | None = None, trust_region_size: float = 0.1, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "mujoco.benchmark", watch: bool = False, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [64, 64] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_mujoco_env( task, seed, num_training_envs, num_test_envs, obs_norm=True, ) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n max_action = env.action_space.high[0] log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") log.info(f"Action range: {np.min(env.action_space.low)}, {np.max(env.action_space.high)}") # seed np.random.seed(seed) torch.manual_seed(seed) # model net_a = Net( state_shape=state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, ) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=action_shape, unbounded=True, ).to(device) net_c = Net( state_shape=state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, ) critic = ContinuousCritic(preprocess_net=net_c).to(device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details for m in actor.mu.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) optim = AdamOptimizerFactory(lr=lr) if lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, ) ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=True, action_bound_method=bound_action_method, action_space=env.action_space, ) algorithm: NPG = NPG( policy=policy, critic=critic, optim=optim, gamma=gamma, gae_lambda=gae_lambda, return_scaling=return_scaling, advantage_normalization=advantage_normalization, optim_critic_iters=optim_critic_iters, trust_region_size=trust_region_size, ) # load a previous policy if resume_path: ckpt = torch.load(resume_path, map_location=device) algorithm.load_state_dict(ckpt["model"]) training_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) log.info(f"Loaded agent from: {resume_path}") # collector buffer: VectorReplayBuffer | ReplayBuffer if num_training_envs > 1: buffer = VectorReplayBuffer(buffer_size, len(training_envs)) else: buffer = ReplayBuffer(buffer_size) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "npg" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": training_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) if not watch: # train result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, update_step_num_repetitions=update_step_num_repetitions, test_step_num_episodes=num_test_envs, batch_size=batch_size, collection_step_num_env_steps=collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=num_test_envs, render=render) log.info(f"Collector stats: {collector_stats}") if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_npg_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal import torch from sensai.util import logging from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, NPGExperimentBuilder, ) from tianshou.highlevel.params.algorithm_params import NPGParams from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( task: str = "Ant-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "sequential", max_epochs: int = 100, epoch_num_steps: int = 30000, ) -> None: """ Train an agent using NPG on a specified MuJoCo task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the MuJoCo task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OnPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, batch_size=None, num_training_envs=64, num_test_envs=10, buffer_size=4096, collection_step_num_env_steps=1024, update_step_num_repetitions=1, ) env_factory = MujocoEnvFactory(task, obs_norm=True) hidden_sizes = (64, 64) experiment_builder = ( NPGExperimentBuilder(env_factory, experiment_config, training_config) .with_npg_params( NPGParams( gamma=0.99, gae_lambda=0.95, action_bound_method="clip", return_scaling=True, advantage_normalization=True, optim_critic_iters=20, trust_region_size=0.1, lr=1e-3, lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_ppo.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint from typing import Literal import numpy as np import torch from mujoco_env import make_mujoco_env from sensai.util import logging from torch import nn from torch.distributions import Distribution, Independent, Normal from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic log = logging.getLogger(__name__) def main( task: str = "Ant-v4", persistence_base_dir: str = "log", seed: int = 0, buffer_size: int = 4096, hidden_sizes: list | None = None, lr: float = 3e-4, gamma: float = 0.99, epoch: int = 100, epoch_num_steps: int = 30000, collection_step_num_env_steps: int = 2048, update_step_num_repetitions: int = 10, batch_size: int = 64, num_training_envs: int = 8, num_test_envs: int = 10, return_scaling: bool = True, vf_coef: float = 0.25, ent_coef: float = 0.0, gae_lambda: float = 0.95, bound_action_method: Literal["clip", "tanh"] | None = "clip", lr_decay: bool = True, max_grad_norm: float = 0.5, eps_clip: float = 0.2, dual_clip: float | None = None, value_clip: bool = True, advantage_normalization: bool = False, recompute_adv: bool = True, render: float = 0.0, device: str | None = None, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "mujoco.benchmark", watch: bool = False, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [64, 64] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_mujoco_env( task, seed, num_training_envs, num_test_envs, obs_norm=True, ) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n max_action = env.action_space.high[0] log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") log.info(f"Action range: {np.min(env.action_space.low)}, {np.max(env.action_space.high)}") # seed np.random.seed(seed) torch.manual_seed(seed) # model net_a = Net( state_shape=state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, ) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=action_shape, unbounded=True, ).to(device) net_c = Net( state_shape=state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, ) critic = ContinuousCritic(preprocess_net=net_c).to(device) actor_critic = ActorCritic(actor, critic) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details for m in actor.mu.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) optim = AdamOptimizerFactory(lr=lr) if lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, ) ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=True, action_bound_method=bound_action_method, action_space=env.action_space, ) algorithm: PPO = PPO( policy=policy, critic=critic, optim=optim, gamma=gamma, gae_lambda=gae_lambda, max_grad_norm=max_grad_norm, vf_coef=vf_coef, ent_coef=ent_coef, return_scaling=return_scaling, eps_clip=eps_clip, value_clip=value_clip, dual_clip=dual_clip, advantage_normalization=advantage_normalization, recompute_advantage=recompute_adv, ) # load a previous policy if resume_path: ckpt = torch.load(resume_path, map_location=device) algorithm.load_state_dict(ckpt["model"]) training_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) log.info(f"Loaded agent from: {resume_path}") # collector buffer: VectorReplayBuffer | ReplayBuffer if num_training_envs > 1: buffer = VectorReplayBuffer(buffer_size, len(training_envs)) else: buffer = ReplayBuffer(buffer_size) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "ppo" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": training_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) if not watch: # train result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, update_step_num_repetitions=update_step_num_repetitions, test_step_num_episodes=num_test_envs, batch_size=batch_size, collection_step_num_env_steps=collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=num_test_envs, render=render) log.info(f"Collector stats: {collector_stats}") if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_ppo_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal import torch from sensai.util import logging from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, ) from tianshou.highlevel.params.algorithm_params import PPOParams from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( task: str = "Ant-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "sequential", max_epochs: int = 100, epoch_num_steps: int = 30000, ) -> None: """ Train an agent using PPO on a specified MuJoCo task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the MuJoCo task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OnPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, batch_size=64, num_training_envs=64, num_test_envs=10, buffer_size=4096, collection_step_num_env_steps=2048, update_step_num_repetitions=1, ) env_factory = MujocoEnvFactory(task, obs_norm=True) hidden_sizes = (64, 64) experiment_builder = ( PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( gamma=0.99, gae_lambda=0.95, action_bound_method="clip", return_scaling=True, ent_coef=0.0, vf_coef=0.25, max_grad_norm=0.5, value_clip=False, advantage_normalization=False, eps_clip=0.2, dual_clip=None, recompute_advantage=True, lr=3e-4, lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_redq.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint from typing import Literal import numpy as np import torch from mujoco_env import make_mujoco_env from sensai.util import logging from tianshou.algorithm import REDQ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.redq import REDQPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic log = logging.getLogger(__name__) def main( task: str = "Ant-v4", persistence_base_dir: str = "log", seed: int = 0, buffer_size: int = 1000000, hidden_sizes: list | None = None, ensemble_size: int = 10, subset_size: int = 2, actor_lr: float = 1e-3, critic_lr: float = 1e-3, gamma: float = 0.99, tau: float = 0.005, alpha: float = 0.2, auto_alpha: bool = False, alpha_lr: float = 3e-4, start_timesteps: int = 10000, epoch: int = 50, epoch_num_steps: int = 5000, collection_step_num_env_steps: int = 1, update_per_step: int = 20, n_step: int = 1, batch_size: int = 256, target_mode: Literal["min", "mean"] = "min", num_training_envs: int = 1, num_test_envs: int = 10, render: float = 0.0, device: str | None = None, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "mujoco.benchmark", watch: bool = False, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [256, 256] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_mujoco_env( task, seed, num_training_envs, num_test_envs, obs_norm=False, ) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n max_action = env.action_space.high[0] log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") log.info(f"Action range: {np.min(env.action_space.low)}, {np.max(env.action_space.high)}") # seed np.random.seed(seed) torch.manual_seed(seed) # model net_a = Net(state_shape=state_shape, hidden_sizes=hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=action_shape, unbounded=True, conditioned_sigma=True, ).to(device) actor_optim = AdamOptimizerFactory(lr=actor_lr) def linear(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(ensemble_size, x, y) net_c = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=hidden_sizes, concat=True, linear_layer=linear, ) critics = ContinuousCritic( preprocess_net=net_c, linear_layer=linear, flatten_input=False, ).to(device) critics_optim = AdamOptimizerFactory(lr=critic_lr) if auto_alpha: target_entropy = -np.prod(env.action_space.shape) log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=alpha_lr) alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(device) # type: ignore policy = REDQPolicy( actor=actor, action_space=env.action_space, ) algorithm: REDQ = REDQ( policy=policy, policy_optim=actor_optim, critic=critics, critic_optim=critics_optim, ensemble_size=ensemble_size, subset_size=subset_size, tau=tau, gamma=gamma, alpha=alpha, n_step_return_horizon=n_step, actor_delay=update_per_step, target_mode=target_mode, ) # load a previous policy if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # collector buffer: VectorReplayBuffer | ReplayBuffer if num_training_envs > 1: buffer = VectorReplayBuffer(buffer_size, len(training_envs)) else: buffer = ReplayBuffer(buffer_size) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) training_collector.reset() training_collector.collect(n_step=start_timesteps, random=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "redq" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not watch: # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=num_test_envs, render=render) log.info(f"Collector stats: {collector_stats}") if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_redq_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal from sensai.util import logging from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, REDQExperimentBuilder, ) from tianshou.highlevel.params.algorithm_params import REDQParams def main( task: str = "Ant-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "joblib", max_epochs: int = 50, epoch_num_steps: int = 5000, ) -> None: """ Train an agent using REDQ on a specified MuJoCo task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the MuJoCo task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OffPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, num_training_envs=1, num_test_envs=10, buffer_size=1000000, batch_size=256, collection_step_num_env_steps=1, update_step_num_gradient_steps_per_sample=20, start_timesteps=10000, start_timesteps_random=True, ) env_factory = MujocoEnvFactory(task, obs_norm=False) hidden_sizes = (256, 256) experiment_builder = ( REDQExperimentBuilder(env_factory, experiment_config, training_config) .with_redq_params( REDQParams( actor_lr=1e-3, critic_lr=1e-3, gamma=0.99, tau=0.005, alpha=0.2, n_step_return_horizon=1, target_mode="min", subset_size=2, ensemble_size=10, ), ) .with_actor_factory_default(hidden_sizes) .with_critic_ensemble_factory_default(hidden_sizes) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_reinforce.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint from typing import Literal import numpy as np import torch from mujoco_env import make_mujoco_env from sensai.util import logging from torch import nn from torch.distributions import Distribution, Independent, Normal from tianshou.algorithm import Reinforce from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic log = logging.getLogger(__name__) def main( task: str = "Ant-v4", persistence_base_dir: str = "log", seed: int = 0, hidden_sizes: list | None = None, lr: float = 1e-3, gamma: float = 0.99, epoch: int = 100, epoch_num_steps: int = 30000, collection_step_num_env_steps: int = 2048, update_step_num_repetitions: int = 1, batch_size: int | None = None, buffer_size: int = 4096, num_training_envs: int = 10, num_test_envs: int = 10, return_scaling: bool = True, action_bound_method: Literal["clip", "tanh"] | None = "tanh", lr_decay: bool = True, device: str | None = None, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "mujoco.benchmark", watch: bool = False, render: float = 0.0, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [64, 64] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_mujoco_env( task, seed, num_training_envs, num_test_envs, obs_norm=True, ) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n max_action = env.action_space.high[0] log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") log.info(f"Action range: {np.min(env.action_space.low)}, {np.max(env.action_space.high)}") # seed np.random.seed(seed) torch.manual_seed(seed) # model net_a = Net( state_shape=state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, ) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=action_shape, unbounded=True, ).to(device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in actor.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details for m in actor.mu.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) optim = AdamOptimizerFactory(lr=lr) if lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, ) ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, action_scaling=True, action_bound_method=action_bound_method, ) algorithm: Reinforce = Reinforce( policy=policy, optim=optim, gamma=gamma, return_standardization=return_scaling, ) # load a previous policy if resume_path: ckpt = torch.load(resume_path, map_location=device) algorithm.load_state_dict(ckpt["model"]) training_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) log.info(f"Loaded agent from: {resume_path}") # collector buffer: VectorReplayBuffer | ReplayBuffer if num_training_envs > 1: buffer = VectorReplayBuffer(buffer_size, len(training_envs)) else: buffer = ReplayBuffer(buffer_size) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "reinforce" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": training_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) if not watch: # train result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, update_step_num_repetitions=update_step_num_repetitions, test_step_num_episodes=num_test_envs, batch_size=batch_size, collection_step_num_env_steps=collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=num_test_envs, render=render) log.info(f"Collector stats: {collector_stats}") if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_reinforce_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal import torch from sensai.util import logging from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, ReinforceExperimentBuilder, ) from tianshou.highlevel.params.algorithm_params import ReinforceParams from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( task: str = "Ant-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "sequential", max_epochs: int = 100, epoch_num_steps: int = 30000, ) -> None: """ Train an agent using REINFORCE on a specified MuJoCo task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the MuJoCo task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OnPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, batch_size=None, num_training_envs=64, num_test_envs=10, buffer_size=4096, collection_step_num_env_steps=2048, update_step_num_repetitions=1, ) env_factory = MujocoEnvFactory(task, obs_norm=True) hidden_sizes = (64, 64) experiment_builder = ( ReinforceExperimentBuilder(env_factory, experiment_config, training_config) .with_reinforce_params( ReinforceParams( gamma=0.99, action_bound_method="tanh", return_standardization=True, lr=1e-3, lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_sac.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import numpy as np import torch from mujoco_env import make_mujoco_env from sensai.util import logging from tianshou.algorithm import SAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic log = logging.getLogger(__name__) def main( task: str = "Ant-v4", persistence_base_dir: str = "log", seed: int = 0, buffer_size: int = 1000000, hidden_sizes: list | None = None, actor_lr: float = 1e-3, critic_lr: float = 1e-3, gamma: float = 0.99, tau: float = 0.005, alpha: float = 0.2, auto_alpha: bool = False, alpha_lr: float = 3e-4, start_timesteps: int = 10000, epoch: int = 50, epoch_num_steps: int = 5000, collection_step_num_env_steps: int = 1, update_per_step: int = 1, n_step: int = 1, batch_size: int = 256, num_training_envs: int = 1, num_test_envs: int = 10, render: float = 0.0, device: str | None = None, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "mujoco.benchmark", watch: bool = False, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [256, 256] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_mujoco_env( task, seed, num_training_envs, num_test_envs, obs_norm=False, ) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n max_action = env.action_space.high[0] log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") log.info(f"Action range: {np.min(env.action_space.low)}, {np.max(env.action_space.high)}") # seed np.random.seed(seed) torch.manual_seed(seed) # model net_a = Net(state_shape=state_shape, hidden_sizes=hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=action_shape, unbounded=True, conditioned_sigma=True, ).to(device) actor_optim = AdamOptimizerFactory(lr=actor_lr) net_c1 = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=hidden_sizes, concat=True, ) net_c2 = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=hidden_sizes, concat=True, ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(device) critic1_optim = AdamOptimizerFactory(lr=critic_lr) critic2 = ContinuousCritic(preprocess_net=net_c2).to(device) critic2_optim = AdamOptimizerFactory(lr=critic_lr) if auto_alpha: target_entropy = -np.prod(env.action_space.shape) log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=alpha_lr) alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(device) # type: ignore policy = SACPolicy( actor=actor, action_space=env.action_space, ) algorithm: SAC = SAC( policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=tau, gamma=gamma, alpha=alpha, n_step_return_horizon=n_step, ) # load a previous policy if resume_path: algorithm.load_state_dict(torch.load(resume_path, map_location=device)) log.info(f"Loaded agent from: {resume_path}") # collector buffer: VectorReplayBuffer | ReplayBuffer if num_training_envs > 1: buffer = VectorReplayBuffer(buffer_size, len(training_envs)) else: buffer = ReplayBuffer(buffer_size) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) training_collector.reset() training_collector.collect(n_step=start_timesteps, random=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "sac" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not watch: # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=num_test_envs, render=render) log.info(collector_stats) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_sac_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal from sensai.util import logging from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, SACExperimentBuilder, ) from tianshou.highlevel.params.algorithm_params import SACParams def main( task: str = "Ant-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "joblib", max_epochs: int = 50, epoch_num_steps: int = 5000, ) -> None: """ Train an agent using SAC on a specified MuJoCo task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the MuJoCo task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OffPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, num_training_envs=1, num_test_envs=10, buffer_size=1000000, batch_size=256, collection_step_num_env_steps=1, update_step_num_gradient_steps_per_sample=1, start_timesteps=10000, start_timesteps_random=True, ) env_factory = MujocoEnvFactory(task, obs_norm=False) hidden_sizes = (256, 256) experiment_builder = ( SACExperimentBuilder(env_factory, experiment_config, training_config) .with_sac_params( SACParams( tau=0.005, gamma=0.99, alpha=0.2, n_step_return_horizon=1, actor_lr=1e-3, critic1_lr=1e-3, critic2_lr=1e-3, ), ) .with_actor_factory_default( hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True, ) .with_common_critic_factory_default(hidden_sizes) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_td3.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint import numpy as np import torch from mujoco_env import make_mujoco_env from sensai.util import logging from tianshou.algorithm import TD3 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic log = logging.getLogger(__name__) def main( task: str = "Ant-v4", persistence_base_dir: str = "log", seed: int = 0, hidden_sizes: list | None = None, actor_lr: float = 3e-4, critic_lr: float = 3e-4, gamma: float = 0.99, tau: float = 0.005, exploration_noise: float = 0.1, policy_noise: float = 0.2, noise_clip: float = 0.5, update_actor_freq: int = 2, start_timesteps: int = 25000, epoch: int = 50, epoch_num_steps: int = 5000, collection_step_num_env_steps: int = 1, update_per_step: int = 1, n_step: int = 1, batch_size: int = 256, buffer_size: int = 1000000, num_training_envs: int = 1, num_test_envs: int = 10, device: str | None = None, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "mujoco.benchmark", watch: bool = False, render: float = 0.0, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [256, 256] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_mujoco_env( task, seed, num_training_envs, num_test_envs, obs_norm=False, ) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n max_action = env.action_space.high[0] exploration_noise = exploration_noise * max_action policy_noise = policy_noise * max_action log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") log.info(f"Action range: {np.min(env.action_space.low)}, {np.max(env.action_space.high)}") # seed np.random.seed(seed) torch.manual_seed(seed) # model net_a = Net(state_shape=state_shape, hidden_sizes=hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net_a, action_shape=action_shape, max_action=max_action ).to(device) actor_optim = AdamOptimizerFactory(lr=actor_lr) net_c1 = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=hidden_sizes, concat=True, ) net_c2 = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=hidden_sizes, concat=True, ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(device) critic1_optim = AdamOptimizerFactory(lr=critic_lr) critic2 = ContinuousCritic(preprocess_net=net_c2).to(device) critic2_optim = AdamOptimizerFactory(lr=critic_lr) policy = ContinuousDeterministicPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=exploration_noise), action_space=env.action_space, ) algorithm: TD3 = TD3( policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=tau, gamma=gamma, policy_noise=policy_noise, update_actor_freq=update_actor_freq, noise_clip=noise_clip, n_step_return_horizon=n_step, ) # load a previous policy if resume_path: log.info(f"Loaded agent from: {resume_path}") # collector buffer: VectorReplayBuffer | ReplayBuffer if num_training_envs > 1: buffer = VectorReplayBuffer(buffer_size, len(training_envs)) else: buffer = ReplayBuffer(buffer_size) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) training_collector.reset() training_collector.collect(n_step=start_timesteps, random=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "td3" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not watch: # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=num_test_envs, batch_size=batch_size, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=update_per_step, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=num_test_envs, render=render) log.info(collector_stats) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_td3_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal import torch from sensai.util import logging from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, TD3ExperimentBuilder, ) from tianshou.highlevel.params.algorithm_params import TD3Params from tianshou.highlevel.params.env_param import MaxActionScaled from tianshou.highlevel.params.noise import ( MaxActionScaledGaussian, ) def main( task: str = "Ant-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "joblib", max_epochs: int = 50, epoch_num_steps: int = 5000, ) -> None: """ Train an agent using TD3 on a specified MuJoCo task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the MuJoCo task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OffPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, num_training_envs=1, num_test_envs=10, buffer_size=1000000, batch_size=256, collection_step_num_env_steps=1, update_step_num_gradient_steps_per_sample=1, start_timesteps=25000, start_timesteps_random=True, ) env_factory = MujocoEnvFactory(task, obs_norm=False) hidden_sizes = (256, 256) experiment_builder = ( TD3ExperimentBuilder(env_factory, experiment_config, training_config) .with_td3_params( TD3Params( tau=0.005, gamma=0.99, n_step_return_horizon=1, update_actor_freq=2, noise_clip=MaxActionScaled(0.5), policy_noise=MaxActionScaled(0.2), exploration_noise=MaxActionScaledGaussian(0.1), actor_lr=3e-4, critic1_lr=3e-4, critic2_lr=3e-4, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh) .with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_trpo.py ================================================ #!/usr/bin/env python3 import datetime import os import pprint from typing import Literal import numpy as np import torch from mujoco_env import make_mujoco_env from sensai.util import logging from torch import nn from torch.distributions import Distribution, Independent, Normal from tianshou.algorithm import TRPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic log = logging.getLogger(__name__) def main( task: str = "Ant-v4", persistence_base_dir: str = "log", seed: int = 0, hidden_sizes: list | None = None, lr: float = 1e-3, max_backtracks: int = 10, buffer_size: int = 4096, update_step_num_repetitions: int = 1, gamma: float = 0.99, epoch: int = 100, epoch_num_steps: int = 30000, collection_step_num_env_steps: int = 1024, batch_size: int | None = None, num_training_envs: int = 16, num_test_envs: int = 10, return_scaling: bool = True, gae_lambda: float = 0.95, action_bound_method: Literal["clip", "tanh"] | None = "clip", lr_decay: bool = True, render: float = 0.0, advantage_normalization: bool = True, optim_critic_iters: int = 20, max_kl: float = 0.01, backtrack_coeff: float = 0.8, device: str | None = None, resume_path: str | None = None, resume_id: str | None = None, logger_type: str = "tensorboard", wandb_project: str = "mujoco.benchmark", watch: bool = False, ) -> None: # Set defaults for mutable arguments if hidden_sizes is None: hidden_sizes = [64, 64] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Get all local variables as config params_log_info = locals() log.info(f"Starting training with config:\n{params_log_info}") env, training_envs, test_envs = make_mujoco_env( task, seed, num_training_envs, num_test_envs, obs_norm=True, ) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n log.info(f"Observations shape: {state_shape}") log.info(f"Actions shape: {action_shape}") log.info(f"Action range: {np.min(env.action_space.low)}, {np.max(env.action_space.high)}") # seed np.random.seed(seed) torch.manual_seed(seed) # model net_a = Net( state_shape=state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, ) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=action_shape, unbounded=True, ).to(device) net_c = Net( state_shape=state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, ) critic = ContinuousCritic(preprocess_net=net_c).to(device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details for m in actor.mu.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) optim = AdamOptimizerFactory(lr=lr) if lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, ) ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=True, action_bound_method=action_bound_method, action_space=env.action_space, ) algorithm: TRPO = TRPO( policy=policy, critic=critic, optim=optim, gamma=gamma, gae_lambda=gae_lambda, return_scaling=return_scaling, advantage_normalization=advantage_normalization, optim_critic_iters=optim_critic_iters, max_kl=max_kl, backtrack_coeff=backtrack_coeff, max_backtracks=max_backtracks, ) # load a previous policy if resume_path: ckpt = torch.load(resume_path, map_location=device) algorithm.load_state_dict(ckpt["model"]) training_envs.set_obs_rms(ckpt["obs_rms"]) log.info(f"Loaded agent from: {resume_path}") # collector buffer: VectorReplayBuffer | ReplayBuffer if num_training_envs > 1: buffer = VectorReplayBuffer(buffer_size, len(training_envs)) else: buffer = ReplayBuffer(buffer_size) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") algo_name = "trpo" log_name = os.path.join(task, algo_name, str(seed), now) log_path = os.path.join(persistence_base_dir, log_name) # logger logger_factory = LoggerFactoryDefault() if logger_type == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=resume_id, config_dict=params_log_info, ) def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": training_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) if not watch: # trainer result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=epoch, epoch_num_steps=epoch_num_steps, update_step_num_repetitions=update_step_num_repetitions, test_step_num_episodes=num_test_envs, batch_size=batch_size, collection_step_num_env_steps=collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_training=False, ) ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=num_test_envs, render=render) log.info(collector_stats) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/mujoco_trpo_hl.py ================================================ #!/usr/bin/env python3 import os from typing import Literal import torch from sensai.util import logging from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, TRPOExperimentBuilder, ) from tianshou.highlevel.params.algorithm_params import TRPOParams from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( task: str = "Ant-v4", persistence_base_dir: str = "log", num_experiments: int = 1, experiment_launcher: Literal["sequential", "joblib"] = "sequential", max_epochs: int = 100, epoch_num_steps: int = 30000, ) -> None: """ Train an agent using TRPO on a specified MuJoCo task, potentially running multiple experiments with different seeds and evaluating the results using rliable. :param task: the MuJoCo task to train on. :param persistence_base_dir: the base directory for logging and saving experiment data, the task name will be appended to it. :param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param experiment_launcher: the type of experiment launcher to use, only has an effect if `num_experiments>1`. You can use "joblib" for parallel execution of whole experiments. :param max_epochs: the maximum number of training epochs. :param epoch_num_steps: the number of environment steps per epoch. """ persistence_base_dir = os.path.abspath(os.path.join(persistence_base_dir, task)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_base_dir, watch=False) training_config = OnPolicyTrainingConfig( max_epochs=max_epochs, epoch_num_steps=epoch_num_steps, batch_size=None, num_training_envs=16, num_test_envs=10, buffer_size=4096, collection_step_num_env_steps=1024, update_step_num_repetitions=1, ) env_factory = MujocoEnvFactory(task, obs_norm=True) hidden_sizes = (64, 64) experiment_builder = ( TRPOExperimentBuilder(env_factory, experiment_config, training_config) .with_trpo_params( TRPOParams( gamma=0.99, gae_lambda=0.95, action_bound_method="clip", return_standardization=True, advantage_normalization=True, optim_critic_iters=20, max_kl=0.01, backtrack_coeff=0.8, max_backtracks=10, lr=1e-3, lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) ) experiment_builder.build_and_run(num_experiments=num_experiments, launcher=experiment_launcher) if __name__ == "__main__": result = logging.run_cli(main, level=logging.INFO) ================================================ FILE: examples/mujoco/plotter.py ================================================ #!/usr/bin/env python3 import argparse import os import re from typing import Any, Literal import matplotlib.pyplot as plt import matplotlib.ticker as mticker import numpy as np from tools import csv2numpy, find_all_files, group_files def smooth( y: np.ndarray, radius: int, mode: Literal["two_sided", "causal"] = "two_sided", valid_only: bool = False, ) -> np.ndarray: """Smooth signal y, where radius is determines the size of the window. mode='twosided': average over the window [max(index - radius, 0), min(index + radius, len(y)-1)] mode='causal': average over the window [max(index - radius, 0), index] valid_only: put nan in entries where the full-sized window is not available """ if len(y) < 2 * radius + 1: return np.ones_like(y) * y.mean() if mode == "two_sided": convkernel = np.ones(2 * radius + 1) out = np.convolve(y, convkernel, mode="same") / np.convolve( np.ones_like(y), convkernel, mode="same", ) if valid_only: out[:radius] = out[-radius:] = np.nan elif mode == "causal": convkernel = np.ones(radius) out = np.convolve(y, convkernel, mode="full") / np.convolve( np.ones_like(y), convkernel, mode="full", ) out = out[: -radius + 1] if valid_only: out[:radius] = np.nan return out COLORS = [ # deepmind style "#0072B2", "#009E73", "#D55E00", "#CC79A7", # '#F0E442', "#d73027", # RED # built-in color "blue", "red", "pink", "cyan", "magenta", "yellow", "black", "purple", "brown", "orange", "teal", "lightblue", "lime", "lavender", "turquoise", "darkgreen", "tan", "salmon", "gold", "darkred", "darkblue", "green", # personal color "#313695", # DARK BLUE "#74add1", # LIGHT BLUE "#f46d43", # ORANGE "#4daf4a", # GREEN "#984ea3", # PURPLE "#f781bf", # PINK "#ffc832", # YELLOW "#000000", # BLACK ] def plot_ax( ax: plt.Axes, file_lists: list[str], legend_pattern: str = ".*", xlabel: str | None = None, ylabel: str | None = None, title: str = "", xlim: float | None = None, xkey: str = "env_step", ykey: str = "reward", smooth_radius: int = 0, shaded_std: bool = True, legend_outside: bool = False, ) -> None: def legend_fn(x: str) -> str: # return os.path.split(os.path.join( # args.root_dir, x))[0].replace('/', '_') + " (10)" match = re.search(legend_pattern, x) assert match is not None # for mypy return match.group(0) legneds = map(legend_fn, file_lists) # sort filelist according to legends file_lists = [f for _, f in sorted(zip(legneds, file_lists, strict=True))] legneds = list(map(legend_fn, file_lists)) for index, csv_file in enumerate(file_lists): csv_dict = csv2numpy(csv_file) x, y = csv_dict[xkey], csv_dict[ykey] y = smooth(y, radius=smooth_radius) color = COLORS[index % len(COLORS)] ax.plot(x, y, color=color) if shaded_std and ykey + ":shaded" in csv_dict: y_shaded = smooth(csv_dict[ykey + ":shaded"], radius=smooth_radius) ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=0.2) ax.legend( legneds, loc=2 if legend_outside else None, bbox_to_anchor=(1, 1) if legend_outside else None, ) ax.xaxis.set_major_formatter(mticker.EngFormatter()) if xlim is not None: ax.set_xlim(xmin=0, xmax=xlim) # add title ax.set_title(title) # add labels if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) def plot_figure( file_lists: list[str], group_pattern: str | None = None, fig_length: int = 6, fig_width: int = 6, sharex: bool = False, sharey: bool = False, title: str = "", **kwargs: Any, ) -> None: if not group_pattern: fig, ax = plt.subplots(figsize=(fig_length, fig_width)) plot_ax(ax, file_lists, title=title, **kwargs) else: res = group_files(file_lists, group_pattern) row_n = int(np.ceil(len(res) / 3)) col_n = min(len(res), 3) fig, axes = plt.subplots( row_n, col_n, sharex=sharex, sharey=sharey, figsize=(fig_length * col_n, fig_width * row_n), squeeze=False, ) axes = axes.flatten() for i, (k, v) in enumerate(res.items()): plot_ax(axes[i], v, title=k, **kwargs) if title: # add title fig.suptitle(title, fontsize=20) if __name__ == "__main__": parser = argparse.ArgumentParser(description="plotter") parser.add_argument( "--fig_length", type=int, default=6, help="matplotlib figure length (default: 6)", ) parser.add_argument( "--fig_width", type=int, default=6, help="matplotlib figure width (default: 6)", ) parser.add_argument( "--style", default="seaborn", help="matplotlib figure style (default: seaborn)", ) parser.add_argument("--title", default=None, help="matplotlib figure title (default: None)") parser.add_argument( "--xkey", default="env_step", help="x-axis key in csv file (default: env_step)", ) parser.add_argument("--ykey", default="rew", help="y-axis key in csv file (default: rew)") parser.add_argument( "--smooth", type=int, default=0, help="smooth radius of y axis (default: 0)", ) parser.add_argument("--xlabel", default="Timesteps", help="matplotlib figure xlabel") parser.add_argument("--ylabel", default="Episode Reward", help="matplotlib figure ylabel") parser.add_argument( "--shaded_std", action="store_true", help="shaded region corresponding to standard deviation of the group", ) parser.add_argument( "--sharex", action="store_true", help="whether to share x axis within multiple sub-figures", ) parser.add_argument( "--sharey", action="store_true", help="whether to share y axis within multiple sub-figures", ) parser.add_argument( "--legend_outside", action="store_true", help="place the legend outside of the figure", ) parser.add_argument("--xlim", type=int, default=None, help="x-axis limitation (default: None)") parser.add_argument("--root_dir", default="./", help="root dir (default: ./)") parser.add_argument( "--file_pattern", type=str, default=r".*/test_rew_\d+seeds.csv$", help="regular expression to determine whether or not to include target csv " "file, default to including all test_rew_{num}seeds.csv file under rootdir", ) parser.add_argument( "--group_pattern", type=str, default=r"(/|^)\w*?\-v(\d|$)", help="regular expression to group files in sub-figure, default to grouping " 'according to env_name dir, "" means no grouping', ) parser.add_argument( "--legend_pattern", type=str, default=r".*", help="regular expression to extract legend from csv file path, default to " "using file path as legend name.", ) parser.add_argument("--show", action="store_true", help="show figure") parser.add_argument("--output_path", type=str, help="figure save path", default="./figure.png") parser.add_argument("--dpi", type=int, default=200, help="figure dpi (default: 200)") args = parser.parse_args() file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern)) file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists] if args.style: plt.style.use(args.style) os.chdir(args.root_dir) plot_figure( file_lists, group_pattern=args.group_pattern, legend_pattern=args.legend_pattern, fig_length=args.fig_length, fig_width=args.fig_width, title=args.title, xlabel=args.xlabel, ylabel=args.ylabel, xkey=args.xkey, ykey=args.ykey, xlim=args.xlim, sharex=args.sharex, sharey=args.sharey, smooth_radius=args.smooth, shaded_std=args.shaded_std, legend_outside=args.legend_outside, ) if args.output_path: plt.savefig(args.output_path, dpi=args.dpi, bbox_inches="tight") if args.show: plt.show() ================================================ FILE: examples/mujoco/tools.py ================================================ #!/usr/bin/env python3 import argparse import csv import os import re from collections import defaultdict from os import PathLike from re import Pattern from typing import Any import numpy as np import tqdm from tensorboard.backend.event_processing import event_accumulator def find_all_files(root_dir: str | PathLike[str], pattern: str | Pattern[str]) -> list: """Find all files under root_dir according to relative pattern.""" file_list = [] for dirname, _, files in os.walk(root_dir): for f in files: absolute_path = os.path.join(dirname, f) if re.match(pattern, absolute_path): file_list.append(absolute_path) return file_list def group_files(file_list: list[str], pattern: str | Pattern[str]) -> dict[str, list]: res = defaultdict(list) for f in file_list: match = re.search(pattern, f) key = match.group() if match else "" res[key].append(f) return res def csv2numpy(csv_file: str) -> dict[Any, np.ndarray]: csv_dict = defaultdict(list) with open(csv_file) as f: for row in csv.DictReader(f): for k, v in row.items(): csv_dict[k].append(eval(v)) return {k: np.array(v) for k, v in csv_dict.items()} def convert_tfevents_to_csv( root_dir: str | PathLike[str], refresh: bool = False, ) -> dict[str, list]: """Recursively convert test/reward from all tfevent file under root_dir to csv. This function assumes that there is at most one tfevents file in each directory and will add suffix to that directory. :param bool refresh: re-create csv file under any condition. """ tfevent_files = find_all_files(root_dir, re.compile(r"^.*tfevents.*$")) print(f"Converting {len(tfevent_files)} tfevents files under {root_dir} ...") result = {} with tqdm.tqdm(tfevent_files) as t: for tfevent_file in t: t.set_postfix(file=tfevent_file) output_file = os.path.join(os.path.split(tfevent_file)[0], "test_reward.csv") if os.path.exists(output_file) and not refresh: with open(output_file) as f: content = list(csv.reader(f)) if content[0] == ["env_step", "reward", "time"]: for i in range(1, len(content)): content[i] = list(map(eval, content[i])) result[output_file] = content continue ea = event_accumulator.EventAccumulator(tfevent_file) ea.Reload() initial_time = ea._first_event_timestamp content = [["env_step", "reward", "time"]] for test_reward in ea.scalars.Items("test/reward"): content.append( [ round(test_reward.step, 4), round(test_reward.value, 4), round(test_reward.wall_time - initial_time, 4), ], ) with open(output_file, "w") as f: csv.writer(f).writerows(content) result[output_file] = content return result def merge_csv( csv_files: dict[str, list], root_dir: str | PathLike[str], remove_zero: bool = False, ) -> None: """Merge result in csv_files into a single csv file.""" assert len(csv_files) > 0 if remove_zero: for v in csv_files.values(): if v[1][0] == 0: v.pop(1) sorted_keys = sorted(csv_files.keys()) sorted_values = [csv_files[k][1:] for k in sorted_keys] content = [ [ "env_step", "reward", "reward:shaded", *["reward:" + os.path.relpath(f, root_dir) for f in sorted_keys], ], ] for rows in zip(*sorted_values, strict=True): array = np.array(rows) assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0]) line = [rows[0][0], round(array[:, 1].mean(), 4), round(array[:, 1].std(), 4)] line += array[:, 1].tolist() content.append(line) output_path = os.path.join(root_dir, f"test_reward_{len(csv_files)}seeds.csv") print(f"Output merged csv file to {output_path} with {len(content[1:])} lines.") with open(output_path, "w") as f: csv.writer(f).writerows(content) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--refresh", action="store_true", help="Re-generate all csv files instead of using existing one.", ) parser.add_argument( "--remove_zero", action="store_true", help="Remove the data point of env_step == 0.", ) parser.add_argument("--root_dir", type=str) args = parser.parse_args() csv_files = convert_tfevents_to_csv(args.root_dir, args.refresh) merge_csv(csv_files, args.root_dir, args.remove_zero) ================================================ FILE: examples/offline/README.md ================================================ # Offline In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore. ## Continuous control Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. We provide implementation of BCQ and CQL algorithm for continuous control. ### Train Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `d4rl_bcq.py` is an example of offline RL using the d4rl dataset. ## Results ### IL (Imitation Learning, aka, Behavior Cloning) | Environment | Dataset | IL | Parameters | |----------------|-----------------------|----------|-------------------------------------------------------------------------------------| | HalfCheetah-v2 | halfcheetah-expert-v2 | 11355.31 | `python3 d4rl_il.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` | | HalfCheetah-v2 | halfcheetah-medium-v2 | 5098.16 | `python3 d4rl_il.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` | ### BCQ | Environment | Dataset | BCQ | Parameters | |----------------|-----------------------|----------|--------------------------------------------------------------------------------------| | HalfCheetah-v2 | halfcheetah-expert-v2 | 11509.95 | `python3 d4rl_bcq.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` | | HalfCheetah-v2 | halfcheetah-medium-v2 | 5147.43 | `python3 d4rl_bcq.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` | ### CQL | Environment | Dataset | CQL | Parameters | |----------------|-----------------------|---------|--------------------------------------------------------------------------------------| | HalfCheetah-v2 | halfcheetah-expert-v2 | 2864.37 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` | | HalfCheetah-v2 | halfcheetah-medium-v2 | 6505.41 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` | ### TD3+BC | Environment | Dataset | CQL | Parameters | |----------------|-----------------------|----------|-----------------------------------------------------------------------------------------| | HalfCheetah-v2 | halfcheetah-expert-v2 | 11788.25 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` | | HalfCheetah-v2 | halfcheetah-medium-v2 | 5741.13 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` | #### Observation normalization Following the original paper, we use observation normalization by default. You can turn it off by setting `--norm-obs 0`. The difference are small but consistent. | Dataset | w/ norm-obs | w/o norm-obs | |:---------------------|:------------|:-------------| | halfcheeta-medium-v2 | 5741.13 | 5724.41 | | halfcheeta-expert-v2 | 11788.25 | 11665.77 | | walker2d-medium-v2 | 4051.76 | 3985.59 | | walker2d-expert-v2 | 5068.15 | 5027.75 | ## Discrete control For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. ### Gather Data To running CQL algorithm on Atari, you need to do the following things: - Train an expert, by using the command listed in the QRDQN section of Atari examples: `python3 atari_qrdqn.py --task {your_task}` - Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` ( note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); - Train offline model: `python3 atari_{bcq,cql,crr}.py --task {your_task} --load-buffer-name expert.hdf5`. ### IL We test our IL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): | Task | Online QRDQN | Behavioral | IL | parameters | |------------------------|--------------|------------|-----------------------------------|--------------------------------------------------------------------------------------------------------------------------------| | PongNoFrameskip-v4 | 20.5 | 6.8 | 20.0 (epoch 5) | `python3 atari_il.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | | BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 121.9 (epoch 12, could be higher) | `python3 atari_il.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` | ### BCQ We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): | Task | Online QRDQN | Behavioral | BCQ | parameters | |------------------------|--------------|------------|----------------------------------|---------------------------------------------------------------------------------------------------------------------------------| | PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | | BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` | ### CQL We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): | Task | Online QRDQN | Behavioral | CQL | parameters | |------------------------|--------------|------------|------------------|---------------------------------------------------------------------------------------------------------------------------------------------------| | PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | | BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | We reduce the size of the offline data to 10% and 1% of the above and get: Buffer size 100000: | Task | Online QRDQN | Behavioral | CQL | parameters | |------------------------|--------------|------------|-----------------|------------------------------------------------------------------------------------------------------------------------------------------------------------| | PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` | | BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` | Buffer size 10000: | Task | Online QRDQN | Behavioral | CQL | parameters | |------------------------|--------------|------------|-----------------|------------------------------------------------------------------------------------------------------------------------------------------------------------| | PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` | | BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` | ### CRR We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): | Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters | |------------------------|--------------|------------|-----------------|-----------------|---------------------------------------------------------------------------------------------------------------------------------------------------| | PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | | BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps. ### RL Unplugged Data We provide a script to convert the Atari datasets of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged) to Tianshou ReplayBuffer. For example, the following command will download the first shard of the first run of Breakout game to `~/.rl_unplugged/datasets/Breakout/run_1-00001-of-00100` then convert it to a `tianshou.data.ReplayBuffer` and save it to `~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5` (use `--dataset-dir` and `--buffer-dir` to change the default directories): ```bash python3 convert_rl_unplugged_atari.py --task Breakout --run-id 1 --shard-id 1 ``` Then you can use it to train an agent by: ```bash python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/datasets/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12 ``` Note: - Each shard contains about 500k transitions. - This conversion script depends on Tensorflow. - It takes about 1 hour to process one shard on my machine. YMMV. ================================================ FILE: examples/offline/atari_bcq.py ================================================ #!/usr/bin/env python3 import argparse import datetime import os import pickle import pprint import sys import numpy as np import torch from gymnasium.spaces import Discrete from examples.offline.utils import load_buffer from tianshou.algorithm import DiscreteBCQ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OfflineTrainerParams from tianshou.utils.net.discrete import DiscreteActor def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=6.25e-5) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--n_step", type=int, default=1) parser.add_argument("--target_update_freq", type=int, default=8000) parser.add_argument("--unlikely_action_threshold", type=float, default=0.3) parser.add_argument("--imitation_logits_penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update_per_epoch", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--frames_stack", type=int, default=4) parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def main(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, args.seed, 1, args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) assert isinstance(env.action_space, Discrete) args.state_shape = env.observation_space.shape args.action_shape = int(env.action_space.n) # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # model assert args.state_shape is not None assert len(args.state_shape) == 3 c, h, w = args.state_shape feature_net = DQNet( c=c, h=h, w=w, action_shape=args.action_shape, features_only=True, ).to(args.device) policy_net = DiscreteActor( preprocess_net=feature_net, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax_output=False, ).to(args.device) imitation_net = DiscreteActor( preprocess_net=feature_net, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax_output=False, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) # define policy and algorithm policy = DiscreteBCQPolicy( model=policy_net, imitator=imitation_net, action_space=env.action_space, unlikely_action_threshold=args.unlikely_action_threshold, eps_inference=args.eps_test, ) algorithm: DiscreteBCQ = DiscreteBCQ( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, imitation_logits_penalty=args.imitation_logits_penalty, ) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: buffer = load_buffer(args.load_buffer_name) else: assert os.path.exists( args.load_buffer_name, ), "Please run atari_dqn.py first to get expert's data buffer." if args.load_buffer_name.endswith(".pkl"): with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) elif args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: print(f"Unknown buffer format: {args.load_buffer_name}") sys.exit(0) print("Replay buffer size:", len(buffer), flush=True) # collector test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "bcq" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=args.resume_id, config_dict=vars(args), ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return False # watch agent's performance def watch() -> None: print("Setup test envs ...") policy.set_eps(args.eps_test) test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: watch() sys.exit(0) result = algorithm.run_training( OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.update_per_epoch, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, ) ) pprint.pprint(result) watch() if __name__ == "__main__": main(get_args()) ================================================ FILE: examples/offline/atari_cql.py ================================================ #!/usr/bin/env python3 import argparse import datetime import os import pickle import pprint import sys from collections.abc import Sequence import numpy as np import torch from gymnasium.spaces import Discrete from examples.offline.utils import load_buffer from tianshou.algorithm import DiscreteCQL from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import QRDQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--num_quantiles", type=int, default=200) parser.add_argument("--n_step", type=int, default=1) parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--min_q_weight", type=float, default=10.0) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update_per_epoch", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--frames_stack", type=int, default=4) parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def main(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, args.seed, 1, args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) assert isinstance(env.action_space, Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape assert isinstance(args.state_shape, Sequence) assert len(args.state_shape) == 3, "state shape must have only 3 dimensions." c, h, w = args.state_shape # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # model net = QRDQNet( c=c, h=h, w=w, action_shape=args.action_shape, num_quantiles=args.num_quantiles, ) optim = AdamOptimizerFactory(lr=args.lr) # define policy policy = QRDQNPolicy( model=net, action_space=env.action_space, ) algorithm: DiscreteCQL = DiscreteCQL( policy=policy, optim=optim, gamma=args.gamma, num_quantiles=args.num_quantiles, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, min_q_weight=args.min_q_weight, ).to(args.device) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: buffer = load_buffer(args.load_buffer_name) else: assert os.path.exists( args.load_buffer_name, ), "Please run atari_dqn.py first to get expert's data buffer." if args.load_buffer_name.endswith(".pkl"): with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) elif args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: print(f"Unknown buffer format: {args.load_buffer_name}") sys.exit(0) print("Replay buffer size:", len(buffer), flush=True) # collector test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "cql" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=args.resume_id, config_dict=vars(args), ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return False # watch agent's performance def watch() -> None: print("Setup test envs ...") test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: watch() sys.exit(0) result = algorithm.run_training( OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.update_per_epoch, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, ) ) pprint.pprint(result) watch() if __name__ == "__main__": main(get_args()) ================================================ FILE: examples/offline/atari_crr.py ================================================ #!/usr/bin/env python3 import argparse import datetime import os import pickle import pprint import sys import numpy as np import torch from gymnasium.spaces import Discrete from examples.offline.utils import load_buffer from tianshou.algorithm import DiscreteCRR from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OfflineTrainerParams from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--policy_improvement_mode", type=str, default="exp") parser.add_argument("--ratio_upper_bound", type=float, default=20.0) parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--min_q_weight", type=float, default=10.0) parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update_per_epoch", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--frames_stack", type=int, default=4) parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def main(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, args.seed, 1, args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) assert isinstance(env.action_space, Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = env.observation_space.shape args.action_shape = space_info.action_info.action_shape # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # model assert args.state_shape is not None assert len(args.state_shape) == 3 c, h, w = args.state_shape feature_net = DQNet( c=c, h=h, w=w, action_shape=args.action_shape, features_only=True, ).to(args.device) actor = DiscreteActor( preprocess_net=feature_net, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax_output=False, ).to(args.device) critic = DiscreteCritic( preprocess_net=feature_net, hidden_sizes=args.hidden_sizes, last_size=int(np.prod(args.action_shape)), ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) # define policy and algorithm policy = DiscreteActorPolicy( actor=actor, action_space=env.action_space, ) algorithm: DiscreteCRR = DiscreteCRR( policy=policy, critic=critic, optim=optim, gamma=args.gamma, policy_improvement_mode=args.policy_improvement_mode, ratio_upper_bound=args.ratio_upper_bound, beta=args.beta, min_q_weight=args.min_q_weight, target_update_freq=args.target_update_freq, ).to(args.device) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: buffer = load_buffer(args.load_buffer_name) else: assert os.path.exists( args.load_buffer_name, ), "Please run atari_dqn.py first to get expert's data buffer." if args.load_buffer_name.endswith(".pkl"): with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) elif args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: print(f"Unknown buffer format: {args.load_buffer_name}") sys.exit(0) print("Replay buffer size:", len(buffer), flush=True) # collector test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "crr" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=args.resume_id, config_dict=vars(args), ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return False # watch agent's performance def watch() -> None: print("Setup test envs ...") test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: watch() sys.exit(0) result = algorithm.run_training( OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.update_per_epoch, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, ) ) pprint.pprint(result) watch() if __name__ == "__main__": main(get_args()) ================================================ FILE: examples/offline/atari_il.py ================================================ #!/usr/bin/env python3 import argparse import datetime import os import pickle import pprint import sys import numpy as np import torch from examples.offline.utils import load_buffer from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.imitation_base import ( ImitationPolicy, OfflineImitationLearning, ) from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update_per_epoch", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--frames_stack", type=int, default=4) parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_il(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, args.seed, 1, args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape assert isinstance(args.state_shape, tuple | list) assert len(args.state_shape) == 3 c, h, w = args.state_shape # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # model net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) # define policy policy = ImitationPolicy(actor=net, action_space=env.action_space) algorithm: OfflineImitationLearning = OfflineImitationLearning( policy=policy, optim=optim, ) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: buffer = load_buffer(args.load_buffer_name) else: assert os.path.exists( args.load_buffer_name, ), "Please run atari_dqn.py first to get expert's data buffer." if args.load_buffer_name.endswith(".pkl"): with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) elif args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: print(f"Unknown buffer format: {args.load_buffer_name}") sys.exit(0) print("Replay buffer size:", len(buffer), flush=True) # collector test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "il" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=args.resume_id, config_dict=vars(args), ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return False # watch agent's performance def watch() -> None: print("Setup test envs ...") test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: watch() sys.exit(0) result = algorithm.run_training( OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.update_per_epoch, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, ) ) pprint.pprint(result) watch() if __name__ == "__main__": test_il(get_args()) ================================================ FILE: examples/offline/convert_rl_unplugged_atari.py ================================================ #!/usr/bin/env python3 # # Adapted from # https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py # """Convert Atari RL Unplugged datasets to HDF5 format. Examples in the dataset represent SARSA transitions stored during a DQN training run as described in https://arxiv.org/pdf/1907.04543. For every training run we have recorded all 50 million transitions corresponding to 200 million environment steps (4x factor because of frame skipping). There are 5 separate datasets for each of the 45 games. Every transition in the dataset is a tuple containing the following features: * o_t: Observation at time t. Observations have been processed using the canonical Atari frame processing, including 4x frame stacking. The shape of a single observation is [84, 84, 4]. * a_t: Action taken at time t. * r_t: Reward after a_t. * d_t: Discount after a_t. * o_tp1: Observation at time t+1. * a_tp1: Action at time t+1. * extras: * episode_id: Episode identifier. * episode_return: Total episode return computed using per-step [-1, 1] clipping. """ import os from argparse import ArgumentParser, Namespace import h5py import numpy as np import numpy.typing as npt import requests import tensorflow as tf from tqdm import tqdm from tianshou.data import Batch tf.config.set_visible_devices([], "GPU") # 9 tuning games. TUNING_SUITE = [ "BeamRider", "DemonAttack", "DoubleDunk", "IceHockey", "MsPacman", "Pooyan", "RoadRunner", "Robotank", "Zaxxon", ] # 36 testing games. TESTING_SUITE = [ "Alien", "Amidar", "Assault", "Asterix", "Atlantis", "BankHeist", "BattleZone", "Boxing", "Breakout", "Carnival", "Centipede", "ChopperCommand", "CrazyClimber", "Enduro", "FishingDerby", "Freeway", "Frostbite", "Gopher", "Gravitar", "Hero", "Jamesbond", "Kangaroo", "Krull", "KungFuMaster", "NameThisGame", "Phoenix", "Pong", "Qbert", "Riverraid", "Seaquest", "SpaceInvaders", "StarGunner", "TimePilot", "UpNDown", "VideoPinball", "WizardOfWor", "YarsRevenge", ] # Total of 45 games. ALL_GAMES = TUNING_SUITE + TESTING_SUITE URL_PREFIX = "http://storage.googleapis.com/rl_unplugged/atari" def _filename(run_id: int, shard_id: int, total_num_shards: int = 100) -> str: return f"run_{run_id}-{shard_id:05d}-of-{total_num_shards:05d}" def _decode_frames(pngs: tf.Tensor) -> tf.Tensor: """Decode PNGs. :param pngs: String Tensor of size (4,) containing PNG encoded images. :returns: Tensor of size (4, 84, 84) containing decoded images. """ # Statically unroll png decoding frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)] # NOTE: to match tianshou's convention for framestacking frames = tf.squeeze(tf.stack(frames, axis=0)) frames.set_shape((4, 84, 84)) return frames def _make_tianshou_batch( o_t: tf.Tensor, a_t: tf.Tensor, r_t: tf.Tensor, d_t: tf.Tensor, o_tp1: tf.Tensor, a_tp1: tf.Tensor, ) -> Batch: """Create Tianshou batch with offline data. :param o_t: Observation at time t. :param a_t: Action at time t. :param r_t: Reward at time t. :param d_t: Discount at time t. :param o_tp1: Observation at time t+1. :param a_tp1: Action at time t+1. :returns: A tianshou.data.Batch object. """ return Batch( obs=o_t.numpy(), act=a_t.numpy(), rew=r_t.numpy(), done=1 - d_t.numpy(), obs_next=o_tp1.numpy(), ) def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch: """Create a tianshou Batch replay sample from a TF example.""" # Parse tf.Example. feature_description = { "o_t": tf.io.FixedLenFeature([4], tf.string), "o_tp1": tf.io.FixedLenFeature([4], tf.string), "a_t": tf.io.FixedLenFeature([], tf.int64), "a_tp1": tf.io.FixedLenFeature([], tf.int64), "r_t": tf.io.FixedLenFeature([], tf.float32), "d_t": tf.io.FixedLenFeature([], tf.float32), "episode_id": tf.io.FixedLenFeature([], tf.int64), "episode_return": tf.io.FixedLenFeature([], tf.float32), } data = tf.io.parse_single_example(tf_example, feature_description) # Process data. o_t = _decode_frames(data["o_t"]) o_tp1 = _decode_frames(data["o_tp1"]) a_t = tf.cast(data["a_t"], tf.int32) a_tp1 = tf.cast(data["a_tp1"], tf.int32) # Build tianshou Batch replay sample. return _make_tianshou_batch(o_t, a_t, data["r_t"], data["d_t"], o_tp1, a_tp1) # Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 def download(url: str, fname: str, chunk_size: int | None = 1024) -> None: resp = requests.get(url, stream=True) total = int(resp.headers.get("content-length", 0)) if os.path.exists(fname): print(f"Found cached file at {fname}.") return with ( open(fname, "wb") as ofile, tqdm( desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024, ) as bar, ): for data in resp.iter_content(chunk_size=chunk_size): size = ofile.write(data) bar.update(size) def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None: download(url, fname) obs: npt.NDArray[np.uint8] = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") act: npt.NDArray[np.int64] = np.ndarray((maxsize,), dtype="int64") rew: npt.NDArray[np.float32] = np.ndarray((maxsize,), dtype="float32") done: npt.NDArray[np.bool_] = np.ndarray((maxsize,), dtype="bool") obs_next: npt.NDArray[np.uint8] = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") i = 0 file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP") for example in file_ds: if i >= maxsize: break batch = _tf_example_to_tianshou_batch(example) obs[i], act[i], rew[i], done[i], obs_next[i] = ( batch.obs, batch.act, batch.rew, batch.done, batch.obs_next, ) i += 1 if i % 1000 == 0: print(f"...{i}", end="", flush=True) print("\nDataset size:", i) # Following D4RL dataset naming conventions with h5py.File(ofname, "w") as f: f.create_dataset("observations", data=obs, compression="gzip") f.create_dataset("actions", data=act, compression="gzip") f.create_dataset("rewards", data=rew, compression="gzip") f.create_dataset("terminals", data=done, compression="gzip") f.create_dataset("next_observations", data=obs_next, compression="gzip") def process_dataset( task: str, download_path: str, dst_path: str, run_id: int = 1, shard_id: int = 0, total_num_shards: int = 100, ) -> None: fn = f"{task}/{_filename(run_id, shard_id, total_num_shards=total_num_shards)}" url = f"{URL_PREFIX}/{fn}" filepath = f"{download_path}/{fn}" ofname = f"{dst_path}/{fn}.hdf5" process_shard(url, filepath, ofname) def main(args: Namespace) -> None: if args.task not in ALL_GAMES: raise KeyError(f"`{args.task}` is not in the list of games.") fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards) dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5") if os.path.exists(dataset_path): raise OSError(f"Found existing dataset at {dataset_path}. Will not overwrite.") args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir) args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir) cache_path = os.path.join(args.cache_dir, args.task) os.makedirs(cache_path, exist_ok=True) dst_path = os.path.join(args.dataset_dir, args.task) os.makedirs(dst_path, exist_ok=True) process_dataset( args.task, args.cache_dir, args.dataset_dir, run_id=args.run_id, shard_id=args.shard_id, total_num_shards=args.total_num_shards, ) if __name__ == "__main__": parser = ArgumentParser(usage=__doc__) parser.add_argument("--task", required=True, help="Name of the Atari game.") parser.add_argument( "--run_id", type=int, default=1, help="Run id to download and convert. Value in [1..5].", ) parser.add_argument( "--shard_id", type=int, default=0, help="Shard id to download and convert. Value in [0..99].", ) parser.add_argument("--total_num_shards", type=int, default=100, help="Total number of shards.") parser.add_argument( "--dataset_dir", default=os.path.expanduser("~/.rl_unplugged/datasets"), help="Directory for converted hdf5 files.", ) parser.add_argument( "--cache_dir", default=os.path.expanduser("~/.rl_unplugged/cache"), help="Directory for downloaded original datasets.", ) args = parser.parse_args() main(args) ================================================ FILE: examples/offline/d4rl_bcq.py ================================================ #!/usr/bin/env python3 import argparse import datetime import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl from tianshou.algorithm import BCQ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.bcq import BCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") parser.add_argument("--buffer_size", type=int, default=1000000) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--actor_lr", type=float, default=1e-3) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--start_timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument("--vae_hidden_sizes", type=int, nargs="*", default=[512, 512]) # default to 2 * action_dim parser.add_argument("--latent_dim", type=int) parser.add_argument("--gamma", default=0.99) parser.add_argument("--tau", default=0.005) # Weighting for Clipped Double Q-learning in BCQ parser.add_argument("--lmbda", default=0.75) # Max perturbation hyper-parameter for BCQ parser.add_argument("--phi", default=0.05) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) return parser.parse_args() def test_bcq() -> None: args = get_args() env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", args.min_action, args.max_action) args.state_dim = space_info.observation_info.obs_dim args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model # perturbation network net_a = MLP( input_dim=args.state_dim + args.action_dim, output_dim=args.action_dim, hidden_sizes=args.hidden_sizes, ) actor = Perturbation(preprocess_net=net_a, max_action=args.max_action, phi=args.phi).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # vae # output_dim = 0, so the last Module in the encoder is ReLU vae_encoder = MLP( input_dim=args.state_dim + args.action_dim, hidden_sizes=args.vae_hidden_sizes, ) if not args.latent_dim: args.latent_dim = args.action_dim * 2 vae_decoder = MLP( input_dim=args.state_dim + args.latent_dim, output_dim=args.action_dim, hidden_sizes=args.vae_hidden_sizes, ) vae = VAE( encoder=vae_encoder, decoder=vae_decoder, hidden_dim=args.vae_hidden_sizes[-1], latent_dim=args.latent_dim, max_action=args.max_action, ).to(args.device) vae_optim = AdamOptimizerFactory() policy = BCQPolicy( actor_perturbation=actor, action_space=env.action_space, critic=critic1, vae=vae, ) algorithm: BCQ = BCQ( policy=policy, actor_perturbation_optim=actor_optim, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, vae_optim=vae_optim, gamma=args.gamma, tau=args.tau, lmbda=args.lmbda, ).to(args.device) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "bcq" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger: WandbLogger | TensorboardLogger if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) logger.load(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) # train result = algorithm.run_training( OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, ) ) pprint.pprint(result) else: watch() # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": test_bcq() ================================================ FILE: examples/offline/d4rl_cql.py ================================================ #!/usr/bin/env python3 import argparse import datetime import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl from tianshou.algorithm import CQL from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "--task", type=str, default="Hopper-v2", help="The name of the OpenAI Gym environment to train on.", ) parser.add_argument( "--seed", type=int, default=0, help="The random seed to use.", ) parser.add_argument( "--expert_data_task", type=str, default="hopper-expert-v2", help="The name of the OpenAI Gym environment to use for expert data collection.", ) parser.add_argument( "--buffer_size", type=int, default=1000000, help="The size of the replay buffer.", ) parser.add_argument( "--hidden_sizes", type=int, nargs="*", default=[256, 256], help="The list of hidden sizes for the neural networks.", ) parser.add_argument( "--actor_lr", type=float, default=1e-4, help="The learning rate for the actor network.", ) parser.add_argument( "--critic_lr", type=float, default=3e-4, help="The learning rate for the critic network.", ) parser.add_argument( "--alpha", type=float, default=0.2, help="The weight of the entropy term in the loss function.", ) parser.add_argument( "--auto_alpha", default=True, action="store_true", help="Whether to use automatic entropy tuning.", ) parser.add_argument( "--alpha_lr", type=float, default=1e-4, help="The learning rate for the entropy tuning.", ) parser.add_argument( "--cql_alpha_lr", type=float, default=3e-4, help="The learning rate for the CQL entropy tuning.", ) parser.add_argument( "--start_timesteps", type=int, default=10000, help="The number of timesteps before starting to train.", ) parser.add_argument( "--epoch", type=int, default=200, help="The number of epochs to train for.", ) parser.add_argument( "--epoch_num_steps", type=int, default=5000, help="The number of steps per epoch.", ) parser.add_argument( "--n_step", type=int, default=3, help="The number of steps to use for N-step TD learning.", ) parser.add_argument( "--batch_size", type=int, default=256, help="The batch size for training.", ) parser.add_argument( "--tau", type=float, default=0.005, help="The soft target update coefficient.", ) parser.add_argument( "--temperature", type=float, default=1.0, help="The temperature for the Boltzmann policy.", ) parser.add_argument( "--cql_weight", type=float, default=1.0, help="The weight of the CQL loss term.", ) parser.add_argument( "--with_lagrange", type=bool, default=True, help="Whether to use the Lagrange multiplier for CQL.", ) parser.add_argument( "--calibrated", type=bool, default=True, help="Whether to use calibration for CQL.", ) parser.add_argument( "--lagrange_threshold", type=float, default=10.0, help="The Lagrange multiplier threshold for CQL.", ) parser.add_argument("--gamma", type=float, default=0.99, help="The discount factor") parser.add_argument( "--eval_freq", type=int, default=1, help="The frequency of evaluation.", ) parser.add_argument( "--num_test_envs", type=int, default=10, help="The number of episodes to evaluate for.", ) parser.add_argument( "--logdir", type=str, default="log", help="The directory to save logs to.", ) parser.add_argument( "--render", type=float, default=1 / 35, help="The frequency of rendering the environment.", ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="The device to train on (cpu or cuda).", ) parser.add_argument( "--resume_path", type=str, default=None, help="The path to the checkpoint to resume from.", ) parser.add_argument( "--resume_id", type=str, default=None, help="The ID of the checkpoint to resume from.", ) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) return parser.parse_args() def test_cql() -> None: args = get_args() env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action args.min_action = space_info.action_info.min_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", args.min_action, args.max_action) args.state_dim = space_info.observation_info.obs_dim args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model # actor network net_a = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, ) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, conditioned_sigma=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: target_entropy = -args.action_dim log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = SACPolicy( actor=actor, action_space=env.action_space, ) algorithm: CQL = CQL( policy=policy, policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, critic2=critic2, critic2_optim=critic2_optim, calibrated=args.calibrated, cql_alpha_lr=args.cql_alpha_lr, cql_weight=args.cql_weight, tau=args.tau, gamma=args.gamma, alpha=args.alpha, temperature=args.temperature, with_lagrange=args.with_lagrange, lagrange_threshold=args.lagrange_threshold, min_action=args.min_action, max_action=args.max_action, ).to(args.device) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "cql" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger: WandbLogger | TensorboardLogger if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) logger.load(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) # train result = algorithm.run_training( OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, ) ) pprint.pprint(result) else: watch() # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": test_cql() ================================================ FILE: examples/offline/d4rl_il.py ================================================ #!/usr/bin/env python3 import argparse import datetime import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.imitation_base import ( ImitationPolicy, OfflineImitationLearning, ) from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument("--gamma", default=0.99) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) return parser.parse_args() def test_il() -> None: args = get_args() env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", args.min_action, args.max_action) args.state_dim = space_info.observation_info.obs_dim args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, ) actor = ContinuousActorDeterministic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = ImitationPolicy( actor=actor, action_space=env.action_space, action_scaling=True, action_bound_method="clip", ) algorithm: OfflineImitationLearning = OfflineImitationLearning( policy=policy, optim=optim, ) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "cql" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger: WandbLogger | TensorboardLogger if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) logger.load(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) # train result = algorithm.run_training( OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, ) ) pprint.pprint(result) else: watch() # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": test_il() ================================================ FILE: examples/offline/d4rl_td3_bc.py ================================================ #!/usr/bin/env python3 import argparse import datetime import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer from tianshou.algorithm import TD3BC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs from tianshou.exploration import GaussianNoise from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") parser.add_argument("--buffer_size", type=int, default=1000000) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--actor_lr", type=float, default=3e-4) parser.add_argument("--critic_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--alpha", type=float, default=2.5) parser.add_argument("--exploration_noise", type=float, default=0.1) parser.add_argument("--policy_noise", type=float, default=0.2) parser.add_argument("--noise_clip", type=float, default=0.5) parser.add_argument("--update_actor_freq", type=int, default=2) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--norm_obs", type=int, default=1) parser.add_argument("--eval_freq", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) return parser.parse_args() def test_td3_bc() -> None: args = get_args() env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action args.min_action = space_info.action_info.min_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", args.min_action, args.max_action) args.state_dim = space_info.observation_info.obs_dim args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) test_envs: BaseVectorEnv test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) if args.norm_obs: test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model # actor network net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, ) actor = ContinuousActorDeterministic( preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = ContinuousDeterministicPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, ) algorithm: TD3BC = TD3BC( policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, alpha=args.alpha, n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "td3_bc" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger: WandbLogger | TensorboardLogger if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) logger.load(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) if args.norm_obs: replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer) test_envs.set_obs_rms(obs_rms) # train result = algorithm.run_training( OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, ) ) pprint.pprint(result) else: watch() # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": test_td3_bc() ================================================ FILE: examples/offline/utils.py ================================================ import d4rl import gymnasium as gym import h5py import numpy as np from tianshou.data import ReplayBuffer from tianshou.utils import RunningMeanStd def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer: dataset = d4rl.qlearning_dataset(gym.make(expert_data_task)) return ReplayBuffer.from_data( obs=dataset["observations"], act=dataset["actions"], rew=dataset["rewards"], done=dataset["terminals"], obs_next=dataset["next_observations"], terminated=dataset["terminals"], truncated=np.zeros(len(dataset["terminals"])), ) def load_buffer(buffer_path: str) -> ReplayBuffer: with h5py.File(buffer_path, "r") as dataset: return ReplayBuffer.from_data( obs=dataset["observations"], act=dataset["actions"], rew=dataset["rewards"], done=dataset["terminals"], obs_next=dataset["next_observations"], terminated=dataset["terminals"], truncated=np.zeros(len(dataset["terminals"])), ) def normalize_all_obs_in_replay_buffer( replay_buffer: ReplayBuffer, ) -> tuple[ReplayBuffer, RunningMeanStd]: # compute obs mean and var obs_rms = RunningMeanStd() obs_rms.update(replay_buffer.obs) _eps = np.finfo(np.float32).eps.item() # normalize obs replay_buffer._meta["obs"] = (replay_buffer.obs - obs_rms.mean) / np.sqrt(obs_rms.var + _eps) replay_buffer._meta["obs_next"] = (replay_buffer.obs_next - obs_rms.mean) / np.sqrt( obs_rms.var + _eps, ) return replay_buffer, obs_rms ================================================ FILE: examples/vizdoom/.gitignore ================================================ _vizdoom.ini ================================================ FILE: examples/vizdoom/README.md ================================================ # ViZDoom [ViZDoom](https://github.com/mwydmuch/ViZDoom) is a popular RL env for a famous first-person shooting game Doom. Here we provide some results and intuitions for this scenario. ## EnvPool We highly recommend using envpool to run the following experiments. To install, in a linux machine, type: ```bash pip install envpool ``` After that, `make_vizdoom_env` will automatically switch to envpool's ViZDoom env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation. For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/) and [Docs](https://envpool.readthedocs.io/en/latest/api/vizdoom.html). ## Train To train an agent: ```bash python3 vizdoom_c51.py --task {D1_basic|D2_navigation|D3_battle|D4_battle2} ``` D1 (health gathering) should finish training (no death) in less than 500k env step (5 epochs); D3 can reach 1600+ reward (75+ killcount in 5 minutes); D4 can reach 700+ reward. Here is the result: (episode length, the maximum length is 2625 because we use frameskip=4, that is 10500/4=2625) ![](results/c51/length.png) (episode reward) ![](results/c51/reward.png) To evaluate an agent's performance: ```bash python3 vizdoom_c51.py --num_test_envs 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} ``` To save `.lmp` files for recording: ```bash python3 vizdoom_c51.py --save-lmp --num_test_envs 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} ``` it will store `lmp` file in `lmps/` directory. To watch these `lmp` files (for example, d3 lmp): ```bash python3 replay.py maps/D3_battle.cfg episode_8_25.lmp ``` We provide two lmp files (d3 best and d4 best) under `results/c51`, you can use the following command to enjoy: ```bash python3 replay.py maps/D3_battle.cfg results/c51/d3.lmp python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp ``` ## Maps See [maps/README.md](maps/README.md) ## Reward 1. living reward is bad 2. combo-action is really important 3. negative reward for health and ammo2 is really helpful for d3/d4 4. only with positive reward for health is really helpful for d1 5. remove MOVE_BACKWARD may converge faster but the final performance may be lower ## Algorithms The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example. ### C51 (single run) | task | best reward | reward curve | parameters | |---------------|-------------|----------------------------------------|-------------------------------------------------| | D2_navigation | 747.52 | ![](results/c51/D2_navigation_rew.png) | `python3 vizdoom_c51.py --task "D2_navigation"` | | D3_battle | 1855.29 | ![](results/c51/D3_battle_rew.png) | `python3 vizdoom_c51.py --task "D3_battle"` | ### PPO (single run) | task | best reward | reward curve | parameters | |---------------|-------------|----------------------------------------|-------------------------------------------------| | D2_navigation | 770.75 | ![](results/ppo/D2_navigation_rew.png) | `python3 vizdoom_ppo.py --task "D2_navigation"` | | D3_battle | 320.59 | ![](results/ppo/D3_battle_rew.png) | `python3 vizdoom_ppo.py --task "D3_battle"` | ### PPO with ICM (single run) | task | best reward | reward curve | parameters | |---------------|-------------|--------------------------------------------|-------------------------------------------------------------------| | D2_navigation | 844.99 | ![](results/ppo_icm/D2_navigation_rew.png) | `python3 vizdoom_ppo.py --task "D2_navigation" --icm-lr-scale 10` | | D3_battle | 547.08 | ![](results/ppo_icm/D3_battle_rew.png) | `python3 vizdoom_ppo.py --task "D3_battle" --icm-lr-scale 10` | ================================================ FILE: examples/vizdoom/env.py ================================================ import os from collections.abc import Sequence from typing import Any import cv2 import gymnasium as gym import numpy as np import vizdoom as vzd from numpy.typing import NDArray from tianshou.env import ShmemVectorEnv try: import envpool except ImportError: envpool = None def normal_button_comb() -> list: actions = [] m_forward = [[0.0], [1.0]] t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] for i in m_forward: for j in t_left_right: actions.append(i + j) return actions def battle_button_comb() -> list: actions = [] m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] attack = [[0.0], [1.0]] speed = [[0.0], [1.0]] for m in attack: for n in speed: for j in m_left_right: for i in m_forward_backward: for k in t_left_right: actions.append(i + j + k + m + n) return actions class Env(gym.Env): def __init__( self, cfg_path: str, frameskip: int = 4, res: Sequence[int] = (4, 40, 60), save_lmp: bool = False, ) -> None: super().__init__() self.save_lmp = save_lmp self.health_setting = "battle" in cfg_path if save_lmp: os.makedirs("lmps", exist_ok=True) self.res = res self.skip = frameskip self.observation_space = gym.spaces.Box(low=0, high=255, shape=res, dtype=np.float32) self.game = vzd.DoomGame() self.game.load_config(cfg_path) self.game.init() if "battle" in cfg_path: self.available_actions = battle_button_comb() else: self.available_actions = normal_button_comb() self.action_num = len(self.available_actions) self.action_space = gym.spaces.Discrete(self.action_num) self.spec = gym.envs.registration.EnvSpec("vizdoom-v0") self.count = 0 def get_obs(self) -> None: state = self.game.get_state() if state is None: return obs = state.screen_buffer self.obs_buffer[:-1] = self.obs_buffer[1:] self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2])) def reset( self, seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[NDArray[np.uint8], dict[str, Any]]: if self.save_lmp: self.game.new_episode(f"lmps/episode_{self.count}.lmp") else: self.game.new_episode() self.count += 1 self.obs_buffer = np.zeros(self.res, dtype=np.uint8) self.get_obs() self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH) self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT) self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) return self.obs_buffer, {"TimeLimit.truncated": False} def step(self, action: int) -> tuple[NDArray[np.uint8], float, bool, bool, dict[str, Any]]: self.game.make_action(self.available_actions[action], self.skip) reward = 0.0 self.get_obs() health = self.game.get_game_variable(vzd.GameVariable.HEALTH) if self.health_setting or health > self.health: # positive health reward only for d1/d2 reward += health - self.health self.health = health killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT) reward += 20 * (killcount - self.killcount) self.killcount = killcount ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) # if ammo2 > self.ammo2: reward += ammo2 - self.ammo2 self.ammo2 = ammo2 done = False info = {} if self.game.is_player_dead() or self.game.get_state() is None: done = True elif self.game.is_episode_finished(): done = True info["TimeLimit.truncated"] = True return ( self.obs_buffer, reward, done, info.get("TimeLimit.truncated", False), info, ) def render(self) -> None: pass def close(self) -> None: self.game.close() def make_vizdoom_env( task: str, frame_skip: int, res: tuple[int], save_lmp: bool = False, seed: int | None = None, num_training_envs: int = 10, num_test_envs: int = 10, ) -> tuple[Env, ShmemVectorEnv, ShmemVectorEnv]: cpu_count = os.cpu_count() if cpu_count is not None: num_test_envs = min(cpu_count - 1, num_test_envs) if envpool is not None: task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1" lmp_save_dir = "lmps/" if save_lmp else "" reward_config = { "KILLCOUNT": [20.0, -20.0], "HEALTH": [1.0, 0.0], "AMMO2": [1.0, -1.0], } if "battle" in task: reward_config["HEALTH"] = [1.0, -1.0] env = training_envs = envpool.make_gymnasium( task_id, frame_skip=frame_skip, stack_num=res[0], seed=seed, num_envs=num_training_envs, reward_config=reward_config, use_combined_action=True, max_episode_steps=2625, use_inter_area_resize=False, ) test_envs = envpool.make_gymnasium( task_id, frame_skip=frame_skip, stack_num=res[0], lmp_save_dir=lmp_save_dir, seed=seed, num_envs=num_test_envs, reward_config=reward_config, use_combined_action=True, max_episode_steps=2625, use_inter_area_resize=False, ) else: cfg_path = f"maps/{task}.cfg" env = Env(cfg_path, frame_skip, res) training_envs = ShmemVectorEnv( [lambda: Env(cfg_path, frame_skip, res) for _ in range(num_training_envs)], ) test_envs = ShmemVectorEnv( [lambda: Env(cfg_path, frame_skip, res, save_lmp) for _ in range(num_test_envs)], ) training_envs.seed(seed) test_envs.seed(seed) return env, training_envs, test_envs if __name__ == "__main__": # env = Env("maps/D1_basic.cfg", 4, (4, 84, 84)) env = Env("maps/D3_battle.cfg", 4, (4, 84, 84)) print(env.available_actions) assert isinstance(env.action_space, gym.spaces.Discrete) action_num = env.action_space.n obs, _ = env.reset() if env.spec: print(env.spec.reward_threshold) print(obs.shape, action_num) for _ in range(4000): obs, rew, terminated, truncated, info = env.step(0) if terminated or truncated: env.reset() print(obs.shape, rew, terminated, truncated) cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3]) ================================================ FILE: examples/vizdoom/maps/D1_basic.cfg ================================================ # Lines starting with # are treated as comments (or with whitespaces+#). # It doesn't matter if you use capital letters or not. # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. doom_scenario_path = D1_basic.wad doom_map = map01 # Rewards # Each step is good for you! living_reward = 0 # And death is not! death_penalty = 0 # Rendering options screen_resolution = RES_160X120 screen_format = GRAY8 render_hud = false render_crosshair = false render_weapon = false render_decals = false render_particles = false window_visible = false # make episodes finish after 10500 actions (tics) episode_timeout = 10500 # Available buttons available_buttons = { MOVE_FORWARD TURN_LEFT TURN_RIGHT } # Game variables that will be in the state available_game_variables = { HEALTH } mode = PLAYER ================================================ FILE: examples/vizdoom/maps/D2_navigation.cfg ================================================ # Lines starting with # are treated as comments (or with whitespaces+#). # It doesn't matter if you use capital letters or not. # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. doom_scenario_path = D2_navigation.wad doom_map = map01 # Rewards # Each step is good for you! living_reward = 0 # And death is not! death_penalty = 0 # Rendering options screen_resolution = RES_160X120 screen_format = GRAY8 render_hud = false render_crosshair = false render_weapon = false render_decals = false render_particles = false window_visible = false # make episodes finish after 10500 actions (tics) episode_timeout = 10500 # Available buttons available_buttons = { MOVE_FORWARD TURN_LEFT TURN_RIGHT } # Game variables that will be in the state available_game_variables = { HEALTH } mode = PLAYER ================================================ FILE: examples/vizdoom/maps/D3_battle.cfg ================================================ # Lines starting with # are treated as comments (or with whitespaces+#). # It doesn't matter if you use capital letters or not. # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. doom_scenario_path = D3_battle.wad doom_map = map01 # Rewards living_reward = 0 death_penalty = 100 # Rendering options screen_resolution = RES_160X120 screen_format = GRAY8 render_hud = false render_crosshair = true render_weapon = true render_decals = false render_particles = false window_visible = false # make episodes finish after 10500 actions (tics) episode_timeout = 10500 # Available buttons available_buttons = { MOVE_FORWARD MOVE_BACKWARD MOVE_LEFT MOVE_RIGHT TURN_LEFT TURN_RIGHT ATTACK SPEED } # Game variables that will be in the state available_game_variables = { KILLCOUNT AMMO2 HEALTH } mode = PLAYER doom_skill = 2 ================================================ FILE: examples/vizdoom/maps/D4_battle2.cfg ================================================ # Lines starting with # are treated as comments (or with whitespaces+#). # It doesn't matter if you use capital letters or not. # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. doom_scenario_path = D4_battle2.wad doom_map = map01 # Rewards living_reward = 0 death_penalty = 100 # Rendering options screen_resolution = RES_160X120 screen_format = GRAY8 render_hud = false render_crosshair = true render_weapon = true render_decals = false render_particles = false window_visible = false # make episodes finish after 10500 actions (tics) episode_timeout = 10500 # Available buttons available_buttons = { MOVE_FORWARD MOVE_BACKWARD MOVE_LEFT MOVE_RIGHT TURN_LEFT TURN_RIGHT ATTACK SPEED } # Game variables that will be in the state available_game_variables = { KILLCOUNT AMMO2 HEALTH } mode = PLAYER doom_skill = 2 ================================================ FILE: examples/vizdoom/maps/README.md ================================================ D1-D4 maps are from https://github.com/intel-isl/DirectFuturePrediction/ More maps and cfgs: https://github.com/mwydmuch/ViZDoom/tree/master/scenarios ================================================ FILE: examples/vizdoom/maps/spectator.py ================================================ #!/usr/bin/env python3 ##################################################################### # This script presents SPECTATOR mode. In SPECTATOR mode you play and # your agent can learn from it. # Configuration is loaded from "../../scenarios/.cfg" file. # # To see the scenario description go to "../../scenarios/README.md" ##################################################################### from argparse import ArgumentParser from time import sleep import vizdoom as vzd # import cv2 if __name__ == "__main__": parser = ArgumentParser("ViZDoom example showing how to use SPECTATOR mode.") parser.add_argument("-c", type=str, dest="config", default="D3_battle.cfg") parser.add_argument("-w", type=str, dest="wad_file", default="D3_battle.wad") args = parser.parse_args() game = vzd.DoomGame() # Choose scenario config file you wish to watch. # Don't load two configs cause the second will overrite the first one. # Multiple config files are ok but combining these ones doesn't make much sense. game.load_config(args.config) game.set_doom_scenario_path(args.wad_file) # Enables freelook in engine game.add_game_args("+freelook 1") game.set_screen_resolution(vzd.ScreenResolution.RES_640X480) # Enables spectator mode, so you can play. # Sounds strange but it is the agent who is supposed to watch not you. game.set_window_visible(True) game.set_mode(vzd.Mode.SPECTATOR) game.init() episodes = 1 for i in range(episodes): print("Episode #" + str(i + 1)) game.new_episode() while not game.is_episode_finished(): state = game.get_state() print(state.screen_buffer.dtype, state.screen_buffer.shape) # cv2.imwrite(f'imgs/{state.number}.png', state.screen_buffer) # game.make_action([0, 0, 0]) game.advance_action() last_action = game.get_last_action() reward = game.get_last_reward() print("State #" + str(state.number)) print("Game variables: ", state.game_variables) print("Action:", last_action) print("Reward:", reward) print("=====================") print("Episode finished!") print("Total reward:", game.get_total_reward()) print("************************") sleep(2.0) game.close() ================================================ FILE: examples/vizdoom/replay.py ================================================ # import cv2 import os import sys import time import tqdm import vizdoom as vzd def main( cfg_path: str = os.path.join("maps", "D3_battle.cfg"), lmp_path: str = os.path.join("test.lmp"), ) -> None: game = vzd.DoomGame() game.load_config(cfg_path) game.set_screen_format(vzd.ScreenFormat.CRCGCB) game.set_screen_resolution(vzd.ScreenResolution.RES_1024X576) game.set_window_visible(True) game.set_render_hud(True) game.init() game.replay_episode(lmp_path) killcount = 0 with tqdm.trange(10500) as tq: while not game.is_episode_finished(): game.advance_action() state = game.get_state() if state is None: break killcount = game.get_game_variable(vzd.GameVariable.KILLCOUNT) time.sleep(1 / 35) # cv2.imwrite(f"imgs/{tq.n}.png", # state.screen_buffer.transpose(1, 2, 0)[..., ::-1]) tq.update(1) game.close() print("killcount:", killcount) if __name__ == "__main__": main(*sys.argv[-2:]) ================================================ FILE: examples/vizdoom/vizdoom_c51.py ================================================ import argparse import datetime import os import pprint import sys import numpy as np import torch from env import make_vizdoom_env from tianshou.algorithm import C51 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import C51Net from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="D1_basic") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--eps_test", type=float, default=0.005) parser.add_argument("--eps_train", type=float, default=1.0) parser.add_argument("--eps_train_final", type=float, default=0.05) parser.add_argument("--buffer_size", type=int, default=2000000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--num_atoms", type=int, default=51) parser.add_argument("--v_min", type=float, default=-10.0) parser.add_argument("--v_max", type=float, default=10.0) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=300) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--frames_stack", type=int, default=4) parser.add_argument("--skip_num", type=int, default=4) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="vizdoom.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) parser.add_argument( "--save_lmp", default=False, action="store_true", help="save lmp file for replay whole episode", ) parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() def test_c51(args: argparse.Namespace = get_args()) -> None: # make environments env, training_envs, test_envs = make_vizdoom_env( args.task, args.skip_num, (args.frames_stack, 84, 84), args.save_lmp, args.seed, args.num_training_envs, args.num_test_envs, ) args.state_shape = env.observation_space.shape args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # define model c, h, w = args.state_shape net = C51Net(c=c, h=h, w=w, action_shape=args.action_shape, num_atoms=args.num_atoms) optim = AdamOptimizerFactory(lr=args.lr) # define policy and algorithm policy = C51Policy( model=net, action_space=env.action_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: C51 = C51( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "c51" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=args.resume_id, config_dict=vars(args), ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold return False def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) # watch agent's performance def watch() -> None: print("Setup test envs ...") test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( args.buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack, ) collector = Collector[CollectStats]( algorithm, test_envs, buffer, exploration_noise=True ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: watch() sys.exit(0) # test training_collector and start filling replay buffer training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_training=False, ) ) pprint.pprint(result) watch() if __name__ == "__main__": test_c51(get_args()) ================================================ FILE: examples/vizdoom/vizdoom_ppo.py ================================================ import argparse import datetime import os import pprint import sys import numpy as np import torch from env import make_vizdoom_env from torch.distributions import Categorical from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.discrete import ( DiscreteActor, DiscreteCritic, IntrinsicCuriosityModule, ) def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="D1_basic") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.00002) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=300) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=1000) parser.add_argument("--update_step_num_repetitions", type=int, default=4) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--hidden_size", type=int, default=512) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--return_scaling", type=int, default=False) parser.add_argument("--vf_coef", type=float, default=0.5) parser.add_argument("--ent_coef", type=float, default=0.01) parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--lr_decay", type=int, default=True) parser.add_argument("--max_grad_norm", type=float, default=0.5) parser.add_argument("--eps_clip", type=float, default=0.2) parser.add_argument("--dual_clip", type=float, default=None) parser.add_argument("--value_clip", type=int, default=0) parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--frames_stack", type=int, default=4) parser.add_argument("--skip_num", type=int, default=4) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb_project", type=str, default="vizdoom.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) parser.add_argument( "--save_lmp", default=False, action="store_true", help="save lmp file for replay whole episode", ) parser.add_argument("--save_buffer_name", type=str, default=None) parser.add_argument( "--icm_lr_scale", type=float, default=0.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( "--icm_reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( "--icm_forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", ) return parser.parse_args() def test_ppo(args: argparse.Namespace = get_args()) -> None: # make environments env, training_envs, test_envs = make_vizdoom_env( args.task, args.skip_num, (args.frames_stack, 84, 84), args.save_lmp, args.seed, args.num_training_envs, args.num_test_envs, ) args.state_shape = env.observation_space.shape args.action_shape = env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # define model c, h, w = args.state_shape net = DQNet( c=c, h=h, w=w, action_shape=args.action_shape, features_only=True, output_dim_added_layer=args.hidden_size, ) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape, softmax_output=False) critic = DiscreteCritic(preprocess_net=net) optim = AdamOptimizerFactory(lr=args.lr) if args.lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, ) ) def dist(logits: torch.Tensor) -> Categorical: return Categorical(logits=logits) # define policy and algorithm policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=False, action_space=env.action_space, ) algorithm: PPO | ICMOnPolicyWrapper algorithm = PPO( policy=policy, critic=critic, optim=optim, gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, return_scaling=args.return_scaling, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: c, h, w = args.state_shape feature_net = DQNet( c=c, h=h, w=w, action_shape=args.action_shape, features_only=True, output_dim_added_layer=args.hidden_size, ) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net=feature_net.net, feature_dim=feature_dim, action_dim=action_dim, ) icm_optim = AdamOptimizerFactory(lr=args.lr) algorithm = ICMOnPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=args.icm_lr_scale, reward_scale=args.icm_reward_scale, forward_loss_weight=args.icm_forward_loss_weight, ).to(args.device) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "ppo_icm" if args.icm_lr_scale > 0 else "ppo" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project else: logger_factory.logger_type = "tensorboard" logger = logger_factory.create_logger( log_dir=log_path, experiment_name=log_name, run_id=args.resume_id, config_dict=vars(args), ) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold return False # watch agent's performance def watch() -> None: print("Setup test envs ...") test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( args.buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack, ) collector = Collector[CollectStats]( algorithm, test_envs, buffer, exploration_noise=True ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: watch() sys.exit(0) # test training_collector and start filling replay buffer training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # train result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=False, ) ) pprint.pprint(result) watch() if __name__ == "__main__": test_ppo(get_args()) ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.poetry] name = "tianshou" version = "2.0.0" description = "A Library for Deep Reinforcement Learning" authors = ["TSAIL "] license = "MIT" readme = "README.md" homepage = "https://github.com/thu-ml/tianshou" classifiers = [ # 3 - Alpha # 4 - Beta # 5 - Production/Stable "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.11", ] exclude = ["test/*", "examples/*", "docs/*"] [tool.poetry.dependencies] python = "^3.11" deepdiff = "^7.0.1" gymnasium = ">=0.28.0" h5py = "^3.9.0" matplotlib = ">=3.0.0" numba = ">=0.60.0" numpy = ">=1.24.4" overrides = "^7.4.0" packaging = "*" pandas = ">=2.0.0" pettingzoo = "^1.22" sensai-utils = ">=1.6.0" tensorboard = "^2.5.0" # Torch 2.0.1 causes problems, see https://github.com/pytorch/pytorch/issues/100974 torch = "^2.0.0, !=2.0.1, !=2.1.0" tqdm = "*" virtualenv = [ # special sauce b/c of a flaky bug in poetry on windows # see https://github.com/python-poetry/poetry/issues/7611#issuecomment-1466478926 { version = "^20.4.3,!=20.4.5,!=20.4.6" }, { version = "<20.16.4", markers = "sys_platform == 'win32'" }, ] # Atari, box2d, classic-control, and mujoco environments are all optional dependencies of gymnasium. # Unfortunately, we cannot have extras relying on (multiple) optionals of other packages due to a poetry issue (see e.g. # https://github.com/python-poetry/poetry/issues/7911) and therefore have to maintain our own list of dependencies. # This requires attention and monitoring of gymnasium's dependencies and their version numbers! ale-py = { version = "~=0.8.1", optional = true } # We have to pin arch explicitly, the lowest version pinned by rliable can't be installed on python 3.11 # and poetry doesn't seem to be able to resolve this properly. arch = { version = ">=5.4.0", optional = true } autorom = { version = "~=0.4.2", extras = ["accept-rom-license"], optional = true } box2d_py = { version = "2.3.5", optional = true } cython = { version = ">=0.27.2", optional = true } docstring-parser = { version = "^0.15", optional = true } envpool = { version = "^0.8.2", optional = true, markers = "sys_platform != 'darwin'"} gymnasium-robotics = { version = "*", optional = true } imageio = { version = ">=2.14.1", optional = true } joblib = { version = "*", optional = true } jsonargparse = {version = "^4.24.1", optional = true} # we need <3 b/c of https://github.com/Farama-Foundation/Gymnasium/issues/749 mujoco = { version = ">=2.1.5, <3", optional = true } opencv_python = { version = "*", optional = true } pybullet = { version = "*", optional = true } pygame = { version = ">=2.1.3", optional = true } rliable = {optional = true, version="1.2.0"} scipy = { version = "*", optional = true } shimmy = { version = ">=0.1.0,<1.0", optional = true } swig = { version = "4.*", optional = true } vizdoom = { version = "*", optional = true } [tool.poetry.extras] argparse = ["docstring-parser", "jsonargparse"] atari = ["ale-py", "autorom", "opencv-python", "shimmy"] box2d = ["box2d-py", "pygame", "swig"] classic_control = ["pygame"] mujoco = ["mujoco", "imageio"] pybullet = ["pybullet"] envpool = ["envpool"] robotics = ["gymnasium-robotics"] vizdoom = ["vizdoom"] eval = ["rliable", "arch", "joblib", "scipy", "jsonargparse", "docstring-parser"] [tool.poetry.group.dev] optional = true [tool.poetry.group.dev.dependencies] docutils = "0.20.1" jinja2 = "*" jupyter = "^1.0.0" jupyter-book = "^1.0.0" mypy = "^1.4.1" nbqa = "^1.7.1" nbstripout = "^0.6.1" # networkx is used in a test networkx = "*" poethepoet = "^0.20.0" pre-commit = "^3.3.3" pygame = "^2.1.0" pymunk = "^6.2.1" pytest = "*" pytest-cov = "*" # Ray currently causes issues when installed on windows server 2022 in CI # If users want to use ray, they should install it manually. ray = { version = ">=2.10, <3", markers = "sys_platform != 'win32'" } ruff = "0.14.1" scipy = "*" sphinx = "^7" sphinx-book-theme = "^1.0.1" sphinx-comments = "^0.0.3" sphinx-copybutton = "^0.5.2" sphinx-jupyterbook-latex = "^1.0.0" sphinx-togglebutton = "^0.3.2" sphinx-toolbox = "^3.5.0" sphinxcontrib-bibtex = "*" sphinxcontrib-spelling = "^8.0.0" sphinxcontrib-mermaid = "^1.0.0" types-requests = "^2.31.0.20240311" types-tabulate = "^0.9.0.20240106" # this is needed for wandb only (undisclosed dependency) typing-extensions = ">=4.10" wandb = ">=0.16.0" [tool.mypy] allow_redefinition = true check_untyped_defs = true disallow_incomplete_defs = true disallow_untyped_defs = true ignore_missing_imports = true no_implicit_optional = true pretty = true show_error_codes = true show_error_context = true show_traceback = true strict_equality = true strict_optional = true warn_no_return = true warn_redundant_casts = true warn_unreachable = true warn_unused_configs = true warn_unused_ignores = true exclude = "^build/|^docs/" [tool.doc8] max-line-length = 1000 [tool.nbqa.exclude] ruff = "\\.jupyter_cache|jupyter_execute" mypy = "\\.jupyter_cache|jupyter_execute" [tool.ruff] target-version = "py311" line-length = 100 [tool.ruff.lint] select = [ "ASYNC", "B", "C4", "C90", "COM", "D", "DTZ", "E", "F", "FLY", "G", "I", "ISC", "PIE", "PLC", "PLE", "PLW", "RET", "RUF", "RSE", "SIM", "TID", "UP", "W", "YTT", ] ignore = [ "RUF003", # custom (greek) letters "SIM118", # Needed b/c iter(batch) != iter(batch.keys()). See https://github.com/thu-ml/tianshou/issues/922 "E501", # line too long. ruff does a good enough job "E741", # variable names like "l". this isn't a huge problem "B008", # do not perform function calls in argument defaults. we do this sometimes "B011", # assert false. we don't use python -O "B028", # we don't need explicit stacklevel for warnings "D100", "D101", "D102", "D104", "D105", "D107", "D203", "D213", "D401", "D402", # docstring stuff "DTZ005", # we don't need that # remaining rules from https://github.com/psf/black/blob/main/.flake8 (except W503) # this is a simplified version of config, making vscode plugin happy "E402", "E501", "E701", "E731", "C408", "E203", # Logging statement uses f-string warning "G004", # Unnecessary `elif` after `return` statement "RET505", "D106", # undocumented public nested class "D205", # blank line after summary (prevents summary-only docstrings, which makes no sense) "D212", # no blank line after """. This clashes with sphinx for multiline descriptions of :param: that start directly after """ "PLW2901", # overwrite vars in loop "B027", # empty and non-abstract method in abstract class "D404", # It's fine to start with "This" in docstrings "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx "COM812", # missing trailing comma: With this enabled, re-application of "poe format" chain can cause additional commas and subsequent reformatting "B023", # forbids function using loop variable without explicit binding "RUF059", # unused name after unpacking "RUF005", # concatenation "PLC0415", # local imports "SIM108", # if else is fine instead of ternary "PLW1641", # weird thing requiring __hash__ for Protocol "PLC0206", # extracting value from dictionary without calling `.items()` "SIM103", # forces returning of conditions instead of booleans "E721", # forbids use of equality for type checks "RET504", # sometimes we want to assign before return (e.g. for debugging) ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all "F601", # automatic fix might obscure issue "F602", # automatic fix might obscure issue "B018", # automatic fix might obscure issue ] extend-fixable = [ "F401", # unused import "B905", # bugbear ] [tool.ruff.lint.mccabe] max-complexity = 20 [tool.ruff.lint.per-file-ignores] "test/**" = ["D103"] "docs/**" = ["D103"] "examples/**" = ["D103"] "__init__.py" = ["F401"] # do not remove "unused" imports (F401) from __init__.py files [tool.poetry-sort] move-optionals-to-bottom = true [tool.poe.env] PYDEVD_DISABLE_FILE_VALIDATION="1" # keep relevant parts in sync with pre-commit [tool.poe.tasks] # https://github.com/nat-n/poethepoet test = "pytest test" test-nocov = "pytest -p no:cov test" test-reduced = "pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes" _ruff_fix = "ruff check --fix ." _ruff_fix_check = "ruff check ." _ruff_format = "ruff format ." _ruff_format_check = "ruff format --check ." lint = ["_ruff_format_check", "_ruff_fix_check"] clean-nbs = "python docs/nbstripout.py" format = ["_ruff_fix", "_ruff_format"] _autogen_rst = "python docs/autogen_rst.py" _sphinx_build = "sphinx-build -b html docs docs/_build -W --keep-going" _jb_generate_toc = "python docs/create_toc.py" _jb_generate_config = "jupyter-book config sphinx docs/" doc-clean = "rm -rf docs/_build docs/03_api" doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] doc-build = ["doc-clean", "doc-generate-files", "_sphinx_build"] _mypy = "mypy tianshou test examples" _mypy_nb = "nbqa mypy docs" type-check = ["_mypy"] ================================================ FILE: test/__init__.py ================================================ ================================================ FILE: test/base/__init__.py ================================================ ================================================ FILE: test/base/env.py ================================================ import random import time from copy import deepcopy from typing import Any, Literal import gymnasium as gym import networkx as nx import numpy as np from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple class MoveToRightEnv(gym.Env): """A task for "going right". The task is to go right ``size`` steps. The observation is the current index, and the action is to go left or right. Action 0 is to go left, and action 1 is to go right. Taking action 0 at index 0 will keep the index at 0. Arriving at index ``size`` means the task is done. In the current implementation, stepping after the task is done is possible, which will lead the index to be larger than ``size``. Index 0 is the starting point. If reset is called with default options, the index will be reset to 0. """ def __init__( self, size: int, sleep: float = 0.0, dict_state: bool = False, recurse_state: bool = False, ma_rew: int = 0, multidiscrete_action: bool = False, random_sleep: bool = False, array_state: bool = False, ) -> None: assert dict_state + recurse_state + array_state <= 1, ( "dict_state / recurse_state / array_state can be only one true" ) self.size = size self.sleep = sleep self.random_sleep = random_sleep self.dict_state = dict_state self.recurse_state = recurse_state self.array_state = array_state self.ma_rew = ma_rew self._md_action = multidiscrete_action # how many steps this env has stepped self.steps = 0 if dict_state: self.observation_space = Dict( { "index": Box(shape=(1,), low=0, high=size - 1), "rand": Box(shape=(1,), low=0, high=1, dtype=np.float64), }, ) elif recurse_state: self.observation_space = Dict( { "index": Box(shape=(1,), low=0, high=size - 1), "dict": Dict( { "tuple": Tuple( ( Discrete(2), Box(shape=(2,), low=0, high=1, dtype=np.float64), ), ), "rand": Box(shape=(1, 2), low=0, high=1, dtype=np.float64), }, ), }, ) elif array_state: self.observation_space = Box(shape=(4, 84, 84), low=0, high=255) else: self.observation_space = Box(shape=(1,), low=0, high=size - 1) if multidiscrete_action: self.action_space = MultiDiscrete([2, 2]) else: self.action_space = Discrete(2) self.terminated = False self.index = 0 def reset( self, seed: int | None = None, # TODO: passing a dict here doesn't make any sense options: dict[str, Any] | None = None, ) -> tuple[dict[str, Any] | np.ndarray, dict]: """:param seed: :param options: the start index is provided in options["state"] :return: """ if options is None: options = {"state": 0} super().reset(seed=seed) self.terminated = False self.do_sleep() self.index = options["state"] return self._get_state(), {"key": 1, "env": self} def _get_reward(self) -> list[int] | int: """Generate a non-scalar reward if ma_rew is True.""" end_flag = int(self.terminated) if self.ma_rew > 0: return [end_flag] * self.ma_rew return end_flag def _get_state(self) -> dict[str, Any] | np.ndarray: """Generate state(observation) of MyTestEnv.""" if self.dict_state: return { "index": np.array([self.index], dtype=np.float32), "rand": self.np_random.random(1), } if self.recurse_state: return { "index": np.array([self.index], dtype=np.float32), "dict": { "tuple": (np.array([1], dtype=int), self.np_random.random(2)), "rand": self.np_random.random((1, 2)), }, } if self.array_state: img = np.zeros([4, 84, 84], int) img[3, np.arange(84), np.arange(84)] = self.index img[2, np.arange(84)] = self.index img[1, :, np.arange(84)] = self.index img[0] = self.index return img return np.array([self.index], dtype=np.float32) def do_sleep(self) -> None: if self.sleep > 0: sleep_time = random.random() if self.random_sleep else 1 sleep_time *= self.sleep time.sleep(sleep_time) def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. issue #1080 self.steps += 1 if self._md_action and isinstance(action, np.ndarray): action = action[0] if self.terminated: raise ValueError("step after done !!!") self.do_sleep() if self.index == self.size: self.terminated = True return self._get_state(), self._get_reward(), self.terminated, False, {} info_dict = {"key": 1, "env": self} if action == 0: self.index = max(self.index - 1, 0) return ( self._get_state(), self._get_reward(), self.terminated, False, info_dict, ) if action == 1: self.index += 1 self.terminated = self.index == self.size return ( self._get_state(), self._get_reward(), self.terminated, False, info_dict, ) raise ValueError(f"Invalid action {action}") class NXEnv(gym.Env): def __init__(self, size: int, obs_type: str, feat_dim: int = 32) -> None: self.size = size self.feat_dim = feat_dim self.graph = nx.Graph() self.graph.add_nodes_from(list(range(size))) assert obs_type in ["array", "object"] self.obs_type = obs_type def _encode_obs(self) -> np.ndarray | nx.Graph: if self.obs_type == "array": return np.stack([v["data"] for v in self.graph._node.values()]) return deepcopy(self.graph) def reset( self, seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[np.ndarray | nx.Graph, dict]: super().reset(seed=seed) graph_state = np.random.rand(self.size, self.feat_dim) for i in range(self.size): self.graph.nodes[i]["data"] = graph_state[i] return self._encode_obs(), {} def step( self, action: Space, ) -> tuple[np.ndarray | nx.Graph, float, Literal[False], Literal[False], dict]: next_graph_state = np.random.rand(self.size, self.feat_dim) for i in range(self.size): self.graph.nodes[i]["data"] = next_graph_state[i] return self._encode_obs(), 1.0, False, False, {} class MyGoalEnv(MoveToRightEnv): def __init__(self, *args: Any, **kwargs: Any) -> None: assert kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0, ( "dict_state / recurse_state not supported" ) super().__init__(*args, **kwargs) super().reset(options={"state": 0}) # will result in obs=1, I guess, so the goal is to reach the max size by moving right obs, _, _, _, _ = super().step(1) self._goal = obs * self.size super_obsv = self.observation_space self.observation_space = gym.spaces.Dict( { "observation": super_obsv, "achieved_goal": super_obsv, "desired_goal": super_obsv, }, ) def reset(self, *args: Any, **kwargs: Any) -> tuple[dict[str, Any], dict]: obs, info = super().reset(*args, **kwargs) new_obs = {"observation": obs, "achieved_goal": obs, "desired_goal": self._goal} return new_obs, info def step(self, *args: Any, **kwargs: Any) -> tuple[dict[str, Any], float, bool, bool, dict]: obs_next, rew, terminated, truncated, info = super().step(*args, **kwargs) new_obs_next = { "observation": obs_next, "achieved_goal": obs_next, "desired_goal": self._goal, } return new_obs_next, rew, terminated, truncated, info def compute_reward_fn( self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: dict, ) -> np.ndarray: axis: tuple[int, ...] = (-3, -2, -1) if self.array_state else (-1,) return (achieved_goal == desired_goal).all(axis=axis) ================================================ FILE: test/base/test_action_space_sampling.py ================================================ import gymnasium as gym from tianshou.env import DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv def test_gym_env_action_space() -> None: env = gym.make("Pendulum-v1") env.action_space.seed(0) action1 = env.action_space.sample() env.action_space.seed(0) action2 = env.action_space.sample() assert action1 == action2 def test_dummy_vec_env_action_space() -> None: num_envs = 4 envs = DummyVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)]) envs.seed(0) action1 = [ac_space.sample() for ac_space in envs.action_space] envs.seed(0) action2 = [ac_space.sample() for ac_space in envs.action_space] assert action1 == action2 def test_subproc_vec_env_action_space() -> None: num_envs = 4 envs = SubprocVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)]) envs.seed(0) action1 = [ac_space.sample() for ac_space in envs.action_space] envs.seed(0) action2 = [ac_space.sample() for ac_space in envs.action_space] assert action1 == action2 def test_shmem_vec_env_action_space() -> None: num_envs = 4 envs = ShmemVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)]) envs.seed(0) action1 = [ac_space.sample() for ac_space in envs.action_space] envs.seed(0) action2 = [ac_space.sample() for ac_space in envs.action_space] assert action1 == action2 ================================================ FILE: test/base/test_batch.py ================================================ import copy import pickle from itertools import starmap from typing import Any, cast import networkx as nx import numpy as np import pytest import torch from deepdiff import DeepDiff from torch.distributions import Distribution, Independent, Normal from torch.distributions.categorical import Categorical from tianshou.data import Batch, to_numpy, to_torch from tianshou.data.batch import IndexType, dist_to_atleast_2d, get_sliced_dist def test_batch() -> None: assert list(Batch()) == [] assert len(Batch().get_keys()) == 0 assert len(Batch(b={"c": {}}).get_keys()) != 0 assert len(Batch(b={"c": {}})) == 0 assert len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) != 0 assert len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0 assert len(Batch(d=1).get_keys()) != 0 assert len(Batch(a=np.float64(1.0)).get_keys()) != 0 assert len(Batch(a=[1, 2, 3], b={"c": {}})) == 3 assert len(Batch(a=[1, 2, 3]).get_keys()) != 0 b = Batch({"a": [4, 4], "b": [5, 5]}, c=[None, None]) assert b.c.dtype == object b = Batch(d=[None], e=[starmap], f=Batch) assert b.d.dtype == b.e.dtype == object assert b.f == Batch b = Batch() b.update() assert len(b.get_keys()) == 0 b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({"a": 2}, a=3) assert "a" in b assert b.a == 3 assert b.pop("a") == 3 assert "a" not in b with pytest.raises(AssertionError): Batch({1: 2}) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object with pytest.raises(TypeError): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch.cat_(batch) assert torch.allclose(batch.a, torch.ones(4, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs # noqa: B018 with pytest.raises(AttributeError): b.obs # noqa: B018 print(batch) batch = Batch(a=np.arange(10)) with pytest.raises(AssertionError): list(batch.split(0)) data = [ (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]), (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]), (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]), (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), ] for size, merge_last, result in data: bs = list(batch.split(size, shuffle=False, merge_last=merge_last)) assert [bs[i].a.tolist() for i in range(len(bs))] == result batch_dict = {"b": np.array([1.0]), "c": 2.0, "d": torch.Tensor([3.0])} batch_item = Batch({"a": [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) assert batch_item.a.b == batch_dict["b"] assert isinstance(batch_item.a.c, float) assert batch_item.a.c == batch_dict["c"] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict["d"] batch2 = Batch(a=[{"b": np.float64(1.0), "c": np.zeros(1), "d": Batch(e=np.array(3.0))}]) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] assert Batch(a={1, 2}).shape == [] assert batch2.shape[0] == 1 assert "a" in batch2 assert all(i in batch2.a for i in "bcd") with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] assert batch2[0].shape == [] with pytest.raises(IndexError): batch2[0][0] with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, float) assert isinstance(batch2[0].a.d.e, float) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch(list(batch2)) assert batch2_from_list.a.b == batch2.a.b assert batch2_from_list.a.c == batch2.a.c assert batch2_from_list.a.d.e == batch2.a.d.e assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2.a.d.f = {} batch2_sum = (batch2 + 1.0) * 2 # type: ignore # __add__ supports Number as input type assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 assert len(batch2_sum.a.d.f.get_keys()) == 0 with pytest.raises(TypeError): batch2 += [1] # type: ignore # error is raised explicitly batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))}) batch3.a.d[0] = {"e": 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 with pytest.raises(ValueError): batch3.a.d[0] = Batch(f=5.0, g=0.0) with pytest.raises(ValueError): batch3[0] = Batch(a={"c": 2, "e": 1}) # auto convert batch4 = Batch(a=np.array(["a", "b"])) assert batch4.a.dtype == object # auto convert to object batch4.update(a=np.array(["c", "d"])) assert list(batch4.a) == ["c", "d"] assert batch4.a.dtype == object # auto convert to object batch5 = Batch(a=np.array([{"index": 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) # We use setattr b/c the setattr of Batch will actually change the type of the field that is being set! # However, mypy would not understand this, and rightly expect that batch.b = some_array would lead to # batch.b being an array (which it is not, it's turned into a Batch instead) batch5.b = np.array([{"index": 1}]) batch5.b = cast(Batch, batch5.b) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) # None is a valid object and can be stored in Batch a = Batch.stack([Batch(a=None), Batch(b=None)]) assert a.a[0] is None assert a.a[1] is None assert a.b[0] is None assert a.b[1] is None # nx.Graph corner case assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object g1 = nx.Graph() g1.add_nodes_from(list(range(10))) g2 = nx.Graph() g2.add_nodes_from(list(range(20))) assert Batch(a=np.array([g1, g2], dtype=object)).a.dtype == object def test_batch_over_batch() -> None: batch = Batch(a=[3, 4, 5], b=[4, 5, 6]) batch2 = Batch({"c": [6, 7, 8], "b": batch}) batch2.b.b[-1] = 0 print(batch2) for k, v in batch2.items(): assert np.all(batch2[k] == v) assert batch2[-1].b.b == 0 batch2.cat_(Batch(c=[6, 7, 8], b=batch)) assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) batch2.update(batch2.b, six=[6, 6, 6]) assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0]) assert np.allclose(batch2.six, [6, 6, 6]) d = {"a": [3, 4, 5], "b": [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) batch3.cat_(Batch(c=[6, 7, 8], b=d)) assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6]) batch4 = Batch(({"a": {"b": np.array([1.0])}},)) assert batch4.a.b.ndim == 2 assert batch4.a.b[0, 0] == 1.0 # advanced slicing batch5 = Batch(a=[[1, 2]], b={"c": np.zeros([3, 2, 1])}) assert batch5.shape == [1, 2] with pytest.raises(IndexError): batch5[2] with pytest.raises(IndexError): batch5[:, 3] with pytest.raises(IndexError): batch5[:, :, -1] batch5[:, -1] += np.int_(1) assert np.allclose(batch5.a, [1, 3]) assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) with pytest.raises(ValueError): batch5[:, -1] = 1 batch5[:, 0] = {"a": -1} assert np.allclose(batch5.a, [-1, 3]) assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) def test_batch_cat_and_stack() -> None: # test cat with compatible keys b1 = Batch(a=[{"b": np.float64(1.0), "d": Batch(e=np.array(3.0))}]) b2 = Batch(a=[{"b": np.float64(4.0), "d": {"e": np.array(6.0)}}]) b12_cat_out = Batch.cat([b1, b2]) b12_cat_in = copy.deepcopy(b1) b12_cat_in.cat_(b2) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 a = Batch(a=Batch(a=np.random.randn(3, 4))) a_empty = Batch(a=Batch(a=Batch())) assert np.allclose( np.concatenate([a.a.a, a.a.a]), Batch.cat([a, a_empty, a]).a.a, ) # test cat with lens infer a = Batch(a=Batch(a=np.random.randn(3, 4), t=Batch()), b=np.random.randn(3, 4)) b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4)) ans = Batch.cat([a, b, a]) assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) assert len(ans.a.t.get_keys()) == 0 b1.stack_([b2]) assert isinstance(b1.a.d.e, np.ndarray) assert b1.a.d.e.ndim == 2 # test cat with all reserved keys (values are Batch()) b1 = Batch(a=Batch(), b=torch.zeros(3, 3), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) ans = Batch( a=Batch(), b=torch.cat([torch.zeros(3, 3), b2.b]), common=Batch(c=np.concatenate([b1.common.c, b2.common.c])), ) assert len(ans.a.get_keys()) == 0 assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test stack with compatible keys b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]])) b34_stack = Batch.stack((b3, b4), axis=1) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d, strict=True)))) b5_dict = np.array([{"a": False, "b": {"c": 2.0, "d": 1.0}}, {"a": True, "b": {"c": 3.0}}]) b5 = Batch(b5_dict) assert b5.a[0] == np.array(False) assert b5.a[1] == np.array(True) assert np.all(b5.b.c == np.stack([e["b"]["c"] for e in b5_dict], axis=0)) assert b5.b.d[0] == b5_dict[0]["b"]["d"] assert b5.b.d[1] == 0.0 # test stack with incompatible keys a = Batch(a=1, b=2, c=3) b = Batch(a=4, b=5, d=6) c = Batch(c=7, b=6, d=9) d = Batch.stack([a, b, c]) assert np.allclose(d.a, [1, 4, 0]) assert np.allclose(d.b, [2, 5, 6]) assert np.allclose(d.c, [3, 0, 7]) assert np.allclose(d.d, [0, 6, 9]) # test stack with empty Batch() assert len(Batch.stack([Batch(), Batch(), Batch()]).get_keys()) == 0 a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch()) b = Batch(a=4, b=5, d=6, e=Batch()) c = Batch(c=7, b=6, d=9, e=Batch()) d = Batch.stack([a, b, c]) assert np.allclose(d.a, [1, 4, 0]) assert np.allclose(d.b, [2, 5, 6]) assert np.allclose(d.c, [3, 0, 7]) assert np.allclose(d.d, [0, 6, 9]) assert len(d.e.get_keys()) == 0 b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5))) test = Batch.stack([b1, b2], axis=-1) assert len(test.a.get_keys()) == 0 assert len(test.b.get_keys()) == 0 assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1)) b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) test = Batch.stack([b1, b2]) ans = Batch( a=np.stack([b1.a, np.zeros((4, 4))]), b=torch.stack([torch.zeros(4, 6), b2.b]), common=Batch(c=np.stack([b1.common.c, b2.common.c])), ) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test with illegal input format with pytest.raises(ValueError): Batch.cat([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # cat() tested with invalid inp with pytest.raises(ValueError): Batch.stack([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # stack() tested with invalid inp # exceptions batch_cat: Batch = Batch.cat([]) assert len(batch_cat.get_keys()) == 0 batch_stack: Batch = Batch.stack([]) assert len(batch_stack.get_keys()) == 0 b1 = Batch(e=[4, 5], d=6) b2 = Batch(e=[4, 6]) with pytest.raises(ValueError): Batch.cat([b1, b2]) with pytest.raises(ValueError): Batch.stack([b1, b2], axis=1) def test_utils_to_torch_numpy() -> None: batch = Batch( a=np.float64(1.0), b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)), ) a_torch_float = to_torch(batch.a, dtype=torch.float32) assert a_torch_float.dtype == torch.float32 a_torch_double = to_torch(batch.a, dtype=torch.float64) assert a_torch_double.dtype == torch.float64 batch_torch_float = to_torch(batch, dtype=torch.float32) assert batch_torch_float.a.dtype == torch.float64 assert batch_torch_float.b.c.dtype == torch.float32 assert batch_torch_float.b.d.dtype == torch.float32 data_list = [float("nan"), 1] data_list_torch = to_torch(data_list) assert data_list_torch.dtype == torch.float64 data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)] data_list_2_torch = to_torch(data_list_2) assert data_list_2_torch.shape == (2, 3, 3) assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2) data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))] data_list_3_torch = [torch.zeros((3, 2)), torch.zeros((3, 3))] with pytest.raises(TypeError): to_torch(data_list_3) with pytest.raises(TypeError): to_numpy(data_list_3_torch) data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))] data_list_4_torch = [torch.zeros((2, 3)), torch.zeros((3, 3))] with pytest.raises(TypeError): to_torch(data_list_4) with pytest.raises(TypeError): to_numpy(data_list_4_torch) data_list_5 = [np.zeros(2), np.zeros((3, 3))] data_list_5_torch = [torch.zeros(2), torch.zeros((3, 3))] with pytest.raises(TypeError): to_torch(data_list_5) with pytest.raises(TypeError): to_numpy(data_list_5_torch) data_array = np.random.rand(3, 2, 2) data_empty_tensor = to_torch(data_array[[]]) assert isinstance(data_empty_tensor, torch.Tensor) assert data_empty_tensor.shape == (0, 2, 2) data_empty_array = to_numpy(data_empty_tensor) assert isinstance(data_empty_array, np.ndarray) assert data_empty_array.shape == (0, 2, 2) assert np.allclose(to_numpy(to_torch(data_array)), data_array) # additional test for to_numpy, for code-coverage assert isinstance(to_numpy(1), np.ndarray) assert isinstance(to_numpy(1.0), np.ndarray) assert isinstance(to_numpy({"a": torch.tensor(1)})["a"], np.ndarray) assert isinstance(to_numpy(Batch(a=torch.tensor(1))).a, np.ndarray) assert to_numpy(None).item() is None assert to_numpy(to_numpy).item() == to_numpy # additional test for to_torch, for code-coverage assert isinstance(to_torch(1), torch.Tensor) assert to_torch(1).dtype in (torch.int64, torch.int32) assert to_torch(1.0).dtype == torch.float64 assert isinstance(to_torch({"a": [1]})["a"], torch.Tensor) with pytest.raises(TypeError): to_torch(None) with pytest.raises(TypeError): to_torch(np.array([{}, "2"])) def test_batch_pickle() -> None: batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4])) batch_pk = pickle.loads(pickle.dumps(batch)) assert batch.obs.a == batch_pk.obs.a assert torch.all(batch.obs.c == batch_pk.obs.c) assert np.all(batch.np == batch_pk.np) def test_batch_copy() -> None: batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6])) batch2 = Batch({"c": np.array([6, 7, 8]), "b": batch}) orig_c_addr = batch2.c.__array_interface__["data"][0] orig_b_a_addr = batch2.b.a.__array_interface__["data"][0] orig_b_b_addr = batch2.b.b.__array_interface__["data"][0] # test with copy=False batch3 = Batch(copy=False, **batch2) curr_c_addr = batch3.c.__array_interface__["data"][0] curr_b_a_addr = batch3.b.a.__array_interface__["data"][0] curr_b_b_addr = batch3.b.b.__array_interface__["data"][0] assert batch2.c is batch3.c assert batch2.b is batch3.b assert batch2.b.a is batch3.b.a assert batch2.b.b is batch3.b.b assert orig_c_addr == curr_c_addr assert orig_b_a_addr == curr_b_a_addr assert orig_b_b_addr == curr_b_b_addr # test with copy=True batch3 = Batch(copy=True, **batch2) curr_c_addr = batch3.c.__array_interface__["data"][0] curr_b_a_addr = batch3.b.a.__array_interface__["data"][0] curr_b_b_addr = batch3.b.b.__array_interface__["data"][0] assert batch2.c is not batch3.c assert batch2.b is not batch3.b assert batch2.b.a is not batch3.b.a assert batch2.b.b is not batch3.b.b assert orig_c_addr != curr_c_addr assert orig_b_a_addr != curr_b_a_addr assert orig_b_b_addr != curr_b_b_addr def test_batch_empty() -> None: b5_dict = np.array([{"a": False, "b": {"c": 2.0, "d": 1.0}}, {"a": True, "b": {"c": 3.0}}]) b5 = Batch(b5_dict) b5[1] = Batch.empty(b5[0]) assert np.allclose(b5.a, [False, False]) assert np.allclose(b5.b.c, [2, 0]) assert np.allclose(b5.b.d, [1, 0]) data = Batch( a=[False, True], b={ "c": np.array([2.0, "st"], dtype=object), "d": [1, None], "e": [2.0, float("nan")], }, c=np.array([1, 3, 4], dtype=int), t=torch.tensor([4, 5, 6, 7.0]), ) data[-1] = Batch.empty(data[1]) assert np.allclose(data.c, [1, 3, 0]) assert np.allclose(data.a, [False, False]) assert list(data.b.c) == [2.0, None] assert list(data.b.d) == [1, None] assert np.allclose(data.b.e, [2, 0]) assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.0])) data[0].empty_() # which will fail in a, b.c, b.d, b.e, c assert torch.allclose(data.t, torch.tensor([0.0, 5, 6, 0])) data.empty_(index=0) assert np.allclose(data.c, [0, 3, 0]) assert list(data.b.c) == [None, None] assert list(data.b.d) == [None, None] assert list(data.b.e) == [0, 0] b0 = Batch() b0.empty_() assert b0.shape == [] def test_batch_standard_compatibility() -> None: batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0])) batch_mean = np.mean(batch) assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst` assert sorted(batch_mean.get_keys()) == ["a", "b", "c"] # type: ignore with pytest.raises(TypeError): len(batch_mean) assert np.all(batch_mean.a == np.mean(batch.a, axis=0)) assert batch_mean.c == np.mean(batch.c, axis=0) with pytest.raises(IndexError): Batch()[0] class TestBatchEquality: @staticmethod def test_keys_different() -> None: batch1 = Batch(a=[1, 2], b=[100, 50]) batch2 = Batch(b=[1, 2], c=[100, 50]) assert batch1 != batch2 @staticmethod def test_keys_missing() -> None: batch1 = Batch(a=[1, 2], b=[2, 3, 4]) batch2 = Batch(a=[1, 2], b=[2, 3, 4]) batch2.pop("b") assert batch1 != batch2 @staticmethod def test_types_keys_different() -> None: batch1 = Batch(a=[1, 2, 3], b=[4, 5]) batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5])) assert batch1 != batch2 @staticmethod def test_array_types_different() -> None: batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5])) batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5])) assert batch1 != batch2 @staticmethod def test_nested_values_different() -> None: batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5]) batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5]) assert batch1 != batch2 @staticmethod def test_nested_shapes_different() -> None: batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5]) batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5]) assert batch1 != batch2 @staticmethod def test_array_scalars() -> None: batch1 = Batch(a={"b": 1}) batch2 = Batch(a={"b": 1}) assert batch1 == batch2 batch3 = Batch(a={"c": 2}) assert batch1 != batch3 batch4 = Batch(b={"b": 1}) assert batch1 != batch4 batch5 = Batch(a={"b": 10}) assert batch1 != batch5 batch6 = Batch(a={"b": [1]}) assert batch1 == batch6 batch7 = Batch(a=1, b=5) batch8 = Batch(a=1, b=5) assert batch7 == batch8 @staticmethod def test_slice_equal() -> None: batch1 = Batch(a=[1, 2, 3]) assert batch1[:2] == batch1[:2] @staticmethod def test_slice_ellipsis_equal() -> None: batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000]) assert batch1[..., 1:] == batch1[..., 1:] @staticmethod def test_empty_batches() -> None: assert Batch() == Batch() @staticmethod def test_different_order_keys() -> None: assert Batch(a=1, b=2) == Batch(b=2, a=1) @staticmethod def test_tuple_and_list_types() -> None: assert Batch(a=(1, 2)) == Batch(a=[1, 2]) @staticmethod def test_subbatch_dict_and_batch_types() -> None: assert Batch(a={"x": 1}) == Batch(a=Batch(x=1)) class TestBatchToDict: @staticmethod def test_to_dict_empty_batch_no_recurse() -> None: batch = Batch() expected: dict[Any, Any] = {} assert batch.to_dict() == expected @staticmethod def test_to_dict_with_simple_values_recurse() -> None: batch = Batch(a=1, b="two", c=np.array([3, 4])) expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])} assert not DeepDiff(batch.to_dict(recursive=True), expected) @staticmethod def test_to_dict_simple() -> None: batch = Batch(a=1, b="two") expected = {"a": np.asanyarray(1), "b": "two"} assert batch.to_dict() == expected @staticmethod def test_to_dict_nested_batch_no_recurse() -> None: nested_batch = Batch(c=3) batch = Batch(a=1, b=nested_batch) expected = {"a": np.asanyarray(1), "b": nested_batch} assert not DeepDiff(batch.to_dict(recursive=False), expected) @staticmethod def test_to_dict_nested_batch_recurse() -> None: nested_batch = Batch(c=3) batch = Batch(a=1, b=nested_batch) expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}} assert not DeepDiff(batch.to_dict(recursive=True), expected) @staticmethod def test_to_dict_multiple_nested_batch_recurse() -> None: nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300]) batch = Batch(a=1, b=nested_batch) expected = { "a": np.asanyarray(1), "b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])}, } assert not DeepDiff(batch.to_dict(recursive=True), expected) @staticmethod def test_to_dict_array() -> None: batch = Batch(a=np.array([1, 2, 3])) expected = {"a": np.array([1, 2, 3])} assert not DeepDiff(batch.to_dict(), expected) @staticmethod def test_to_dict_nested_batch_with_array() -> None: nested_batch = Batch(c=np.array([4, 5])) batch = Batch(a=1, b=nested_batch) expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}} assert not DeepDiff(batch.to_dict(recursive=True), expected) @staticmethod def test_to_dict_torch_tensor() -> None: t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy() batch = Batch(a=t1) t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy() expected = {"a": t2} assert not DeepDiff(batch.to_dict(), expected) @staticmethod def test_to_dict_nested_batch_with_torch_tensor() -> None: nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy()) batch = Batch(a=1, b=nested_batch) expected = { "a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}, } assert not DeepDiff(batch.to_dict(recursive=True), expected) class TestBatchConversions: @staticmethod def test_to_numpy() -> None: batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) new_batch = batch.to_numpy() assert id(batch) != id(new_batch) assert isinstance(batch.b, torch.Tensor) assert isinstance(batch.c.d, torch.Tensor) assert isinstance(new_batch.b, np.ndarray) assert isinstance(new_batch.c.d, np.ndarray) @staticmethod def test_to_numpy_() -> None: batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) id_batch = id(batch) batch.to_numpy_() assert id_batch == id(batch) assert isinstance(batch.b, np.ndarray) assert isinstance(batch.c.d, np.ndarray) @staticmethod def test_to_torch() -> None: batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])}) new_batch = batch.to_torch() assert id(batch) != id(new_batch) assert isinstance(batch.b, np.ndarray) assert isinstance(batch.c.d, np.ndarray) assert isinstance(new_batch.b, torch.Tensor) assert isinstance(new_batch.c.d, torch.Tensor) @staticmethod def test_to_torch_() -> None: batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])}) id_batch = id(batch) batch.to_torch_() assert id_batch == id(batch) assert isinstance(batch.b, torch.Tensor) assert isinstance(batch.c.d, torch.Tensor) @staticmethod def test_apply_array_func() -> None: batch = Batch(a=1, b=np.arange(3), c={"d": np.array([1, 2, 3])}) batch_with_max = batch.apply_values_transform(np.max) assert np.array_equal(batch_with_max.a, np.array(1)) assert np.array_equal(batch_with_max.b, np.array(2)) assert np.array_equal(batch_with_max.c.d, np.array(3)) batch_array_added = batch.apply_values_transform(lambda x: x + np.array([1, 2, 3])) assert np.array_equal(batch_array_added.a, np.array([2, 3, 4])) assert np.array_equal(batch_array_added.c.d, np.array([2, 4, 6])) @staticmethod def test_batch_to_numpy_without_copy() -> None: batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) a_mem_addr_orig = batch.a.__array_interface__["data"][0] c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] batch.to_numpy_() a_mem_addr_new = batch.a.__array_interface__["data"][0] c_mem_addr_new = batch.b.c.__array_interface__["data"][0] assert a_mem_addr_new == a_mem_addr_orig assert c_mem_addr_new == c_mem_addr_orig @staticmethod def test_batch_from_to_numpy_without_copy() -> None: batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) a_mem_addr_orig = batch.a.__array_interface__["data"][0] c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] batch.to_torch_() batch.to_numpy_() a_mem_addr_new = batch.a.__array_interface__["data"][0] c_mem_addr_new = batch.b.c.__array_interface__["data"][0] assert a_mem_addr_new == a_mem_addr_orig assert c_mem_addr_new == c_mem_addr_orig @staticmethod def test_batch_over_batch_to_torch() -> None: batch = Batch( a=np.float64(1.0), b=Batch( c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64), ), ) batch.b.set_array_at_key(np.array([1]), "e") batch.to_torch_() assert isinstance(batch.a, torch.Tensor) assert isinstance(batch.b.c, torch.Tensor) assert isinstance(batch.b.d, torch.Tensor) assert isinstance(batch.b.e, torch.Tensor) assert batch.a.dtype == torch.float64 assert batch.b.c.dtype == torch.float32 assert batch.b.d.dtype == torch.float64 assert batch.b.e.dtype in (torch.int64, torch.int32) batch.to_torch_(dtype=torch.float32) assert batch.a.dtype == torch.float32 assert batch.b.c.dtype == torch.float32 assert batch.b.d.dtype == torch.float32 assert batch.b.e.dtype == torch.float32 @staticmethod @pytest.mark.parametrize( "dist, expected_batch_shape", [ (Categorical(probs=torch.tensor([0.3, 0.7])), (1,)), (Categorical(probs=torch.tensor([[0.3, 0.7], [0.4, 0.6]])), (2,)), (Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), (1,)), ( Normal(loc=torch.tensor([0.0, 1.0]), scale=torch.tensor([1.0, 2.0])), (2,), ), ( Independent(Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), 0), (1,), ), ( Independent( Normal(loc=torch.tensor([0.0, 1.0]), scale=torch.tensor([1.0, 2.0])), 0, ), (2,), ), ], ) def test_dist_to_atleast_2d(dist: Distribution, expected_batch_shape: tuple[int]) -> None: result = dist_to_atleast_2d(dist) assert result.batch_shape == expected_batch_shape # Additionally check that the parameters are correctly transformed if isinstance(dist, Categorical): assert isinstance(result, Categorical) assert result.probs.shape[:-1] == expected_batch_shape elif isinstance(dist, Normal): assert isinstance(result, Normal) assert result.loc.shape == expected_batch_shape assert result.scale.shape == expected_batch_shape elif isinstance(dist, Independent): assert isinstance(result, Independent) assert result.base_dist.batch_shape == expected_batch_shape @staticmethod @pytest.mark.parametrize( "dist", [ Categorical(probs=torch.tensor([0.3, 0.7])), Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), Independent(Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), 0), ], ) def test_dist_to_atleast_2d_idempotent(dist: Distribution) -> None: result1 = dist_to_atleast_2d(dist) result2 = dist_to_atleast_2d(result1) assert result1 == result2 @staticmethod def test_batch_to_atleast_2d() -> None: scalar_batch = Batch(a=1, b=2, dist=Categorical(probs=torch.ones(3))) assert scalar_batch.dist.batch_shape == () assert scalar_batch.a.shape == scalar_batch.b.shape == () scalar_batch_2d = scalar_batch.to_at_least_2d() assert scalar_batch_2d.dist.batch_shape == (1,) assert scalar_batch_2d.a.shape == scalar_batch_2d.b.shape == (1, 1) class TestAssignment: @staticmethod def test_assign_full_length_array() -> None: batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) batch.set_array_at_key(np.array([1, 2, 3]), "a") batch.set_array_at_key(np.array([4, 5, 6]), "new_key") assert np.array_equal(batch.a, np.array([1, 2, 3])) assert np.array_equal(batch.new_key, np.array([4, 5, 6])) # other keys are not affected assert np.array_equal(batch.b, np.array([7, 8, 9])) assert np.array_equal(batch.c.d, np.array([1, 2, 3])) with pytest.raises(ValueError): # wrong length batch.set_array_at_key(np.array([1, 2]), "a") @staticmethod def test_assign_subarray_existing_key() -> None: batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) batch.set_array_at_key(np.array([1, 2]), "a", index=[0, 1]) assert np.array_equal(batch.a, np.array([1, 2, 6])) batch.set_array_at_key(np.array([10, 12]), "a", index=slice(0, 2)) assert np.array_equal(batch.a, np.array([10, 12, 6])) batch.set_array_at_key(np.array([1, 2]), "a", index=[0, 2]) assert np.array_equal(batch.a, np.array([1, 12, 2])) batch.set_array_at_key(np.array([1, 2]), "a", index=[2, 0]) assert np.array_equal(batch.a, np.array([2, 12, 1])) batch.set_array_at_key(np.array([1, 2, 3]), "a", index=[2, 1, 0]) assert np.array_equal(batch.a, np.array([3, 2, 1])) with pytest.raises(IndexError): # Index out of bounds batch.set_array_at_key(np.array([1, 2]), "a", index=[10, 11]) # other keys are not affected assert np.array_equal(batch.b, np.array([7, 8, 9])) assert np.array_equal(batch.c.d, np.array([1, 2, 3])) @staticmethod def test_assign_subarray_new_key() -> None: batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) batch.set_array_at_key(np.array([1, 2]), "new_key", index=[0, 1], default_value=0) assert np.array_equal(batch.new_key, np.array([1, 2, 0])) # with float, None can be cast to NaN batch.set_array_at_key(np.array([1.0, 2.0]), "new_key2", index=[0, 1]) assert np.array_equal(batch.new_key2, np.array([1.0, 2.0, np.nan]), equal_nan=True) @staticmethod def test_isnull() -> None: batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([1, None, 3])}) batch_isnan = batch.isnull() assert not batch_isnan.a.any() assert batch_isnan.b[2] assert not batch_isnan.b[:2].any() assert np.array_equal(batch_isnan.c.d, np.array([False, True, False])) @staticmethod def test_hasnull() -> None: batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([1, 2, 3])}) assert batch.hasnull() batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) assert not batch.hasnull() batch = Batch(a=[4, 5, 6], c={"d": np.array([1, None, 3])}) assert batch.hasnull() @staticmethod def test_dropnull() -> None: batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([None, 2.1, 3.0])}) assert batch.dropnull() == Batch( a=[5], b=[8], c={"d": np.array([2.1])}, ).apply_values_transform( np.atleast_1d, ) batch2 = Batch(a=[4, 5, 6, 7], b=[7, 8, None, 10], c={"d": np.array([None, 2, 3, 4])}) assert batch2.dropnull() == Batch(a=[5, 7], b=[8, 10], c={"d": np.array([2, 4])}) batch_no_nan = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) assert batch_no_nan.dropnull() == batch_no_nan class TestSlicing: # TODO: parametrize with other dists @staticmethod def test_slice_distribution() -> None: cat_probs = torch.randint(1, 10, (10, 3)) dist = Categorical(probs=cat_probs) batch = Batch(dist=dist) selected_idx = [1, 3] sliced_batch = batch[selected_idx] sliced_probs = cat_probs[selected_idx] assert torch.allclose(sliced_batch.dist.probs, Categorical(probs=sliced_probs).probs) assert torch.allclose( Categorical(probs=sliced_probs).probs, get_sliced_dist(dist, selected_idx).probs, ) # retrieving a single index assert torch.allclose(batch[0].dist.probs, dist.probs[0]) @staticmethod def test_getitem_with_int_gives_scalars() -> None: batch = Batch(a=[1, 2], b=Batch(c=[3, 4])) batch_sliced = batch[0] assert batch_sliced.a == np.array(1) assert batch_sliced.b.c == np.array(3) @staticmethod @pytest.mark.parametrize("index", ([0, 1], np.array([0, 1]), torch.tensor([0, 1]), slice(0, 2))) def test_getitem_with_slice_gives_subslice(index: IndexType) -> None: batch = Batch(a=[1, 2, 3], b=Batch(c=torch.tensor([4, 5, 6]))) batch_sliced = batch[index] assert (batch_sliced.a == batch.a[index]).all() assert (batch_sliced.b.c == batch.b.c[index]).all() @staticmethod def test_len_batch_with_dist() -> None: batch_with_dist = Batch(a=[1, 2, 3], dist=Categorical(torch.ones((3, 3))), b=None) batch_with_dist_sliced = batch_with_dist[:2] assert batch_with_dist_sliced.b is None assert len(batch_with_dist_sliced) == 2 assert np.array_equal(batch_with_dist_sliced.a, np.array([1, 2])) assert torch.allclose( batch_with_dist_sliced.dist.probs, Categorical(torch.ones(2, 3)).probs, ) with pytest.raises(TypeError): # scalar batches have no len len(batch_with_dist[0]) ================================================ FILE: test/base/test_buffer.py ================================================ import os import pickle import tempfile from typing import cast import h5py import numpy as np import numpy.typing as npt import pytest import torch from test.base.env import MoveToRightEnv, MyGoalEnv from tianshou.data import ( Batch, CachedReplayBuffer, HERReplayBuffer, HERVectorReplayBuffer, PrioritizedReplayBuffer, PrioritizedVectorReplayBuffer, ReplayBuffer, SegmentTree, VectorReplayBuffer, ) from tianshou.data.types import RolloutBatchProtocol from tianshou.data.utils.converter import to_hdf5 def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: env = MoveToRightEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) assert str(buf) == buf.__class__.__name__ + "()" obs, _ = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) buf.add( Batch( obs=obs, act=[act], rew=rew, terminated=terminated, truncated=truncated, obs_next=obs_next, info=info, ), ) obs = obs_next assert len(buf) == min(bufsize, i + 1) assert buf.act.dtype == int assert buf.act.shape == (bufsize, 1) data, indices = buf.sample(bufsize * 2) assert isinstance(data, Batch) assert (indices < len(buf)).all() assert (data.obs < size).all() assert (data.done >= 0).all() assert (data.done <= 1).all() assert (data.terminated >= 0).all() assert (data.terminated <= 1).all() assert (data.truncated >= 0).all() assert (data.truncated <= 1).all() b = ReplayBuffer(size=10) # neg bsz should return empty index assert b.sample_indices(-1).tolist() == [] ptr, ep_rew, ep_len, ep_idx = b.add( Batch( obs=1, act=1, rew=1, terminated=1, truncated=0, obs_next="str", info={"a": 3, "b": {"c": 5.0}}, ), ) assert b.obs[0] == 1 assert b.done[0] assert b.terminated[0] assert not b.truncated[0] assert b.obs_next[0] == "str" assert np.all(b.obs[1:] == 0) assert np.all(b.obs_next[1:] == np.array(None)) assert b.info.a[0] == 3 assert b.info.a.dtype == int assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 assert b.info.b.c.dtype == float assert np.all(b.info.b.c[1:] == 0.0) assert ptr.shape == (1,) assert ptr[0] == 0 assert ep_rew.shape == (1,) assert ep_rew[0] == 1 assert ep_len.shape == (1,) assert ep_len[0] == 1 assert ep_idx.shape == (1,) assert ep_idx[0] == 0 # test extra keys pop up, the buffer should handle it dynamically batch = cast( RolloutBatchProtocol, Batch( obs=2, act=2, rew=2, terminated=0, truncated=0, obs_next="str2", info={"a": 4, "d": {"e": -np.inf}}, ), ) b.add(batch) info_keys = ["a", "b", "d"] assert set(b.info.keys()) == set(info_keys) assert b.info.a[1] == 4 assert b.info.b.c[1] == 0 assert b.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 batch.done = np.array([True]) batch.terminated = np.array([False]) batch.truncated = np.array([True]) batch.info.e = np.zeros([1, 4]) # type: ignore batch: RolloutBatchProtocol = Batch.stack([batch]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) assert ptr.shape == (1,) assert ptr[0] == 2 assert ep_rew.shape == (1,) assert ep_rew[0] == 4 assert ep_len.shape == (1,) assert ep_len[0] == 2 assert ep_idx.shape == (1,) assert ep_idx[0] == 1 assert set(b.info.keys()) == {*info_keys, "e"} assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): b[22] # test prev / next assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) batch.done = [0] b.add(batch, buffer_ids=[0]) assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) def test_ignore_obs_next(size: int = 10) -> None: # Issue 82 buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): buf.add( cast( RolloutBatchProtocol, Batch( obs={ "mask1": np.array([i, 1, 1, 0, 0]), "mask2": np.array([i + 4, 0, 1, 0, 0]), "mask": i, }, act={"act_id": i, "position_id": i + 3}, rew=i, terminated=i % 3 == 0, truncated=False, info={"if": i}, ), ), ) indices = np.arange(len(buf)) orig = np.arange(len(buf)) data = buf[indices] data2 = buf[indices] assert isinstance(data, Batch) assert isinstance(data2, Batch) assert np.allclose(indices, orig) assert hasattr(data.obs_next, "mask") and hasattr( data2.obs_next, "mask", ), "Both `data.obs_next` and `data2.obs_next` must have attribute `mask`." assert np.allclose(data.obs_next.mask, data2.obs_next.mask) assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9]) buf.stack_num = 4 data = buf[indices] data2 = buf[indices] assert hasattr(data.obs_next, "mask") and hasattr( data2.obs_next, "mask", ), "Both `data.obs_next` and `data2.obs_next` must have attribute `mask`." assert np.allclose(data.obs_next.mask, data2.obs_next.mask) assert np.allclose( data.obs_next.mask, np.array( [ [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3], [4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6], [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9], ], ), ) assert np.allclose(data["info"]["if"], data2["info"]["if"]) assert np.allclose( data["info"]["if"], np.array( [ [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6], [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9], ], ), ) assert data.obs_next def test_stack(size: int = 5, bufsize: int = 9, stack_num: int = 4, cached_num: int = 3) -> None: env = MoveToRightEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) obs, info = env.reset(options={"state": 1}) for _ in range(16): obs_next, rew, terminated, truncated, info = env.step(1) done = terminated or truncated buf.add( cast( RolloutBatchProtocol, Batch( obs=obs, act=1, rew=rew, terminated=terminated, truncated=truncated, info=info, ), ), ) buf2.add( cast( RolloutBatchProtocol, Batch( obs=obs, act=1, rew=rew, terminated=terminated, truncated=truncated, info=info, ), ), ) buf3.add( cast( RolloutBatchProtocol, Batch( obs=[obs, obs, obs], act=1, rew=rew, terminated=terminated, truncated=truncated, obs_next=[obs, obs], info=info, ), ), ) obs = obs_next if done: obs, info = env.reset(options={"state": 1}) indices = np.arange(len(buf)) assert np.allclose( buf.get(indices, "obs")[..., 0], [ [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1], ], ) assert np.allclose(buf.get(indices, "obs"), buf3.get(indices, "obs")) assert np.allclose(buf.get(indices, "obs"), buf3.get(indices, "obs_next")) _, indices = buf2.sample(0) assert indices.tolist() == [2, 6] _, indices = buf2.sample(1) assert indices[0] in [2, 6] batch, indices = buf2.sample(-1) # neg bsz -> no data assert indices.tolist() == [] assert len(batch) == 0 with pytest.raises(IndexError): buf[bufsize * 2] def test_prioritized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: env = MoveToRightEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) obs, _ = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) batch = cast( RolloutBatchProtocol, Batch( obs=obs, act=act, rew=rew, terminated=terminated, truncated=truncated, obs_next=obs_next, info=info, policy=np.random.randn() - 0.5, ), ) batch_stack: RolloutBatchProtocol = Batch.stack([batch, batch, batch]) buf.add(Batch.stack([batch]), buffer_ids=[0]) buf2.add(batch_stack, buffer_ids=[0, 1, 2]) obs = obs_next data, indices = buf.sample(len(buf) // 2) if len(buf) // 2 == 0: assert len(data) == len(buf) else: assert len(data) == len(buf) // 2 assert len(buf) == min(bufsize, i + 1) assert len(buf2) == min(bufsize, 3 * (i + 1)) # check single buffer's data assert buf.info.key.shape == (buf.maxsize,) assert buf.rew.dtype == float assert buf.done.dtype == bool assert buf.terminated.dtype == bool assert buf.truncated.dtype == bool data, indices = buf.sample(len(buf) // 2) buf.update_weight(indices, -data.weight / 2) assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2) ** buf._alpha) # check multi buffer's data assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1) batch_sample, indices = buf2.sample(10) buf2.update_weight(indices, batch_sample.weight * 0) weight = buf2[np.arange(buf2.maxsize)].weight assert isinstance(weight, np.ndarray) mask = np.isin(np.arange(buf2.maxsize), indices) selected_weight = weight[mask] unselected_weight = weight[~mask] assert np.all(selected_weight == selected_weight[0]) assert np.all(unselected_weight == unselected_weight[0]) assert unselected_weight[0] < selected_weight[0] assert selected_weight[0] <= 1 def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4) -> None: env_size = size env = MyGoalEnv(env_size, array_state=True) def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: return env.compute_reward_fn(ag, g, {}) buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) buf2 = HERVectorReplayBuffer( bufsize, buffer_num=3, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8, ) # Apply her on every episodes sampled (Hacky but necessary for deterministic test) buf.future_p = 1 for buf2_buf in buf2.buffers: buf2_buf.future_p = 1 obs, _ = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) batch = cast( RolloutBatchProtocol, Batch( obs=obs, act=[act], rew=rew, terminated=terminated, truncated=truncated, obs_next=obs_next, info=info, ), ) buf.add(batch) buf2.add(Batch.stack([batch, batch, batch]), buffer_ids=[0, 1, 2]) obs = obs_next assert len(buf) == min(bufsize, i + 1) assert len(buf2) == min(bufsize, 3 * (i + 1)) batch_sample, indices = buf.sample(sample_sz) # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() for _ in range(2 * env_size): obs_in_buf = cast(Batch, buf[tmp_indices].obs) obs_next_buf = cast(Batch, buf[tmp_indices].obs_next) rew_in_buf = buf[tmp_indices].rew g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0] g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == g[0]) assert np.all(g_next == g_next[0]) assert np.all(rew_in_buf == (ag_next == g).astype(np.float32)) tmp_indices = buf.next(tmp_indices) # Check that goals are correctly restored buf._restore_cache() tmp_indices = indices.copy() for _ in range(2 * env_size): obs_in_buf = cast(Batch, buf[tmp_indices].obs) obs_next_buf = cast(Batch, buf[tmp_indices].obs_next) g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == env_size) assert np.all(g_next == g_next[0]) assert np.all(g == g[0]) tmp_indices = buf.next(tmp_indices) # Test vector buffer batch_sample, indices = buf2.sample(sample_sz) # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() for _ in range(2 * env_size): obs_in_buf = cast(Batch, buf2[tmp_indices].obs) obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next) rew_buf = buf2[tmp_indices].rew g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0] g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == g_next) assert np.all(rew_buf == (ag_next == g).astype(np.float32)) tmp_indices = buf2.next(tmp_indices) # Check that goals are correctly restored buf2._restore_cache() tmp_indices = indices.copy() for _ in range(2 * env_size): obs_in_buf = cast(Batch, buf2[tmp_indices].obs) obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next) g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == env_size) assert np.all(g_next == g_next[0]) assert np.all(g == g[0]) tmp_indices = buf2.next(tmp_indices) # Test handling cycled indices env_size = size bufsize = 15 env = MyGoalEnv(size=env_size, array_state=False) buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) buf.future_p = 1 for ep_len in [5, 10]: obs, _ = env.reset() for i in range(ep_len): act = 1 obs_next, rew, terminated, truncated, info = env.step(act) batch = cast( RolloutBatchProtocol, Batch( obs=obs, act=[act], rew=rew, terminated=(i == ep_len - 1), truncated=(i == ep_len - 1), obs_next=obs_next, info=info, ), ) buf.add(batch) obs = obs_next batch_sample, indices = buf.sample(0) assert np.all(buf.obs.desired_goal[:5] == buf.obs.desired_goal[0]) assert np.all(buf.obs.desired_goal[5:10] == buf.obs.desired_goal[5]) assert np.all(buf.obs.desired_goal[5:] == buf.obs.desired_goal[14]) # (same ep) assert np.all(buf.obs.desired_goal[0] != buf.obs.desired_goal[5]) # (diff ep) # Another test case for cycled indices env_size = 99 bufsize = 15 env = MyGoalEnv(env_size, array_state=False) buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) buf.future_p = 1 for x, ep_len in enumerate([10, 20]): obs, _ = env.reset() for i in range(ep_len): act = 1 obs_next, rew, terminated, truncated, info = env.step(act) batch = cast( RolloutBatchProtocol, Batch( obs=obs, act=[act], rew=rew, terminated=(i == ep_len - 1), truncated=(i == ep_len - 1), obs_next=obs_next, info=info, ), ) if x == 1 and obs["observation"] < 10: obs = obs_next continue buf.add(batch) obs = obs_next buf._restore_cache() sample_indices = np.array([10]) # Suppose the sampled indices is [10] buf.rewrite_transitions(sample_indices) assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20] def test_update() -> None: buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): buf1.add( cast( RolloutBatchProtocol, Batch( obs=np.array([i]), act=float(i), rew=i * i, terminated=i % 2 == 0, truncated=False, info={"incident": "found"}, ), ), ) assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) assert (buf2.obs[0] == buf1.obs[1]).all() assert (buf2.obs[-1] == buf1.obs[0]).all() b = CachedReplayBuffer(ReplayBuffer(10), 4, 5) with pytest.raises(NotImplementedError): b.update(b) def test_segtree() -> None: realop = np.sum # small test actual_len = 8 tree = SegmentTree(actual_len) # 1-15. 8-15 are leaf nodes assert len(tree) == actual_len assert np.all([tree[i] == 0.0 for i in range(actual_len)]) with pytest.raises(IndexError): tree[actual_len] naive = np.zeros(actual_len) for _ in range(1000): # random choose a place to perform single update index: int | np.ndarray = np.random.randint(actual_len) value: float | np.ndarray = np.random.rand() naive[index] = value tree[index] = value for i in range(actual_len): for j in range(i + 1, actual_len): ref = realop(naive[i:j]) out = tree.reduce(i, j) assert np.allclose(ref, out), (ref, out) assert np.allclose(tree.reduce(start=1), realop(naive[1:])) assert np.allclose(tree.reduce(end=-1), realop(naive[:-1])) # batch setitem for _ in range(1000): index = np.random.choice(actual_len, size=4) value = np.random.rand(4) naive[index] = value tree[index] = value assert np.allclose(realop(naive), tree.reduce()) for _ in range(10): left = np.random.randint(actual_len) right = np.random.randint(left + 1, actual_len + 1) assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) # large test actual_len = 16384 tree = SegmentTree(actual_len) naive = np.zeros([actual_len]) for _ in range(1000): index = np.random.choice(actual_len, size=64) value = np.random.rand(64) naive[index] = value tree[index] = value assert np.allclose(realop(naive), tree.reduce()) for _ in range(10): left = np.random.randint(actual_len) right = np.random.randint(left + 1, actual_len + 1) assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) # test prefix-sum-idx actual_len = 8 tree = SegmentTree(actual_len) naive = np.random.rand(actual_len) tree[np.arange(actual_len)] = naive for _ in range(1000): scalar = np.random.rand() * naive.sum() index = tree.get_prefix_sum_idx(scalar) assert naive[:index].sum() <= scalar <= naive[: index + 1].sum() # corner case here naive = np.ones(actual_len, int) tree[np.arange(actual_len)] = naive for scalar in range(actual_len): index = tree.get_prefix_sum_idx(scalar * 1.0) assert naive[:index].sum() <= scalar <= naive[: index + 1].sum() tree = SegmentTree(10) tree[np.arange(3)] = np.array([0.1, 0, 0.1]) assert np.allclose( tree.get_prefix_sum_idx(np.array([0, 0.1, 0.1 + 1e-6, 0.2 - 1e-6])), [0, 0, 2, 2], ) with pytest.raises(AssertionError): tree.get_prefix_sum_idx(0.2) # test large prefix-sum-idx actual_len = 16384 tree = SegmentTree(actual_len) naive = np.random.rand(actual_len) tree[np.arange(actual_len)] = naive for _ in range(1000): scalar = np.random.rand() * naive.sum() index = tree.get_prefix_sum_idx(scalar) assert naive[:index].sum() <= scalar <= naive[: index + 1].sum() def test_pickle() -> None: size = 100 vbuf = ReplayBuffer(size, stack_num=2) pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) rew = np.array([1, 1]) for i in range(4): vbuf.add( cast( RolloutBatchProtocol, Batch( obs=Batch(index=np.array([i])), act=0, rew=rew, terminated=0, truncated=0, ), ), ) for i in range(5): pbuf.add( cast( RolloutBatchProtocol, Batch( obs=Batch(index=np.array([i])), act=2, rew=rew, terminated=0, truncated=0, info=np.random.rand(), ), ), ) # save & load _vbuf = pickle.loads(pickle.dumps(vbuf)) _pbuf = pickle.loads(pickle.dumps(pbuf)) assert len(_vbuf) == len(vbuf) assert np.allclose(_vbuf.act, vbuf.act) assert len(_pbuf) == len(pbuf) assert np.allclose(_pbuf.act, pbuf.act) # make sure the meta var is identical assert _vbuf.stack_num == vbuf.stack_num assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))], pbuf.weight[np.arange(len(pbuf))]) def test_hdf5() -> None: size = 100 buffers = { "array": ReplayBuffer(size, stack_num=2), "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = "cuda" if torch.cuda.is_available() else "cpu" info_t = torch.tensor([1.0]).to(device) for i in range(4): kwargs = { "obs": Batch(index=np.array([i])), "act": i, "rew": np.array([1, 2]), "terminated": i % 3 == 2, "truncated": False, "done": i % 3 == 2, "info": {"number": {"n": i, "t": info_t}, "extra": None}, } buffers["array"].add(cast(RolloutBatchProtocol, Batch(kwargs))) buffers["prioritized"].add(cast(RolloutBatchProtocol, Batch(kwargs))) # save paths = {} for k, buf in buffers.items(): f, path = tempfile.mkstemp(suffix=".hdf5") os.close(f) buf.save_hdf5(path) paths[k] = path # load replay buffer _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths} # compare for k in buffers: assert len(_buffers[k]) == len(buffers[k]) assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num assert _buffers[k].maxsize == buffers[k].maxsize assert np.all(_buffers[k]._indices == buffers[k]._indices) for k in ["array", "prioritized"]: assert _buffers[k]._insertion_idx == buffers[k]._insertion_idx assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: assert np.all(buffers[k][:]["info"].number.n == _buffers[k][:]["info"].number.n) assert np.all(buffers[k][:]["info"]["extra"] == _buffers[k][:]["info"]["extra"]) # raise exception when value cannot be pickled data = {"not_supported": lambda x: x * x} grp = h5py.Group with pytest.raises(NotImplementedError): to_hdf5(data, grp) # type: ignore # ndarray with data type not supported by HDF5 that cannot be pickled data = {"not_supported": np.array(lambda x: x * x)} grp = h5py.Group with pytest.raises(RuntimeError): to_hdf5(data, grp) # type: ignore def test_replaybuffermanager() -> None: buf = VectorReplayBuffer(20, 4) batch = cast( RolloutBatchProtocol, Batch( obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], terminated=[0, 0, 1], truncated=[0, 0, 0], ), ) ptr, ep_rew, ep_len, ep_idx = buf.add(batch, buffer_ids=[0, 1, 2]) assert np.all(ep_len == [0, 0, 1]) assert np.all(ep_rew == [0, 0, 3]) assert np.all(ptr == [0, 5, 10]) assert np.all(ep_idx == [0, 5, 10]) with pytest.raises(NotImplementedError): # ReplayBufferManager cannot be updated buf.update(buf) # sample index / prev / next / unfinished_index indices = buf.sample_indices(11000) assert np.bincount(indices)[[0, 5, 10]].min() >= 3000 # uniform sample batch, indices = buf.sample(0) assert np.allclose(indices, [0, 5, 10]) indices_prev = buf.prev(indices) assert np.allclose(indices_prev, indices), indices_prev indices_next = buf.next(indices) assert np.allclose(indices_next, indices), indices_next assert np.allclose(buf.unfinished_index(), [0, 5]) buf.add( cast( RolloutBatchProtocol, Batch(obs=[4], act=[4], rew=[4], terminated=[1], truncated=[0]), ), buffer_ids=[3], ) assert np.allclose(buf.unfinished_index(), [0, 5]) batch, indices = buf.sample(10) batch, indices = buf.sample(0) assert np.allclose(indices, [0, 5, 10, 15]) indices_prev = buf.prev(indices) assert np.allclose(indices_prev, indices), indices_prev indices_next = buf.next(indices) assert np.allclose(indices_next, indices), indices_next data = np.array([0, 0, 0, 0]) buf.add( cast( RolloutBatchProtocol, Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), ), buffer_ids=[0, 1, 2, 3], ) buf.add( cast( RolloutBatchProtocol, Batch(obs=data, act=data, rew=data, terminated=1 - data, truncated=data), ), buffer_ids=[0, 1, 2, 3], ) assert len(buf) == 12 buf.add( cast( RolloutBatchProtocol, Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), ), buffer_ids=[0, 1, 2, 3], ) buf.add( cast( RolloutBatchProtocol, Batch(obs=data, act=data, rew=data, terminated=[0, 1, 0, 1], truncated=data), ), buffer_ids=[0, 1, 2, 3], ) assert len(buf) == 20 indices = buf.sample_indices(120000) assert np.bincount(indices).min() >= 5000 batch, indices = buf.sample(10) indices = buf.sample_indices(0) assert np.allclose(indices, np.arange(len(buf))) # check the actual data stored in buf._meta assert np.allclose( buf.done, [ 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, ], ) assert np.allclose( buf.prev(indices), [ 0, 0, 1, 3, 3, 5, 5, 6, 8, 8, 10, 11, 11, 13, 13, 15, 16, 16, 18, 18, ], ) assert np.allclose( buf.next(indices), [ 1, 2, 2, 4, 4, 6, 7, 7, 9, 9, 10, 12, 12, 14, 14, 15, 17, 17, 19, 19, ], ) assert np.allclose(buf.unfinished_index(), [4, 14]) ptr, ep_rew, ep_len, ep_idx = buf.add( cast( RolloutBatchProtocol, Batch(obs=[1], act=[1], rew=[1], terminated=[1], truncated=[0]), ), buffer_ids=[2], ) assert np.all(ep_len == [3]) assert np.all(ep_rew == [1]) assert np.all(ptr == [10]) assert np.all(ep_idx == [13]) assert np.allclose(buf.unfinished_index(), [4]) indices = np.array(sorted(buf.sample_indices(0))) assert np.allclose(indices, np.arange(len(buf))) assert np.allclose( buf.prev(indices), [ 0, 0, 1, 3, 3, 5, 5, 6, 8, 8, 14, 11, 11, 13, 13, 15, 16, 16, 18, 18, ], ) assert np.allclose( buf.next(indices), [ 1, 2, 2, 4, 4, 6, 7, 7, 9, 9, 10, 12, 12, 14, 10, 15, 17, 17, 19, 19, ], ) # corner case: list, int and -1 assert buf.prev(-1) == buf.prev(np.array([buf.maxsize - 1]))[0] assert buf.next(-1) == buf.next(np.array([buf.maxsize - 1]))[0] batch = buf._meta batch.info = np.ones(buf.maxsize) buf.set_batch(batch) assert np.allclose(buf.buffers[-1].info, [1] * 5) assert buf.sample_indices(-1).tolist() == [] assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == object def test_cachedbuffer() -> None: buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) assert buf.sample_indices(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer ptr, ep_rew, ep_len, ep_idx = buf.add( cast( RolloutBatchProtocol, Batch(obs=[1], act=[1], rew=[1], terminated=[0], truncated=[0]), ), buffer_ids=[1], ) obs = np.zeros(buf.maxsize) obs[15] = 1 indices = buf.sample_indices(0) assert np.allclose(indices, [15]) assert np.allclose(buf.prev(indices), [15]) assert np.allclose(buf.next(indices), [15]) assert np.allclose(buf.obs, obs) assert np.all(ep_len == [0]) assert np.all(ep_rew == [0.0]) assert np.all(ptr == [15]) assert np.all(ep_idx == [15]) ptr, ep_rew, ep_len, ep_idx = buf.add( cast( RolloutBatchProtocol, Batch(obs=[2], act=[2], rew=[2], terminated=[1], truncated=[0]), ), buffer_ids=[3], ) obs[[0, 25]] = 2 indices = buf.sample_indices(0) assert np.allclose(indices, [0, 15]) assert np.allclose(buf.prev(indices), [0, 15]) assert np.allclose(buf.next(indices), [0, 15]) assert np.allclose(buf.obs, obs) assert np.all(ep_len == [1]) assert np.all(ep_rew == [2.0]) assert np.all(ptr == [0]) assert np.all(ep_idx == [0]) assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_indices(0), [0, 15]) ptr, ep_rew, ep_len, ep_idx = buf.add( cast( RolloutBatchProtocol, Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], terminated=[0, 1], truncated=[0, 0]), ), buffer_ids=[3, 1], # TODO ) assert np.all(ep_len == [0, 2]) assert np.all(ep_rew == [0, 5.0]) assert np.all(ptr == [25, 2]) assert np.all(ep_idx == [25, 1]) obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] assert np.allclose(buf.obs, obs) assert np.allclose(buf.unfinished_index(), [25]) indices = buf.sample_indices(0) assert np.allclose(indices, [0, 1, 2, 25]) assert np.allclose(buf.done[indices], [1, 0, 1, 0]) assert np.allclose(buf.prev(indices), [0, 1, 1, 25]) assert np.allclose(buf.next(indices), [0, 2, 2, 25]) indices = buf.sample_indices(10000) assert np.bincount(indices)[[0, 1, 2, 25]].min() > 2000 # uniform sample # cached buffer with main_buffer size == 0 (no update) # used in test_collector buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) data = np.zeros(4) # TODO: this doesn't make any sense - why a matrix reward?! # See error message in ReplayBuffer._update_state_pre_add rew = np.ones([4, 4]) buf.add( cast( RolloutBatchProtocol, Batch( obs=data, act=data, rew=rew, terminated=[0, 0, 1, 1], truncated=[0, 0, 0, 0], ), ), ) buf.add( cast( RolloutBatchProtocol, Batch( obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0], ), ), ) buf.add( cast( RolloutBatchProtocol, Batch( obs=data, act=data, rew=rew, terminated=[1, 1, 1, 1], truncated=[0, 0, 0, 0], ), ), ) buf.add( cast( RolloutBatchProtocol, Batch( obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0], ), ), ) ptr, ep_rew, ep_len, ep_idx = buf.add( cast( RolloutBatchProtocol, Batch( obs=data, act=data, rew=rew, terminated=[0, 1, 0, 1], truncated=[0, 0, 0, 0], ), ), ) assert np.all(ptr == [1, -1, 11, -1]) assert np.all(ep_idx == [0, -1, 10, -1]) assert np.all(ep_len == [0, 2, 0, 2]) assert np.all(ep_rew == [data, data + 2, data, data + 2]) assert np.allclose( buf.done, [ 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ], ) indices = buf.sample_indices(0) assert np.allclose(indices, [0, 1, 10, 11]) assert np.allclose(buf.prev(indices), [0, 0, 10, 10]) assert np.allclose(buf.next(indices), [1, 1, 11, 11]) def test_multibuf_stack() -> None: size = 5 bufsize = 9 stack_num = 4 cached_num = 3 env = MoveToRightEnv(size) # test if CachedReplayBuffer can handle stack_num + ignore_obs_next buf4 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), cached_num, size, ) # test if CachedReplayBuffer can handle corner case: # buffer + stack_num + ignore_obs_next + sample_avail buf5 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True, sample_avail=True), cached_num, size, ) obs, info = env.reset(options={"state": 1}) obs = cast(np.ndarray, obs) for i in range(18): obs_next, rew, terminated, truncated, info = env.step(1) done = terminated or truncated obs_list = np.array([obs + size * i for i in range(cached_num)]) act_list = [1] * cached_num rew_list = [rew] * cached_num terminated_list = [terminated] * cached_num truncated_list = [truncated] * cached_num obs_next_list = -obs_list info_list = [info] * cached_num batch = cast( RolloutBatchProtocol, Batch( obs=obs_list, act=act_list, rew=rew_list, terminated=terminated_list, truncated=truncated_list, obs_next=obs_next_list, info=info_list, ), ) buf5.add(batch) buf4.add(batch) assert np.all(buf4.obs == buf5.obs) assert np.all(buf4.done == buf5.done) assert np.all(buf4.terminated == buf5.terminated) assert np.all(buf4.truncated == buf5.truncated) obs = obs_next if done: # obs is an array, but the env is malformed, so we can't properly type it obs, info = env.reset(options={"state": 1}) # type: ignore[assignment] # check the `add` order is correct assert np.allclose( buf4.obs.reshape(-1), [ 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer 1, 2, 3, 4, 0, # cached_buffer[0] 6, 7, 8, 9, 0, # cached_buffer[1] 11, 12, 13, 14, 0, # cached_buffer[2] ], ), buf4.obs assert np.allclose( buf4.done, [ 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer 0, 0, 0, 1, 0, # cached_buffer[0] 0, 0, 0, 1, 0, # cached_buffer[1] 0, 0, 0, 1, 0, # cached_buffer[2] ], ), buf4.done assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) indices = np.array(sorted(buf4.sample_indices(0))) assert np.allclose(indices, [*list(range(bufsize)), 9, 10, 14, 15, 19, 20]) cur_obs = buf4[indices].obs assert isinstance(cur_obs, np.ndarray) assert np.allclose( cur_obs[..., 0], [ [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], ], ) next_obs = buf4[indices].obs_next assert isinstance(next_obs, np.ndarray) assert np.allclose( next_obs[..., 0], [ [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], ], ) indices = buf5.sample_indices(0) assert np.allclose(sorted(indices), [2, 7]) assert np.all(np.isin(buf5.sample_indices(100), indices)) # manually change the stack num buf5.stack_num = 2 for buf in buf5.buffers: buf.stack_num = 2 indices = buf5.sample_indices(0) assert np.allclose(sorted(indices), [0, 1, 2, 5, 6, 7, 10, 15, 20]) batch_sample, _ = buf5.sample(0) # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True, ignore_obs_next=True), cached_num, size, ) obs = np.random.rand(size, 4, 84, 84) buf6.add( cast( RolloutBatchProtocol, Batch( obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], terminated=[0, 1], truncated=[0, 0], obs_next=[obs[3], obs[1]], ), ), buffer_ids=[1, 2], ) assert buf6.obs.shape == (buf6.maxsize, 84, 84) assert np.allclose(buf6.obs[0], obs[0, -1]) assert np.allclose(buf6.obs[14], obs[2, -1]) assert np.allclose(buf6.obs[19], obs[0, -1]) assert buf6[0].obs.shape == (4, 84, 84) def test_multibuf_hdf5() -> None: size = 100 buffers = { "vector": VectorReplayBuffer(size * 4, 4), "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size), } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = "cuda" if torch.cuda.is_available() else "cpu" info_t = torch.tensor([1.0]).to(device) for i in range(4): kwargs = { "obs": Batch(index=np.array([i])), "act": i, "rew": np.array([1, 2]), "terminated": i % 3 == 2, "truncated": False, "done": i % 3 == 2, "info": {"number": {"n": i, "t": info_t}, "extra": None}, } buffers["vector"].add(Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]) buffers["cached"].add(Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]) # save paths = {} for k, buf in buffers.items(): f, path = tempfile.mkstemp(suffix=".hdf5") os.close(f) buf.save_hdf5(path) paths[k] = path # load replay buffer _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths} # compare for k in buffers: assert len(_buffers[k]) == len(buffers[k]) assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num assert _buffers[k].maxsize == buffers[k].maxsize assert np.all(_buffers[k]._indices == buffers[k]._indices) # check shallow copy in VectorReplayBuffer for k in ["vector", "cached"]: buffers[k].info.number.n[0] = -100 assert buffers[k].buffers[0].info.number.n[0] == -100 # check if still behave normally for k in ["vector", "cached"]: kwargs = { "obs": Batch(index=np.array([5])), "act": 5, "rew": np.array([2, 1]), "terminated": False, "truncated": False, "done": False, "info": {"number": {"n": i}, "Timelimit.truncate": True}, } buffers[k].add(Batch.stack([kwargs, kwargs, kwargs, kwargs])) act = np.zeros(buffers[k].maxsize) if k == "vector": act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5]) act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5]) act[size * 3] = 5 elif k == "cached": act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) act[np.arange(3) + size] = np.array([3, 5, 2]) act[np.arange(3) + size * 2] = np.array([3, 5, 2]) act[np.arange(3) + size * 3] = np.array([3, 5, 2]) act[size * 4] = 5 assert np.allclose(buffers[k].act, act) info_keys = ["number", "extra", "Timelimit.truncate"] assert set(buffers[k].info.keys()) == set(info_keys) for path in paths.values(): os.remove(path) def test_from_data() -> None: obs_data: npt.NDArray[np.uint8] = np.ndarray((10, 3, 3), dtype="uint8") for i in range(10): obs_data[i] = i * np.ones((3, 3), dtype="uint8") obs_next_data = np.zeros_like(obs_data) obs_next_data[:-1] = obs_data[1:] f, path = tempfile.mkstemp(suffix=".hdf5") os.close(f) with h5py.File(path, "w") as f: obs = f.create_dataset("obs", data=obs_data) act = f.create_dataset("act", data=np.arange(10, dtype="int32")) rew = f.create_dataset("rew", data=np.arange(10, dtype="float32")) terminated = f.create_dataset("terminated", data=np.zeros(10, dtype="bool")) truncated = f.create_dataset("truncated", data=np.zeros(10, dtype="bool")) done = f.create_dataset("done", data=np.zeros(10, dtype="bool")) obs_next = f.create_dataset("obs_next", data=obs_next_data) buf = ReplayBuffer.from_data(obs, act, rew, terminated, truncated, done, obs_next) assert len(buf) == 10 batch = buf[3] cur_obs = batch.obs assert isinstance(cur_obs, np.ndarray) assert np.array_equal(cur_obs, 3 * np.ones((3, 3), dtype="uint8")) assert batch.act == 3 assert batch.rew == 3.0 assert not batch.done next_obs = batch.obs_next assert isinstance(next_obs, np.ndarray) assert np.array_equal(next_obs, 4 * np.ones((3, 3), dtype="uint8")) os.remove(path) def test_custom_key() -> None: batch = cast( RolloutBatchProtocol, Batch( obs_next=np.array( [ [ 1.174, -0.1151, -0.609, -0.5205, -0.9316, 3.236, -2.418, 0.386, 0.2227, -0.5117, 2.293, ], ], ), rew=np.array([4.28125]), act=np.array([[-0.3088, -0.4636, 0.4956]]), truncated=np.array([False]), obs=np.array( [ [ 1.193, -0.1203, -0.6123, -0.519, -0.9434, 3.32, -2.266, 0.9116, 0.623, 0.1259, 0.363, ], ], ), terminated=np.array([False]), done=np.array([False]), returns=np.array([74.70343082]), info=Batch(), policy=Batch(), ), ) buffer_size = len(batch.rew) buffer = ReplayBuffer(buffer_size) buffer.add(batch) sampled_batch, _ = buffer.sample(1) # Check if they have the same keys assert set(batch.get_keys()) == set( sampled_batch.get_keys(), ), f"Batches have different keys: {set(batch.get_keys())} and {set(sampled_batch.get_keys())}" # Compare the values for each key for key in batch.get_keys(): if isinstance(batch.__dict__[key], np.ndarray) and isinstance( sampled_batch.__dict__[key], np.ndarray, ): assert np.allclose( batch.__dict__[key], sampled_batch.__dict__[key], ), f"Value mismatch for key: {key}" if isinstance(batch.__dict__[key], Batch) and isinstance( sampled_batch.__dict__[key], Batch, ): assert len(batch.__dict__[key].get_keys()) == 0 assert len(sampled_batch.__dict__[key].get_keys()) == 0 def test_buffer_dropnull() -> None: size = 10 buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(4): buf.add( cast( RolloutBatchProtocol, Batch( obs={ "mask1": i + 1, "mask2": i + 4, "mask": i, }, act={"act_id": i, "position_id": i + 3}, rew=i, terminated=i % 3 == 0, truncated=False, info={"if": i}, ), ), ) assert len(buf[:3]) == 3 buf.set_array_at_key(np.array([1, 2, 3], float), "newkey", [0, 1, 2]) assert np.array_equal(buf.newkey[:3], np.array([1, 2, 3], float)) assert buf.hasnull() buf.dropnull() assert len(buf[:3]) == 3 assert not buf.hasnull() @pytest.fixture def dummy_rollout_batch() -> RolloutBatchProtocol: return cast( RolloutBatchProtocol, Batch( obs=np.arange(2), obs_next=np.arange(2), act=np.arange(5), rew=1, terminated=False, truncated=False, done=False, info={}, ), ) def test_get_replay_buffer_indices(dummy_rollout_batch: RolloutBatchProtocol) -> None: buffer = ReplayBuffer(5) for _ in range(5): buffer.add(dummy_rollout_batch) assert np.array_equal(buffer.get_buffer_indices(0, 3), [0, 1, 2]) assert np.array_equal(buffer.get_buffer_indices(3, 2), [3, 4, 0, 1]) assert np.array_equal(buffer.get_buffer_indices(0, 5), np.arange(5)) def test_get_vector_replay_buffer_indices( dummy_rollout_batch: RolloutBatchProtocol, ) -> None: stacked_batch = Batch.stack([dummy_rollout_batch, dummy_rollout_batch]) buffer = VectorReplayBuffer(10, 2) for _ in range(5): buffer.add(stacked_batch) assert np.array_equal(buffer.get_buffer_indices(0, 3), [0, 1, 2]) assert np.array_equal(buffer.get_buffer_indices(3, 2), [3, 4, 0, 1]) assert np.array_equal(buffer.get_buffer_indices(6, 9), [6, 7, 8]) assert np.array_equal(buffer.get_buffer_indices(8, 7), [8, 9, 5, 6]) with pytest.raises(ValueError): buffer.get_buffer_indices(3, 6) with pytest.raises(ValueError): buffer.get_buffer_indices(6, 3) ================================================ FILE: test/base/test_collector.py ================================================ from collections.abc import Callable, Sequence from typing import Any import gymnasium as gym import numpy as np import pytest import tqdm from test.base.env import MoveToRightEnv, NXEnv from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go from tianshou.data import ( AsyncCollector, Batch, CachedReplayBuffer, Collector, CollectStats, PrioritizedReplayBuffer, ReplayBuffer, VectorReplayBuffer, ) from tianshou.data.batch import BatchProtocol from tianshou.data.collector import ( CollectActionBatchProtocol, EpisodeRolloutHookMCReturn, StepHook, ) from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv try: import envpool except ImportError: envpool = None class MaxActionPolicy(Policy): def __init__( self, action_space: gym.spaces.Space | None = None, dict_state: bool = False, need_state: bool = True, action_shape: Sequence[int] | int | None = None, ) -> None: """Mock policy for testing, will always return an array of ones of the shape of the action space. Note that this doesn't make much sense for discrete action space (the output is then intepreted as logits, meaning all actions would be equally likely). :param action_space: the action space of the environment. If None, a dummy Box space will be used. :param bool dict_state: if the observation of the environment is a dict :param bool need_state: if the policy needs the hidden state (for RNN) """ action_space = action_space or gym.spaces.Box(-1, 1, (1,)) super().__init__(action_space=action_space) self.dict_state = dict_state self.need_state = need_state self.action_shape = action_shape def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> Batch: if self.need_state: if state is None: state = np.zeros((len(batch.obs), 2)) elif isinstance(state, np.ndarray | BatchProtocol): state += np.int_(1) elif isinstance(state, dict) and state.get("hidden") is not None: state["hidden"] += np.int_(1) if self.dict_state: if self.action_shape: action_shape = self.action_shape elif isinstance(batch.obs, Batch): action_shape = len(batch.obs["index"]) else: action_shape = len(batch.obs) return Batch(act=np.ones(action_shape), state=state) action_shape = self.action_shape if self.action_shape else len(batch.obs) return Batch(act=np.ones(action_shape), state=state) @pytest.fixture() def collector_with_single_env() -> Collector[CollectStats]: """The env will be a MoveToRightEnv with size 5, sleep 0.""" env = MoveToRightEnv(size=5, sleep=0) policy = MaxActionPolicy() collector = Collector[CollectStats](policy, env, ReplayBuffer(size=100)) collector.reset() return collector def test_collector() -> None: env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] subproc_venv_4_envs = SubprocVectorEnv(env_fns) dummy_venv_4_envs = DummyVectorEnv(env_fns) policy = MaxActionPolicy() single_env = env_fns[0]() c_single_env = Collector[CollectStats]( policy, single_env, ReplayBuffer(size=100), ) c_single_env.reset() c_single_env.collect(n_step=3) assert len(c_single_env.buffer) == 3 # TODO: direct attr access is an arcane way of using the buffer, it should be never done # The placeholders for entries are all zeros, so buffer.obs is an array filled with 3 # observations, and 97 zeros. # However, buffer[:] will have all attributes with length three... The non-filled entries are removed there # See above. For the single env, we start with obs=0, obs_next=1. # We move to obs=1, obs_next=2, # then the env is reset and we move to obs=0 # Making one more step results in obs_next=1 # The final 0 in the buffer.obs is because the buffer is initialized with zeros and the direct attr access assert np.allclose(c_single_env.buffer.obs[:4, 0], [0, 1, 0, 0]) obs_next = c_single_env.buffer[:].obs_next[..., 0] assert isinstance(obs_next, np.ndarray) assert np.allclose(obs_next, [1, 2, 1]) keys = np.zeros(100) keys[:3] = 1 assert np.allclose(c_single_env.buffer.info["key"], keys) for e in c_single_env.buffer.info["env"][:3]: assert isinstance(e, MoveToRightEnv) assert np.allclose(c_single_env.buffer.info["env_id"], 0) rews = np.zeros(100) rews[:3] = [0, 1, 0] assert np.allclose(c_single_env.buffer.rew, rews) # At this point, the buffer contains obs 0 -> 1 -> 0 # At start we have 3 entries in the buffer # We collect 3 episodes, in addition to the transitions we have collected before # 0 -> 1 -> 0 -> 0 (reset at collection start) -> 1 -> done (0) -> 1 -> done(0) # obs_next: 1 -> 2 -> 1 -> 1 (reset at collection start) -> 2 -> 1 -> 2 -> 1 -> 2 # In total, we will have 3 + 6 = 9 entries in the buffer c_single_env.collect(n_episode=3) assert len(c_single_env.buffer) == 8 assert np.allclose(c_single_env.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) obs_next = c_single_env.buffer[:].obs_next[..., 0] assert isinstance(obs_next, np.ndarray) assert np.allclose(obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) assert np.allclose(c_single_env.buffer.info["key"][:8], 1) for e in c_single_env.buffer.info["env"][:8]: assert isinstance(e, MoveToRightEnv) assert np.allclose(c_single_env.buffer.info["env_id"][:8], 0) assert np.allclose(c_single_env.buffer.rew[:8], [0, 1, 0, 1, 0, 1, 0, 1]) c_single_env.collect(n_step=3, random=True) c_subproc_venv_4_envs = Collector[CollectStats]( policy, subproc_venv_4_envs, VectorReplayBuffer(total_size=100, buffer_num=4), ) c_subproc_venv_4_envs.reset() # Collect some steps c_subproc_venv_4_envs.collect(n_step=8) obs = np.zeros(100) valid_indices = [0, 1, 25, 26, 50, 51, 75, 76] obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1] assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) obs_next = c_subproc_venv_4_envs.buffer[:].obs_next[..., 0] assert isinstance(obs_next, np.ndarray) assert np.allclose(obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) keys = np.zeros(100) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys) for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]: assert isinstance(e, MoveToRightEnv) env_ids = np.zeros(100) env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3] assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids) rews = np.zeros(100) rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0] assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews) # we previously collected 8 steps, 2 from each env, now we collect 4 episodes # each env will contribute an episode, which will be of lens 2 (first env was reset), 1, 2, 3 # So we get 8 + 2+1+2+3 = 16 steps c_subproc_venv_4_envs.collect(n_episode=4) assert len(c_subproc_venv_4_envs.buffer) == 16 valid_indices = [2, 3, 27, 52, 53, 77, 78, 79] obs[valid_indices] = [0, 1, 2, 2, 3, 2, 3, 4] assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) obs_next = c_subproc_venv_4_envs.buffer[:].obs_next[..., 0] assert isinstance(obs_next, np.ndarray) assert np.allclose( obs_next, [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], ) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys) for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]: assert isinstance(e, MoveToRightEnv) env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3] assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids) rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1] assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews) c_subproc_venv_4_envs.collect(n_episode=4, random=True) c_dummy_venv_4_envs = Collector[CollectStats]( policy, dummy_venv_4_envs, VectorReplayBuffer(total_size=100, buffer_num=4), ) c_dummy_venv_4_envs.reset() c_dummy_venv_4_envs.collect(n_episode=7) obs1 = obs.copy() obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] obs2 = obs.copy() obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3] c2obs = c_dummy_venv_4_envs.buffer.obs[:, 0] assert np.all(c2obs == obs1) or np.all(c2obs == obs2) c_dummy_venv_4_envs.reset_env() c_dummy_venv_4_envs.reset_buffer() assert c_dummy_venv_4_envs.collect(n_episode=8).n_collected_episodes == 8 valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57] obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3] assert np.all(c_dummy_venv_4_envs.buffer.obs[:, 0] == obs) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1] assert np.allclose(c_dummy_venv_4_envs.buffer.info["key"], keys) for e in c_dummy_venv_4_envs.buffer.info["env"][valid_indices]: assert isinstance(e, MoveToRightEnv) env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2] assert np.allclose(c_dummy_venv_4_envs.buffer.info["env_id"], env_ids) rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1] assert np.allclose(c_dummy_venv_4_envs.buffer.rew, rews) c_dummy_venv_4_envs.collect(n_episode=4, random=True) # test corner case with pytest.raises(ValueError): Collector[CollectStats](policy, dummy_venv_4_envs, ReplayBuffer(10)) with pytest.raises(ValueError): Collector[CollectStats](policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) with pytest.raises(ValueError): c_dummy_venv_4_envs.collect() def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]: return lambda: NXEnv(i, t) # test NXEnv for obs_type in ["array", "object"]: envs = SubprocVectorEnv([get_env_factory(i=i, t=obs_type) for i in [5, 10, 15, 20]]) c_suproc_new = Collector[CollectStats]( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), ) c_suproc_new.reset() c_suproc_new.collect(n_step=6) assert c_suproc_new.buffer.obs.dtype == object @pytest.fixture() def async_collector_and_env_lens() -> tuple[AsyncCollector, list[int]]: env_lens = [2, 3, 4, 5] env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MaxActionPolicy() bufsize = 60 async_collector = AsyncCollector( policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), ) async_collector.reset() return async_collector, env_lens class TestAsyncCollector: def test_collect_without_argument_gives_error( self, async_collector_and_env_lens: tuple[AsyncCollector, list[int]], ) -> None: c1, env_lens = async_collector_and_env_lens with pytest.raises(ValueError): c1.collect() def test_collect_one_episode_async( self, async_collector_and_env_lens: tuple[AsyncCollector, list[int]], ) -> None: c1, env_lens = async_collector_and_env_lens result = c1.collect(n_episode=1) assert result.n_collected_episodes >= 1 def test_enough_episodes_two_collection_cycles_n_episode_without_reset( self, async_collector_and_env_lens: tuple[AsyncCollector, list[int]], ) -> None: c1, env_lens = async_collector_and_env_lens n_episode = 2 result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=False) assert result_c1.n_collected_episodes >= n_episode result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=False) assert result_c2.n_collected_episodes >= n_episode def test_enough_episodes_two_collection_cycles_n_episode_with_reset( self, async_collector_and_env_lens: tuple[AsyncCollector, list[int]], ) -> None: c1, env_lens = async_collector_and_env_lens n_episode = 2 result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=True) assert result_c1.n_collected_episodes >= n_episode result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=True) assert result_c2.n_collected_episodes >= n_episode def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_episode( self, async_collector_and_env_lens: tuple[AsyncCollector, list[int]], ) -> None: c1, env_lens = async_collector_and_env_lens ptr = [0, 0, 0, 0] bufsize = 60 for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): result = c1.collect(n_episode=n_episode) assert result.n_collected_episodes >= n_episode # check buffer data, obs and obs_next, env_id for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): env_len = i + 2 total = env_len * count indices = np.arange(ptr[i], ptr[i] + total) % bufsize ptr[i] = (ptr[i] + total) % bufsize seq = np.arange(env_len) buf = c1.buffer.buffers[i] assert np.all(buf.info.env_id[indices] == i) assert np.all(buf.obs[indices].reshape(count, env_len) == seq) assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_step( self, async_collector_and_env_lens: tuple[AsyncCollector, list[int]], ) -> None: c1, env_lens = async_collector_and_env_lens bufsize = 60 ptr = [0, 0, 0, 0] for n_step in tqdm.trange(1, 15, desc="test async n_step"): result = c1.collect(n_step=n_step) assert result.n_collected_steps >= n_step for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): env_len = i + 2 total = env_len * count indices = np.arange(ptr[i], ptr[i] + total) % bufsize ptr[i] = (ptr[i] + total) % bufsize seq = np.arange(env_len) buf = c1.buffer.buffers[i] assert np.all(buf.info.env_id[indices] == i) assert np.all(buf.obs[indices].reshape(count, env_len) == seq) assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_first_n_episode_then_n_step( self, async_collector_and_env_lens: tuple[AsyncCollector, list[int]], ) -> None: c1, env_lens = async_collector_and_env_lens bufsize = 60 ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): result = c1.collect(n_episode=n_episode) assert result.n_collected_episodes >= n_episode # check buffer data, obs and obs_next, env_id for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): env_len = i + 2 total = env_len * count indices = np.arange(ptr[i], ptr[i] + total) % bufsize ptr[i] = (ptr[i] + total) % bufsize seq = np.arange(env_len) buf = c1.buffer.buffers[i] assert np.all(buf.info.env_id[indices] == i) assert np.all(buf.obs[indices].reshape(count, env_len) == seq) assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data, thus no bincount stuff as above for n_step in tqdm.trange(1, 15, desc="test async n_step"): result = c1.collect(n_step=n_step) assert result.n_collected_steps >= n_step for i in range(4): env_len = i + 2 seq = np.arange(env_len) buf = c1.buffer.buffers[i] assert np.all(buf.info.env_id == i) assert np.all(buf.obs.reshape(-1, env_len) == seq) assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) def test_collector_with_dict_state() -> None: env = MoveToRightEnv(size=5, sleep=0, dict_state=True) policy = MaxActionPolicy(dict_state=True) c0 = Collector[CollectStats](policy, env, ReplayBuffer(size=100)) c0.reset() c0.collect(n_step=3) c0.collect(n_episode=2) assert len(c0.buffer) == 10 # 3 + two episodes with 5 steps each env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) envs.seed(666) obs, info = envs.reset() assert not np.isclose(obs[0]["rand"], obs[1]["rand"]) c1 = Collector[CollectStats]( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), ) c1.reset() c1.collect(n_step=12) result = c1.collect(n_episode=8) assert result.n_collected_episodes == 8 lens = np.bincount(result.lens) assert (result.n_collected_steps == 21 and np.all(lens == [0, 0, 2, 2, 2, 2])) or ( result.n_collected_steps == 20 and np.all(lens == [0, 0, 3, 1, 2, 2]) ) batch, _ = c1.buffer.sample(10) c0.buffer.update(c1.buffer) assert len(c0.buffer) in [42, 43] cur_obs = c0.buffer[:].obs assert isinstance(cur_obs, Batch) if len(c0.buffer) == 42: assert np.all( cur_obs.index[..., 0] == [ 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, ], ), cur_obs.index[..., 0] else: assert np.all( cur_obs.index[..., 0] == [ 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 0, 1, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, ], ), cur_obs.index[..., 0] c2 = Collector[CollectStats]( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), ) c2.reset() c2.collect(n_episode=10) batch, _ = c2.buffer.sample(10) def test_collector_with_multi_agent() -> None: multi_agent_env = MoveToRightEnv(size=5, sleep=0, ma_rew=4) policy = MaxActionPolicy() c_single_env = Collector[CollectStats](policy, multi_agent_env, ReplayBuffer(size=100)) c_single_env.reset() multi_env_returns = c_single_env.collect(n_step=3).returns # c_single_env has length 3 # We have no full episodes, so no returns yet assert len(multi_env_returns) == 0 single_env_returns = c_single_env.collect(n_episode=2).returns # now two episodes. Since we have 4 a agents, the returns have shape (2, 4) assert single_env_returns.shape == (2, 4) assert np.all(single_env_returns == 1) env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) c_multi_env_ma = Collector[CollectStats]( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), ) c_multi_env_ma.reset() multi_env_returns = c_multi_env_ma.collect(n_step=12).returns # each env makes 3 steps, the first two envs are done and result in two finished episodes assert multi_env_returns.shape == (2, 4) and np.all(multi_env_returns == 1), multi_env_returns multi_env_returns = c_multi_env_ma.collect(n_episode=8).returns assert multi_env_returns.shape == (8, 4) assert np.all(multi_env_returns == 1) batch, _ = c_multi_env_ma.buffer.sample(10) print(batch) c_single_env.buffer.update(c_multi_env_ma.buffer) assert len(c_single_env.buffer) in [42, 43] if len(c_single_env.buffer) == 42: multi_env_returns = np.array( [ 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, ], ) else: multi_env_returns = np.array( [ 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, ], ) assert np.all(c_single_env.buffer[:].rew == [[x] * 4 for x in multi_env_returns]) assert np.all(c_single_env.buffer[:].done == multi_env_returns) c2 = Collector[CollectStats]( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), ) c2.reset() multi_env_returns = c2.collect(n_episode=10).returns assert multi_env_returns.shape == (10, 4) assert np.all(multi_env_returns == 1) batch, _ = c2.buffer.sample(10) def test_collector_with_atari_setting() -> None: reference_obs = np.zeros([6, 4, 84, 84]) for i in range(6): reference_obs[i, 3, np.arange(84), np.arange(84)] = i reference_obs[i, 2, np.arange(84)] = i reference_obs[i, 1, :, np.arange(84)] = i reference_obs[i, 0] = i # atari single buffer env = MoveToRightEnv(size=5, sleep=0, array_state=True) policy = MaxActionPolicy() c0 = Collector[CollectStats](policy, env, ReplayBuffer(size=100)) c0.reset() c0.collect(n_step=6) c0.collect(n_episode=2) assert c0.buffer.obs.shape == (100, 4, 84, 84) assert c0.buffer.obs_next.shape == (100, 4, 84, 84) assert len(c0.buffer) == 15 # 6 + 2 episodes with 5 steps each obs = np.zeros_like(c0.buffer.obs) obs[np.arange(15)] = reference_obs[np.arange(15) % 5] assert np.all(obs == c0.buffer.obs) c1 = Collector[CollectStats](policy, env, ReplayBuffer(size=100, ignore_obs_next=True)) c1.collect(n_episode=3, reset_before_collect=True) assert np.allclose(c0.buffer.obs, c1.buffer.obs) with pytest.raises(AttributeError): c1.buffer.obs_next # noqa: B018 assert np.all(reference_obs[[1, 2, 3, 4, 4] * 3] == c1.buffer[:].obs_next) c2 = Collector[CollectStats]( policy, env, ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True), ) c2.reset() c2.collect(n_step=8) assert c2.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c2.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1] assert np.all(c2.buffer.obs == obs) obs_next = c2.buffer[:].obs_next assert isinstance(obs_next, np.ndarray) assert np.allclose(obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) # atari multi buffer env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) c3 = Collector[CollectStats](policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.reset() c3.collect(n_step=12) result_cached_buffer_collect_9_episodes = c3.collect(n_episode=9) assert result_cached_buffer_collect_9_episodes.n_collected_episodes == 9 assert result_cached_buffer_collect_9_episodes.n_collected_steps == 23 assert c3.buffer.obs.shape == (100, 4, 84, 84) obs = np.zeros_like(c3.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]] obs[np.arange(25, 34)] = reference_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]] obs[np.arange(50, 58)] = reference_obs[[0, 1, 2, 3, 0, 1, 2, 3]] obs[np.arange(75, 85)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]] assert np.all(obs == c3.buffer.obs) obs_next = np.zeros_like(c3.buffer.obs_next) obs_next[np.arange(8)] = reference_obs[[1, 2, 1, 2, 1, 2, 1, 2]] obs_next[np.arange(25, 34)] = reference_obs[[1, 2, 3, 1, 2, 3, 1, 2, 3]] obs_next[np.arange(50, 58)] = reference_obs[[1, 2, 3, 4, 1, 2, 3, 4]] obs_next[np.arange(75, 85)] = reference_obs[[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]] assert np.all(obs_next == c3.buffer.obs_next) c4 = Collector[CollectStats]( policy, envs, VectorReplayBuffer( total_size=100, buffer_num=4, stack_num=4, ignore_obs_next=True, save_only_last_obs=True, ), ) c4.reset() c4.collect(n_step=12) result_cached_buffer_collect_9_episodes = c4.collect(n_episode=9) assert result_cached_buffer_collect_9_episodes.n_collected_episodes == 9 assert result_cached_buffer_collect_9_episodes.n_collected_steps == 23 assert c4.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c4.buffer.obs) slice_obs = reference_obs[:, -1] obs[np.arange(8)] = slice_obs[[0, 1, 0, 1, 0, 1, 0, 1]] obs[np.arange(25, 34)] = slice_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]] obs[np.arange(50, 58)] = slice_obs[[0, 1, 2, 3, 0, 1, 2, 3]] obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]] assert np.all(c4.buffer.obs == obs) obs_next = np.zeros([len(c4.buffer), 4, 84, 84]) ref_index = np.array( [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 4, 4, 1, 2, 3, 4, 4, ], ) obs_next[:, -1] = slice_obs[ref_index] ref_index -= 1 ref_index[ref_index < 0] = 0 obs_next[:, -2] = slice_obs[ref_index] ref_index -= 1 ref_index[ref_index < 0] = 0 obs_next[:, -3] = slice_obs[ref_index] ref_index -= 1 ref_index[ref_index < 0] = 0 obs_next[:, -4] = slice_obs[ref_index] assert np.all(obs_next == c4.buffer[:].obs_next) buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True) collector_cached_buffer = Collector[CollectStats](policy, envs, CachedReplayBuffer(buf, 4, 10)) collector_cached_buffer.reset() result_cached_buffer_collect_12_steps = collector_cached_buffer.collect(n_step=12) assert len(buf) == 5 assert len(collector_cached_buffer.buffer) == 12 result_cached_buffer_collect_9_episodes = collector_cached_buffer.collect(n_episode=9) assert result_cached_buffer_collect_9_episodes.n_collected_episodes == 9 assert result_cached_buffer_collect_9_episodes.n_collected_steps == 23 assert len(buf) == 35 assert np.all( buf.obs[: len(buf)] == slice_obs[ [ 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, ] ], ) assert np.all( buf[:].obs_next[:, -1] == slice_obs[ [ 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4, ] ], ) assert len(buf) == len(collector_cached_buffer.buffer) # test buffer=None collector_default_buffer = Collector[CollectStats](policy, envs) collector_default_buffer.reset() result_default_buffer_collect_12_steps = collector_default_buffer.collect(n_step=12) for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: assert np.allclose( getattr(result_default_buffer_collect_12_steps, key), getattr(result_cached_buffer_collect_12_steps, key), ) result2 = collector_default_buffer.collect(n_episode=9) for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: assert np.allclose( getattr(result2, key), getattr(result_cached_buffer_collect_9_episodes, key), ) @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_collector_envpool_gym_reset_return_info() -> None: envs = envpool.make_gymnasium("Pendulum-v1", num_envs=4, gym_reset_return_info=True) policy = MaxActionPolicy(action_shape=(len(envs), 1)) c0 = Collector[CollectStats]( policy, envs, VectorReplayBuffer(len(envs) * 10, len(envs)), exploration_noise=True, ) c0.reset() c0.collect(n_step=8) env_ids = np.zeros(len(envs) * 10) env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3] assert np.allclose(c0.buffer.info["env_id"], env_ids) def test_collector_with_vector_env() -> None: env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] dum = DummyVectorEnv(env_fns) policy = MaxActionPolicy() c2 = Collector[CollectStats]( policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4), ) c2.reset() c1r = c2.collect(n_episode=2) assert np.array_equal(np.array([1, 8]), c1r.lens) c2r = c2.collect(n_episode=10) assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 8, 9, 10]), c2r.lens) c3r = c2.collect(n_step=20) assert np.array_equal(np.array([1, 1, 1, 1, 1]), c3r.lens) c4r = c2.collect(n_step=20) assert np.array_equal(np.array([1, 1, 1, 8, 1, 9, 1, 10]), c4r.lens) def test_async_collector_with_vector_env() -> None: env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] dum = DummyVectorEnv(env_fns) policy = MaxActionPolicy() c1 = AsyncCollector( policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4), ) c1r = c1.collect(n_episode=10, reset_before_collect=True) assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens) c2r = c1.collect(n_step=20) assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens) class StepHookAddFieldToBatch(StepHook): def __call__( self, action_batch: CollectActionBatchProtocol, rollout_batch: RolloutBatchProtocol, ) -> None: rollout_batch.set_array_at_key(np.array([1]), "added_by_hook") class TestCollectStatsAndHooks: @staticmethod def test_on_step_hook(collector_with_single_env: Collector) -> None: collector_with_single_env.set_on_step_hook(StepHookAddFieldToBatch()) collect_stats = collector_with_single_env.collect(n_step=3) assert collect_stats.n_collected_steps == 3 # a was added by the hook assert np.array_equal( collector_with_single_env.buffer[:].added_by_hook, np.array([1, 1, 1]), ) @staticmethod def test_episode_mc_hook(collector_with_single_env: Collector) -> None: collector_with_single_env.set_on_episode_done_hook(EpisodeRolloutHookMCReturn()) collector_with_single_env.collect(n_episode=1) collected_batch = collector_with_single_env.buffer[:] return_to_go = collected_batch.get(EpisodeRolloutHookMCReturn.MC_RETURN_TO_GO_KEY) full_return = collected_batch.get(EpisodeRolloutHookMCReturn.FULL_EPISODE_MC_RETURN_KEY) assert np.array_equal(return_to_go, episode_mc_return_to_go(collected_batch.rew)) assert np.array_equal(full_return, np.ones(5) * return_to_go[0]) ================================================ FILE: test/base/test_env.py ================================================ import sys import time from collections.abc import Callable from typing import Any, Literal import gymnasium as gym import numpy as np import pytest from gymnasium.spaces.discrete import Discrete from test.base.env import MoveToRightEnv, NXEnv from tianshou.data import Batch from tianshou.env import ( ContinuousToDiscrete, DummyVectorEnv, MultiDiscreteToDiscrete, RayVectorEnv, ShmemVectorEnv, SubprocVectorEnv, VectorEnvNormObs, ) from tianshou.env.gym_wrappers import TruncatedAsTerminated from tianshou.env.venvs import BaseVectorEnv from tianshou.utils import RunningMeanStd try: import envpool except ImportError: envpool = None def has_ray() -> bool: try: import ray # noqa: F401 return True except ImportError: return False def recurse_comp(a: np.ndarray | list | tuple | dict, b: Any) -> np.bool_ | bool | None: try: if isinstance(a, np.ndarray): if a.dtype == object: return np.array([recurse_comp(m, n) for m, n in zip(a, b, strict=True)]).all() return np.allclose(a, b) if isinstance(a, list | tuple): return np.array([recurse_comp(m, n) for m, n in zip(a, b, strict=True)]).all() if isinstance(a, dict): return np.array([recurse_comp(a[k], b[k]) for k in a]).all() except Exception: return False def test_async_env(size: int = 10000, num: int = 8, sleep: float = 0.1) -> None: # simplify the test case, just keep stepping env_fns = [ lambda i=i: MoveToRightEnv(size=i, sleep=sleep, random_sleep=True) for i in range(size, size + num) ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] for cls in test_cls: v = cls(env_fns, wait_num=num // 2, timeout=1e-3) v.seed(None) v.reset() # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un} # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1} # expectation of v is n / (n + 1) # for a synchronous environment, the following actions should take # about 7 * sleep * num / (num + 1) seconds # for async simulation, the analysis is complicated, but the time cost # should be smaller action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4) current_idx_start = 0 act = action_list[:num] env_ids = list(range(num)) o = [] spent_time = time.time() while current_idx_start < len(action_list): ( A, B, C, D, E, ) = v.step(action=act, id=env_ids) b = Batch({"obs": A, "rew": B, "terminate": C, "truncated": D, "info": E}) env_ids = b.info.env_id o.append(b) current_idx_start += len(act) # len of action may be smaller than len(A) in the end act = action_list[current_idx_start : current_idx_start + len(A)] # truncate env_ids with the first terms # typically len(env_ids) == len(A) == len(action), except for the # last batch when actions are not enough env_ids = env_ids[: len(act)] spent_time = time.time() - spent_time Batch.cat(o) v.close() # assure 1/7 improvement if sys.platform == "linux" and cls != RayVectorEnv: # macOS/Windows cannot pass this check assert spent_time < 6.0 * sleep * num / (num + 1) def test_async_check_id( size: int = 100, num: int = 4, sleep: float = 0.2, timeout: float = 0.7, ) -> None: env_fns = [ lambda: MoveToRightEnv(size=size, sleep=sleep * 2), lambda: MoveToRightEnv(size=size, sleep=sleep * 3), lambda: MoveToRightEnv(size=size, sleep=sleep * 5), lambda: MoveToRightEnv(size=size, sleep=sleep * 7), ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] total_pass = 0 for cls in test_cls: pass_check = 1 v = cls(env_fns, wait_num=num - 1, timeout=timeout) t = time.time() v.reset() t = time.time() - t print(f"{cls} reset {t}") if t > sleep * 9: # huge than maximum sleep time (7 sleep) pass_check = 0 expect_result = [ [0, 1], [0, 1, 2], [0, 1, 3], [0, 1, 2], [0, 1], [0, 2, 3], [0, 1], ] ids = np.arange(num) for res in expect_result: t = time.time() _, _, _, _, info = v.step([1] * len(ids), ids) t = time.time() - t ids = Batch(info).env_id print(ids, t) if not ( len(ids) == len(res) and np.allclose(sorted(ids), res) and (t < timeout) == (len(res) == num - 1) ): pass_check = 0 break total_pass += pass_check if sys.platform == "linux": # Windows/macOS may not pass this check assert total_pass >= 2 def test_vecenv(size: int = 10, num: int = 8, sleep: float = 0.001) -> None: env_fns = [ lambda i=i: MoveToRightEnv(size=i, sleep=sleep, recurse_state=True) for i in range(size, size + num) ] venv = [ DummyVectorEnv(env_fns), SubprocVectorEnv(env_fns), ShmemVectorEnv(env_fns), ] if has_ray() and sys.platform == "linux": venv += [RayVectorEnv(env_fns)] for v in venv: v.seed(0) action_list = [1] * 5 + [0] * 10 + [1] * 20 for a in action_list: o = [] for v in venv: A, B, C, D, E = v.step(np.array([a] * num)) if sum(C + D): A, _ = v.reset(np.where(C + D)[0]) o.append([A, B, C, D, E]) for index, infos in enumerate(zip(*o, strict=True)): if index == 4: # do not check info here continue for info in infos: assert recurse_comp(infos[0], info) def assert_get(v: BaseVectorEnv, expected: list) -> None: assert v.get_env_attr("size") == expected assert v.get_env_attr("size", id=0) == [expected[0]] assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3] for v in venv: assert_get(v, list(range(size, size + num))) assert v.env_num == num assert v.action_space == [Discrete(2)] * num v.set_env_attr("size", 0) assert_get(v, [0] * num) v.set_env_attr("size", 1, 0) assert_get(v, [1] + [0] * (num - 1)) v.set_env_attr("size", 2, [1, 2, 3]) assert_get(v, [1] + [2] * 3 + [0] * (num - 4)) for v in venv: v.close() def test_attr_unwrapped() -> None: training_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")]) training_envs.set_env_attr("test_attribute", 1337) assert training_envs.get_env_attr("test_attribute") == [1337] assert hasattr(training_envs.workers[0].env.unwrapped, "test_attribute") # type: ignore def test_env_obs_dtype() -> None: def create_env(i: int, t: str) -> Callable[[], NXEnv]: return lambda: NXEnv(i, t) for obs_type in ["array", "object"]: envs = SubprocVectorEnv([create_env(x, obs_type) for x in [5, 10, 15, 20]]) obs, info = envs.reset() assert obs.dtype == object obs = envs.step(np.array([1, 1, 1, 1]))[0] assert obs.dtype == object def test_env_reset_optional_kwargs(size: int = 10000, num: int = 8) -> None: env_fns = [lambda i=i: MoveToRightEnv(size=i) for i in range(size, size + num)] test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] for cls in test_cls: v = cls(env_fns, wait_num=num // 2, timeout=1e-3) _, info = v.reset(seed=1) assert len(info) == len(env_fns) assert isinstance(info[0], dict) def test_venv_wrapper_gym(num_envs: int = 4) -> None: # Issue 697 envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)]) envs = VectorEnvNormObs(envs) try: obs, info = envs.reset() except ValueError: obs, info = envs.reset(return_info=True) assert isinstance(obs, np.ndarray) assert isinstance(info, np.ndarray) assert isinstance(info[0], dict) assert obs.shape[0] == len(info) == num_envs def run_align_norm_obs( raw_env: DummyVectorEnv, train_env: VectorEnvNormObs, test_env: VectorEnvNormObs, action_list: list[np.ndarray], ) -> None: def reset_result_to_obs( reset_result: tuple[np.ndarray, dict | list[dict]], ) -> np.ndarray: """Extract observation from reset result (result is possibly a tuple containing info).""" if isinstance(reset_result, tuple) and len(reset_result) == 2: obs, _ = reset_result else: obs = reset_result # type: ignore return obs eps = np.finfo(np.float32).eps.item() raw_reset_result = raw_env.reset() train_reset_result = train_env.reset() initial_raw_obs = reset_result_to_obs(raw_reset_result) # type: ignore initial_train_obs = reset_result_to_obs(train_reset_result) # type: ignore raw_obs, train_obs = [initial_raw_obs], [initial_train_obs] for action in action_list: step_result = raw_env.step(action) if len(step_result) == 5: obs, rew, terminated, truncated, info = step_result done = np.logical_or(terminated, truncated) else: obs, rew, done, info = step_result # type: ignore raw_obs.append(obs) if np.any(done): reset_result = raw_env.reset(np.where(done)[0]) obs = reset_result_to_obs(reset_result) # type: ignore raw_obs.append(obs) step_result = train_env.step(action) if len(step_result) == 5: obs, rew, terminated, truncated, info = step_result done = np.logical_or(terminated, truncated) else: obs, rew, done, info = step_result # type: ignore train_obs.append(obs) if np.any(done): reset_result = train_env.reset(np.where(done)[0]) obs = reset_result_to_obs(reset_result) # type: ignore train_obs.append(obs) ref_rms = RunningMeanStd() for ro, to in zip(raw_obs, train_obs, strict=True): ref_rms.update(ro) no = (ro - ref_rms.mean) / np.sqrt(ref_rms.var + eps) assert np.allclose(no, to) assert np.allclose(ref_rms.mean, train_env.get_obs_rms().mean) assert np.allclose(ref_rms.var, train_env.get_obs_rms().var) assert np.allclose(ref_rms.mean, test_env.get_obs_rms().mean) assert np.allclose(ref_rms.var, test_env.get_obs_rms().var) reset_result = test_env.reset() obs = reset_result_to_obs(reset_result) # type: ignore test_obs = [obs] for action in action_list: step_result = test_env.step(action) if len(step_result) == 5: obs, rew, terminated, truncated, info = step_result done = np.logical_or(terminated, truncated) else: obs, rew, done, info = step_result # type: ignore test_obs.append(obs) if np.any(done): reset_result = test_env.reset(np.where(done)[0]) obs = reset_result_to_obs(reset_result) # type: ignore test_obs.append(obs) for ro, to in zip(raw_obs, test_obs, strict=True): no = (ro - ref_rms.mean) / np.sqrt(ref_rms.var + eps) assert np.allclose(no, to) def test_venv_norm_obs() -> None: sizes = np.array([5, 10, 15, 20]) action = np.array([1, 1, 1, 1]) total_step = 30 action_list = [action] * total_step env_fns = [lambda i=x: MoveToRightEnv(size=i, array_state=True) for x in sizes] raw = DummyVectorEnv(env_fns) train_env = VectorEnvNormObs(DummyVectorEnv(env_fns)) print(train_env.observation_space) test_env = VectorEnvNormObs(DummyVectorEnv(env_fns), update_obs_rms=False) test_env.set_obs_rms(train_env.get_obs_rms()) run_align_norm_obs(raw, train_env, test_env, action_list) def test_gym_wrappers() -> None: class DummyEnv(gym.Env): def __init__(self) -> None: self.action_space = gym.spaces.Box(low=-1.0, high=2.0, shape=(4,), dtype=np.float32) self.observation_space = gym.spaces.Discrete(2) def step(self, act: Any) -> tuple[Any, Literal[-1], Literal[False], Literal[True], dict]: return self.observation_space.sample(), -1, False, True, {} bsz = 10 action_per_branch = [4, 6, 10, 7] env = DummyEnv() assert isinstance(env.action_space, gym.spaces.Box) original_act = env.action_space.high # convert continous to multidiscrete action space # with different action number per dimension env_m = ContinuousToDiscrete(env, action_per_branch) assert isinstance(env_m.action_space, gym.spaces.MultiDiscrete) # check conversion is working properly for one action np.testing.assert_allclose(env_m.action(env_m.action_space.nvec - 1), original_act) # check conversion is working properly for a batch of actions np.testing.assert_allclose( env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)), np.array([original_act] * bsz), ) # convert multidiscrete with different action number per # dimension to discrete action space env_d = MultiDiscreteToDiscrete(env_m) assert isinstance(env_d.action_space, gym.spaces.Discrete) # check conversion is working properly for one action np.testing.assert_allclose( env_d.action(np.array(env_d.action_space.n - 1)), env_m.action_space.nvec - 1, ) # check conversion is working properly for a batch of actions np.testing.assert_allclose( env_d.action(np.array([env_d.action_space.n - 1] * bsz)), np.array([env_m.action_space.nvec - 1] * bsz), ) # check truncate is True when terminated try: env_t = TruncatedAsTerminated(env) except OSError: env_t = None if env_t is not None: _, _, truncated, _, _ = env_t.step(env_t.action_space.sample()) assert truncated # TODO: old gym envs are no longer supported! Replace by Ant-v4 and fix assoticiated tests @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_venv_wrapper_envpool() -> None: raw = envpool.make_gymnasium("Ant-v3", num_envs=4) train = VectorEnvNormObs(envpool.make_gymnasium("Ant-v3", num_envs=4)) test = VectorEnvNormObs(envpool.make_gymnasium("Ant-v3", num_envs=4), update_obs_rms=False) test.set_obs_rms(train.get_obs_rms()) actions = [np.array([raw.action_space.sample() for _ in range(4)]) for i in range(30)] run_align_norm_obs(raw, train, test, actions) @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_venv_wrapper_envpool_gym_reset_return_info() -> None: num_envs = 4 env = VectorEnvNormObs( envpool.make_gymnasium("Ant-v3", num_envs=num_envs, gym_reset_return_info=True), ) obs, info = env.reset() assert obs.shape[0] == num_envs # This is not actually unreachable b/c envpool does not return info in the right format if isinstance(info, dict): # type: ignore[unreachable] for _, v in info.items(): # type: ignore[unreachable] if not isinstance(v, dict): assert v.shape[0] == num_envs else: for _info in info: for _, v in _info.items(): if not isinstance(v, dict): assert v.shape[0] == num_envs ================================================ FILE: test/base/test_env_finite.py ================================================ # see issue #322 for detail import copy from collections import Counter from collections.abc import Callable, Iterator, Sequence from typing import Any, cast import gymnasium as gym import numpy as np import torch from gymnasium.spaces import Box from torch.utils.data import DataLoader, Dataset, DistributedSampler from tianshou.algorithm.algorithm_base import Policy from tianshou.data import Batch, Collector, CollectStats from tianshou.data.types import ( ActBatchProtocol, BatchProtocol, ObsBatchProtocol, ) from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type class DummyDataset(Dataset): def __init__(self, length: int) -> None: self.length = length self.episodes = [3 * i % 5 + 1 for i in range(self.length)] def __getitem__(self, index: int) -> tuple[int, int]: assert 0 <= index < self.length return index, self.episodes[index] def __len__(self) -> int: return self.length class FiniteEnv(gym.Env): def __init__(self, dataset: Dataset, num_replicas: int | None, rank: int | None) -> None: self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.loader = DataLoader( dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None, ) self.iterator: Iterator | None = None def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[Any, dict[str, Any]]: if self.iterator is None: self.iterator = iter(self.loader) try: self.current_sample, self.step_count = next(self.iterator) self.current_step = 0 return self.current_sample, {} except StopIteration: self.iterator = None return None, {} def step(self, action: int) -> tuple[int, float, bool, bool, dict[str, Any]]: self.current_step += 1 assert self.current_step <= self.step_count return ( 0, 1.0, self.current_step >= self.step_count, False, {"sample": self.current_sample, "action": action, "metric": 2.0}, ) class FiniteVectorEnv(BaseVectorEnv): def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: super().__init__(env_fns, **kwargs) self._alive_env_ids: set[int] = set() self._reset_alive_envs() self._default_obs: np.ndarray | None = None self._default_info: dict | None = None self.tracker: MetricTracker def _reset_alive_envs(self) -> None: if not self._alive_env_ids: # starting or running out self._alive_env_ids = set(range(self.env_num)) # to workaround with tianshou's buffer and batch def _set_default_obs(self, obs: np.ndarray) -> None: if obs is not None and self._default_obs is None: self._default_obs = copy.deepcopy(obs) def _set_default_info(self, info: dict) -> None: if info is not None and self._default_info is None: self._default_info = copy.deepcopy(info) def _get_default_obs(self) -> np.ndarray | None: return copy.deepcopy(self._default_obs) def _get_default_info(self) -> dict | None: return copy.deepcopy(self._default_info) # END def reset( self, env_id: int | list[int] | np.ndarray | None = None, **kwargs: Any, ) -> tuple[np.ndarray, np.ndarray]: env_id = self._wrap_id(env_id) self._reset_alive_envs() # ask super to reset alive envs and remap to current index request_id = list(filter(lambda i: i in self._alive_env_ids, env_id)) obs_list: list[np.ndarray | None] = [None] * len(env_id) infos: list[dict | None] = [None] * len(env_id) id2idx = {i: k for k, i in enumerate(env_id)} if request_id: for k, o, info in zip(request_id, *super().reset(request_id), strict=True): obs_list[id2idx[k]] = o infos[id2idx[k]] = info for i, o in zip(env_id, obs_list, strict=True): if o is None and i in self._alive_env_ids: self._alive_env_ids.remove(i) # fill empty observation with default(fake) observation for o in obs_list: self._set_default_obs(o) for i in range(len(obs_list)): if obs_list[i] is None: obs_list[i] = self._get_default_obs() if infos[i] is None: infos[i] = self._get_default_info() if not self._alive_env_ids: self.reset() raise StopIteration obs_list = cast(list[np.ndarray], obs_list) infos = cast(list[dict], infos) return np.stack(obs_list), np.array(infos) def step( self, action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: ids: list[int] | np.ndarray = self._wrap_id(id) id2idx = {i: k for k, i in enumerate(ids)} request_id = list(filter(lambda i: i in self._alive_env_ids, ids)) result: list[list] = [[None, 0.0, False, False, None] for _ in range(len(ids))] # ask super to step alive envs and remap to current index assert action is not None if request_id: valid_act = np.stack([action[id2idx[i]] for i in request_id]) for i, (r_obs, r_reward, r_term, r_trunc, r_info) in zip( request_id, zip(*super().step(valid_act, request_id), strict=True), strict=True, ): result[id2idx[i]] = [r_obs, r_reward, r_term, r_trunc, r_info] # logging for i, r in zip(ids, result, strict=True): if i in self._alive_env_ids: self.tracker.log(*r) # fill empty observation/info with default(fake) for _, __, ___, ____, i in result: self._set_default_info(i) for i in range(len(result)): if result[i][0] is None: result[i][0] = self._get_default_obs() if result[i][-1] is None: result[i][-1] = self._get_default_info() obs_list, rew_list, term_list, trunc_list, info_list = zip(*result, strict=True) try: obs_stack = np.stack(obs_list) except ValueError: # different len(obs) obs_stack = np.array(obs_list, dtype=object) return ( obs_stack, np.stack(rew_list), np.stack(term_list), np.stack(trunc_list), np.stack(info_list), ) class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv): pass class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv): pass class DummyPolicy(Policy): def __init__(self) -> None: super().__init__(action_space=Box(-1, 1, (1,))) def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ActBatchProtocol: return cast(ActBatchProtocol, Batch(act=np.stack([1] * len(batch)))) def _finite_env_factory(dataset: Dataset, num_replicas: int, rank: int) -> Callable[[], FiniteEnv]: return lambda: FiniteEnv(dataset, num_replicas, rank) class MetricTracker: def __init__(self) -> None: self.counter: Counter = Counter() self.finished: set[int] = set() def log(self, obs: Any, rew: float, terminated: bool, truncated: bool, info: dict) -> None: assert rew == 1.0 done = terminated or truncated index = info["sample"] if done: assert index not in self.finished self.finished.add(index) self.counter[index] += 1 def validate(self) -> None: assert len(self.finished) == 100 for k, v in self.counter.items(): assert v == k * 3 % 5 + 1 def test_finite_dummy_vector_env() -> None: dataset = DummyDataset(100) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) policy = DummyPolicy() test_collector = Collector[CollectStats](policy, envs, exploration_noise=True) test_collector.reset() for _ in range(3): envs.tracker = MetricTracker() try: # TODO: why on earth 10**18? test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() def test_finite_subproc_vector_env() -> None: dataset = DummyDataset(100) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) policy = DummyPolicy() test_collector = Collector[CollectStats](policy, envs, exploration_noise=True) test_collector.reset() for _ in range(3): envs.tracker = MetricTracker() try: test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() ================================================ FILE: test/base/test_logger.py ================================================ from typing import Literal import numpy as np import pytest from torch.utils.tensorboard import SummaryWriter from tianshou.utils import TensorboardLogger class TestTensorBoardLogger: @staticmethod @pytest.mark.parametrize( "input_dict, expected_output", [ ({"a": 1, "b": {"c": 2, "d": {"e": 3}}}, {"a": 1, "b/c": 2, "b/d/e": 3}), ({"a": {"b": {"c": 1}}}, {"a/b/c": 1}), ], ) def test_flatten_dict_basic( input_dict: dict[str, int | dict[str, int | dict[str, int]]] | dict[str, dict[str, dict[str, int]]], expected_output: dict[str, int], ) -> None: logger = TensorboardLogger(SummaryWriter("log/logger")) result = logger.prepare_dict_for_logging(input_dict) assert result == expected_output @staticmethod @pytest.mark.parametrize( "input_dict, delimiter, expected_output", [ ({"a": {"b": {"c": 1}}}, "|", {"a|b|c": 1}), ({"a": {"b": {"c": 1}}}, ".", {"a.b.c": 1}), ], ) def test_flatten_dict_custom_delimiter( input_dict: dict[str, dict[str, dict[str, int]]], delimiter: Literal["|", "."], expected_output: dict[str, int], ) -> None: logger = TensorboardLogger(SummaryWriter("log/logger")) result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter) assert result == expected_output @staticmethod @pytest.mark.parametrize( "input_dict, exclude_arrays, expected_output", [ ( {"a": np.array([1, 2, 3]), "b": {"c": np.array([4, 5, 6])}}, False, {"a": np.array([1, 2, 3]), "b/c": np.array([4, 5, 6])}, ), ({"a": np.array([1, 2, 3]), "b": {"c": np.array([4, 5, 6])}}, True, {}), ], ) def test_flatten_dict_exclude_arrays( input_dict: dict[str, np.ndarray | dict[str, np.ndarray]], exclude_arrays: bool, expected_output: dict[str, np.ndarray], ) -> None: logger = TensorboardLogger(SummaryWriter("log/logger")) result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays) assert result.keys() == expected_output.keys() for val1, val2 in zip(result.values(), expected_output.values(), strict=True): assert np.all(val1 == val2) @staticmethod @pytest.mark.parametrize( "input_dict, expected_output", [ ({"a": (1,), "b": {"c": "2", "d": {"e": 3}}}, {"b/d/e": 3}), ], ) def test_flatten_dict_invalid_values_filtered_out( input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]], expected_output: dict[str, int], ) -> None: logger = TensorboardLogger(SummaryWriter("log/logger")) result = logger.prepare_dict_for_logging(input_dict) assert result == expected_output ================================================ FILE: test/base/test_policy.py ================================================ import gymnasium as gym import numpy as np import pytest import torch from torch.distributions import Categorical, Distribution, Independent, Normal from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import ( RandomActionPolicy, episode_mc_return_to_go, ) from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Batch from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.net.discrete import DiscreteActor obs_shape = (5,) def _to_hashable(x: np.ndarray | int) -> int | tuple[list]: return x if isinstance(x, int) else tuple(x.tolist()) def test_calculate_discounted_returns() -> None: assert np.all( episode_mc_return_to_go([1, 1, 1], 0.9) == np.array([0.9**2 + 0.9 + 1, 0.9 + 1, 1]), ) assert episode_mc_return_to_go([1, 2, 3], 0.5)[0] == 1 + 0.5 * (2 + 0.5 * 3) @pytest.fixture(params=["continuous", "discrete"]) def algorithm(request: pytest.FixtureRequest) -> PPO: action_type = request.param action_space: gym.spaces.Box | gym.spaces.Discrete actor: DiscreteActor | ContinuousActorProbabilistic if action_type == "continuous": action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) actor = ContinuousActorProbabilistic( preprocess_net=Net( state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape, ), action_shape=action_space.shape, ) def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) elif action_type == "discrete": action_space = gym.spaces.Discrete(3) actor = DiscreteActor( preprocess_net=Net( state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n, ), action_shape=action_space.n, ) dist_fn = Categorical else: raise ValueError(f"Unknown action type: {action_type}") critic = ContinuousCritic( preprocess_net=Net(state_shape=obs_shape, hidden_sizes=[64, 64]), ) optim = AdamOptimizerFactory(lr=1e-3) algorithm: PPO policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist_fn, action_space=action_space, action_scaling=False, ) algorithm = PPO( policy=policy, critic=critic, optim=optim, ) algorithm.eval() return algorithm class TestPolicyBasics: def test_get_action(self, algorithm: PPO) -> None: policy = algorithm.policy policy.is_within_training_step = False sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False actions = [policy.compute_action(sample_obs) for _ in range(10)] assert all(policy.action_space.contains(a) for a in actions) # check that the actions are different in non-deterministic mode assert len(set(map(_to_hashable, actions))) > 1 policy.deterministic_eval = True actions = [policy.compute_action(sample_obs) for _ in range(10)] # check that the actions are the same in deterministic mode assert len(set(map(_to_hashable, actions))) == 1 @staticmethod def test_random_policy_discrete_actions() -> None: action_space = gym.spaces.Discrete(3) policy = RandomActionPolicy(action_space=action_space) # forward of actor returns discrete probabilities, in compliance with the overall discrete actor action_probs = policy.actor(np.zeros((10, 2)))[0] assert np.allclose(action_probs, 1 / 3 * np.ones((10, 3))) actions = [] for _ in range(10): action = policy.compute_action(np.array([0])) assert action_space.contains(action) actions.append(action) # not all actions are the same assert len(set(actions)) > 1 # test batched forward action_batch = policy(Batch(obs=np.zeros((10, 2)))) assert action_batch.act.shape == (10,) assert len(set(action_batch.act.tolist())) > 1 @staticmethod def test_random_policy_continuous_actions() -> None: action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) policy = RandomActionPolicy(action_space=action_space) actions = [] for _ in range(10): action = policy.compute_action(np.array([0])) assert action_space.contains(action) actions.append(action) # not all actions are the same assert len(set(map(_to_hashable, actions))) > 1 # test batched forward action_batch = policy(Batch(obs=np.zeros((10, 2)))) assert action_batch.act.shape == (10, 3) assert len(set(map(_to_hashable, action_batch.act))) > 1 ================================================ FILE: test/base/test_returns.py ================================================ from typing import cast import numpy as np import torch from tianshou.algorithm import Algorithm from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import RolloutBatchProtocol def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: returns = np.zeros_like(batch.rew) last = 0 for i in reversed(range(len(batch.rew))): returns[i] = batch.rew[i] if not batch.done[i]: returns[i] += last * gamma last = returns[i] batch.returns = returns return batch def test_episodic_returns(size: int = 2560) -> None: fn = Algorithm.compute_episodic_return buf = ReplayBuffer(20) batch = cast( RolloutBatchProtocol, Batch( terminated=np.array([1, 0, 0, 1, 0, 0, 0, 1.0]), truncated=np.array([0, 0, 0, 0, 0, 1, 0, 0]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.0]), info=Batch( { "TimeLimit.truncated": np.array( [False, False, False, False, False, True, False, False], ), }, ), ), ) for b in iter(batch): b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) assert np.allclose(returns, ans) buf.reset() batch = cast( RolloutBatchProtocol, Batch( terminated=np.array([0, 1, 0, 1, 0, 1, 0.0]), truncated=np.array([0, 0, 0, 0, 0, 0, 0.0]), rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), ), ) for b in iter(batch): b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) assert np.allclose(returns, ans) buf.reset() batch = cast( RolloutBatchProtocol, Batch( terminated=np.array([0, 1, 0, 1, 0, 0, 1.0]), truncated=np.array([0, 0, 0, 0, 0, 0, 0]), rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), ), ) for b in iter(batch): b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) assert np.allclose(returns, ans) buf.reset() batch = cast( RolloutBatchProtocol, Batch( terminated=np.array([0, 0, 0, 1.0, 0, 0, 0, 1, 0, 0, 0, 1]), truncated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), ), ) for b in batch: b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) v = np.array([2.0, 3.0, 4, -1, 5.0, 6.0, 7, -2, 8.0, 9.0, 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) ground_truth = np.array( [ 454.8344, 376.1143, 291.298, 200.0, 464.5610, 383.1085, 295.387, 201.0, 474.2876, 390.1027, 299.476, 202.0, ], ) assert np.allclose(returns, ground_truth) buf.reset() batch = cast( RolloutBatchProtocol, Batch( terminated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), truncated=np.array([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]), rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), info=Batch( { "TimeLimit.truncated": np.array( [ False, False, False, True, False, False, False, True, False, False, False, False, ], ), }, ), ), ) for b in iter(batch): b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) v = np.array([2.0, 3.0, 4, -1, 5.0, 6.0, 7, -2, 8.0, 9.0, 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) ground_truth = np.array( [ 454.0109, 375.2386, 290.3669, 199.01, 462.9138, 381.3571, 293.5248, 199.02, 474.2876, 390.1027, 299.476, 202.0, ], ) assert np.allclose(returns, ground_truth) def target_q_fn(buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: # return the next reward indices = buffer.next(indices) return torch.tensor(-buffer.rew[indices], dtype=torch.float32) def target_q_fn_multidim(buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: return target_q_fn(buffer, indices).unsqueeze(1).repeat(1, 51) def compute_nstep_return_base( nstep: int, gamma: float, buffer: ReplayBuffer, indices: np.ndarray, ) -> np.ndarray: returns = np.zeros_like(indices, dtype=float) buf_len = len(buffer) for i in range(len(indices)): flag, rew = False, 0.0 real_step_n = nstep for n in range(nstep): idx = (indices[i] + n) % buf_len rew += buffer.rew[idx] * gamma**n if buffer.done[idx]: if not (hasattr(buffer, "info") and buffer.info["TimeLimit.truncated"][idx]): flag = True real_step_n = n + 1 break if not flag: idx = (indices[i] + real_step_n - 1) % buf_len rew += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n returns[i] = rew return returns def test_nstep_returns(size: int = 10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( cast( RolloutBatchProtocol, Batch( obs=0, act=0, rew=i + 1, terminated=i % 4 == 3, truncated=False, ), ), ) batch, indices = buf.sample(0) assert np.allclose(indices, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy( Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=1) .pop("returns") .reshape(-1), ) assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, 0.1, buf, indices) assert np.allclose(returns, r_), (r_, returns) returns_multidim = to_numpy( Algorithm.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=0.1, n_step=1, ).pop("returns"), ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy( Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=2) .pop("returns") .reshape(-1), ) assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( Algorithm.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=0.1, n_step=2, ).pop("returns"), ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy( Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=10) .pop("returns") .reshape(-1), ) assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( Algorithm.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=0.1, n_step=10, ).pop("returns"), ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) def test_nstep_returns_with_timelimit(size: int = 10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( cast( RolloutBatchProtocol, Batch( obs=0, act=0, rew=i + 1, terminated=i % 4 == 3 and i != 3, truncated=i == 3, info={"TimeLimit.truncated": i == 3}, ), ), ) batch, indices = buf.sample(0) assert np.allclose(indices, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy( Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=1) .pop("returns") .reshape(-1), ) assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, 0.1, buf, indices) assert np.allclose(returns, r_), (r_, returns) returns_multidim = to_numpy( Algorithm.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=0.1, n_step=1, ).pop("returns"), ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy( Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=2) .pop("returns") .reshape(-1), ) assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( Algorithm.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=0.1, n_step=2, ).pop("returns"), ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy( Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=10) .pop("returns") .reshape(-1), ) assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( Algorithm.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=0.1, n_step=10, ).pop("returns"), ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) ================================================ FILE: test/base/test_stats.py ================================================ from typing import cast import numpy as np import pytest import torch from torch.distributions import Categorical, Normal from tianshou.algorithm.algorithm_base import TrainingStats, TrainingStatsWrapper from tianshou.data import Batch, CollectStats from tianshou.data.collector import CollectStepBatchProtocol, get_stddev_from_dist class DummyTrainingStatsWrapper(TrainingStatsWrapper): def __init__(self, wrapped_stats: TrainingStats, *, dummy_field: int) -> None: self.dummy_field = dummy_field super().__init__(wrapped_stats) class TestStats: @staticmethod def test_training_stats_wrapper() -> None: train_stats = TrainingStats(train_time=1.0) setattr(train_stats, "loss_field", 12) # noqa: B010 wrapped_train_stats = DummyTrainingStatsWrapper(train_stats, dummy_field=42) # basic readout assert wrapped_train_stats.train_time == 1.0 assert wrapped_train_stats.loss_field == 12 # mutation of TrainingStats fields wrapped_train_stats.train_time = 2.0 wrapped_train_stats.smoothed_loss["foo"] = 50 assert wrapped_train_stats.train_time == 2.0 assert wrapped_train_stats.smoothed_loss["foo"] == 50 # loss stats dict assert wrapped_train_stats.get_loss_stats_dict() == { "loss_field": 12, "dummy_field": 42, } # new fields can't be added with pytest.raises(AttributeError): wrapped_train_stats.new_loss_field = 90 # existing fields, wrapped and not-wrapped, can be mutated wrapped_train_stats.loss_field = 13 wrapped_train_stats.dummy_field = 43 assert hasattr( wrapped_train_stats.wrapped_stats, "loss_field", ), "Attribute `loss_field` not found in `wrapped_train_stats.wrapped_stats`." assert hasattr( wrapped_train_stats, "loss_field", ), "Attribute `loss_field` not found in `wrapped_train_stats`." assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13 @staticmethod @pytest.mark.parametrize( "act,dist", ( (np.array(1), Categorical(probs=torch.tensor([0.5, 0.5]))), (np.array([1, 2, 3]), Normal(torch.zeros(3), torch.ones(3))), ), ) def test_collect_stats_update_at_step( act: np.ndarray, dist: torch.distributions.Distribution, ) -> None: step_batch = cast( CollectStepBatchProtocol, Batch( info={}, obs=np.array([1, 2, 3]), obs_next=np.array([4, 5, 6]), act=act, rew=np.array(1.0), done=np.array(False), terminated=np.array(False), dist=dist, ).to_at_least_2d(), ) stats = CollectStats() for _ in range(10): stats.update_at_step_batch(step_batch) stats.refresh_all_sequence_stats() assert stats.n_collected_steps == 10 assert stats.pred_dist_std_array is not None assert np.allclose(stats.pred_dist_std_array, get_stddev_from_dist(dist)) assert stats.pred_dist_std_array_stat is not None assert stats.pred_dist_std_array_stat[0].mean == get_stddev_from_dist(dist)[0].item() ================================================ FILE: test/base/test_utils.py ================================================ from typing import cast import numpy as np import pytest import torch import torch.distributions as dist from gymnasium import spaces from torch import nn from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils import MovAvg, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic from tianshou.utils.torch_utils import create_uniform_action_dist, torch_train_mode def test_noise() -> None: noise = GaussianNoise() size = (3, 4, 5) assert np.allclose(noise(size).shape, size) noise = OUNoise() noise.reset() assert np.allclose(noise(size).shape, size) def test_moving_average() -> None: stat = MovAvg(10) assert np.allclose(stat.get(), 0) assert np.allclose(stat.mean(), 0) assert np.allclose(stat.std() ** 2, 0) stat.add(torch.tensor([1])) stat.add(np.array([2])) stat.add([3, 4]) stat.add(5.0) assert np.allclose(stat.get(), 3) assert np.allclose(stat.mean(), 3) assert np.allclose(stat.std() ** 2, 2) def test_rms() -> None: rms = RunningMeanStd() assert np.allclose(rms.mean, 0) assert np.allclose(rms.var, 1) rms.update(np.array([[[1, 2], [3, 5]]])) rms.update(np.array([[[1, 2], [3, 4]], [[1, 2], [0, 0]]])) assert np.allclose(rms.mean, np.array([[1, 2], [2, 3]]), atol=1e-3) assert np.allclose(rms.var, np.array([[0, 0], [2, 14 / 3.0]]), atol=1e-3) def test_net() -> None: # here test the networks that does not appear in the other script bsz = 64 # MLP data = torch.rand([bsz, 3]) mlp = MLP(input_dim=3, output_dim=6, hidden_sizes=[128]) assert list(mlp(data).shape) == [bsz, 6] # output == 0 and len(hidden_sizes) == 0 means identity model mlp = MLP(input_dim=6, output_dim=0) assert data.shape == mlp(data).shape # common net state_shape = (10, 2) action_shape = (5,) data = torch.rand([bsz, *state_shape]) expect_output_shape = [bsz, *action_shape] net = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128], norm_layer=torch.nn.LayerNorm, activation=None, ) assert list(net(data)[0].shape) == expect_output_shape assert str(net).count("LayerNorm") == 2 assert str(net).count("ReLU") == 0 Q_param = V_param = {"hidden_sizes": [128, 128]} net = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128], dueling_param=(Q_param, V_param), ) assert list(net(data)[0].shape) == expect_output_shape # concat net = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128], concat=True, ) data = torch.rand([bsz, int(np.prod(state_shape)) + int(np.prod(action_shape))]) expect_output_shape = [bsz, 128] assert list(net(data)[0].shape) == expect_output_shape net = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128], concat=True, dueling_param=(Q_param, V_param), ) assert list(net(data)[0].shape) == expect_output_shape # recurrent actor/critic data = torch.rand([bsz, *state_shape]).flatten(1) expect_output_shape = [bsz, *action_shape] net = RecurrentActorProb(layer_num=3, state_shape=state_shape, action_shape=action_shape) mu, sigma = net(data)[0] assert mu.shape == sigma.shape assert list(mu.shape) == [bsz, 5] net = RecurrentCritic(layer_num=3, state_shape=state_shape, action_shape=action_shape) data = torch.rand([bsz, 8, int(np.prod(state_shape))]) act = torch.rand(expect_output_shape) assert list(net(data, act).shape) == [bsz, 1] def test_in_eval_mode() -> None: module = nn.Linear(3, 4) module.train() with torch_train_mode(module, False): assert not module.training assert module.training def test_in_train_mode() -> None: module = nn.Linear(3, 4) module.eval() with torch_train_mode(module): assert module.training assert not module.training class TestCreateActionDistribution: @classmethod def setup_class(cls) -> None: # Set random seeds for reproducibility torch.manual_seed(0) np.random.seed(0) @pytest.mark.parametrize( "action_space, batch_size", [ (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 1), (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 5), (spaces.Discrete(5), 1), (spaces.Discrete(5), 5), ], ) def test_distribution_properties( self, action_space: spaces.Box | spaces.Discrete, batch_size: int, ) -> None: distribution = create_uniform_action_dist(action_space, batch_size) # Correct distribution type if isinstance(action_space, spaces.Box): assert isinstance(distribution, dist.Uniform) elif isinstance(action_space, spaces.Discrete): assert isinstance(distribution, dist.Categorical) # Samples are within correct range samples = distribution.sample() if isinstance(action_space, spaces.Box): low = torch.tensor(action_space.low, dtype=torch.float32) high = torch.tensor(action_space.high, dtype=torch.float32) assert torch.all(samples >= low) assert torch.all(samples <= high) elif isinstance(action_space, spaces.Discrete): assert torch.all(samples >= 0) assert torch.all(samples < action_space.n) @pytest.mark.parametrize( "action_space, batch_size", [ (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 1), (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 5), (spaces.Discrete(5), 1), (spaces.Discrete(5), 5), ], ) def test_distribution_uniformity( self, action_space: spaces.Box | spaces.Discrete, batch_size: int, ) -> None: distribution = create_uniform_action_dist(action_space, batch_size) # Test 7: Uniform distribution (statistical test) large_sample = distribution.sample(torch.Size((10000,))) if isinstance(action_space, spaces.Box): # For Box, check if mean is close to 0 and std is close to 1/sqrt(3) assert torch.allclose(large_sample.mean(), torch.tensor(0.0), atol=0.1) assert torch.allclose(large_sample.std(), torch.tensor(1 / 3**0.5), atol=0.1) elif isinstance(action_space, spaces.Discrete): # For Discrete, check if all actions are roughly equally likely n_actions = cast(int, action_space.n) counts = torch.bincount(large_sample.flatten(), minlength=n_actions).float() expected_count = 10000 * batch_size / n_actions assert torch.allclose(counts, torch.tensor(expected_count).float(), rtol=0.1) def test_unsupported_space(self) -> None: # Test 6: Raises ValueError for unsupported space with pytest.raises(ValueError): create_uniform_action_dist(spaces.MultiBinary(5)) # type: ignore @pytest.mark.parametrize( "space, batch_size, expected_shape, distribution_type", [ (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 1, (1, 3), dist.Uniform), (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 5, (5, 3), dist.Uniform), (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 10, (10, 3), dist.Uniform), (spaces.Discrete(5), 1, (1,), dist.Categorical), (spaces.Discrete(5), 5, (5,), dist.Categorical), (spaces.Discrete(5), 10, (10,), dist.Categorical), ], ) def test_batch_sizes( self, space: spaces.Box | spaces.Discrete, batch_size: int, expected_shape: tuple[int, ...], distribution_type: type[dist.Distribution], ) -> None: distribution = create_uniform_action_dist(space, batch_size) # Check distribution type assert isinstance(distribution, distribution_type) # Check sample shape samples = distribution.sample() assert samples.shape == expected_shape # Check internal distribution shapes if isinstance(space, spaces.Box): distribution = cast(dist.Uniform, distribution) assert distribution.low.shape == expected_shape assert distribution.high.shape == expected_shape elif isinstance(space, spaces.Discrete): distribution = cast(dist.Categorical, distribution) assert distribution.probs.shape == (batch_size, space.n) ================================================ FILE: test/continuous/__init__.py ================================================ ================================================ FILE: test/continuous/test_ddpg.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import DDPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--actor_lr", type=float, default=1e-4) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--exploration_noise", type=float, default=0.1) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=20000) parser.add_argument("--collection_step_num_env_steps", type=int, default=8) parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_training_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--return_scaling", action="store_true", default=False) parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_ddpg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic = ContinuousCritic(preprocess_net=net).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = ContinuousDeterministicPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, ) policy_optim = AdamOptimizerFactory(lr=args.actor_lr) algorithm: DDPG = DDPG( policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, n_step_return_horizon=args.n_step, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ddpg") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_ddpg_determinism() -> None: main_fn = lambda args: test_ddpg(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_ddpg", main_fn, get_args()).run() ================================================ FILE: test/continuous/test_npg.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch import nn from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import NPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer_size", type=int, default=50000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=50000) parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) parser.add_argument( "--update_step_num_repetitions", type=int, default=2 ) # theoretically it should be 1 parser.add_argument("--batch_size", type=int, default=99999) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) # npg special parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--optim_critic_iters", type=int, default=5) parser.add_argument("--trust_region_size", type=float, default=0.5) return parser.parse_known_args()[0] def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( preprocess_net=Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ), ).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, deterministic_eval=True, ) algorithm: NPG = NPG( policy=policy, critic=critic, optim=AdamOptimizerFactory(lr=args.lr), gamma=args.gamma, return_scaling=args.return_scaling, advantage_normalization=args.advantage_normalization, gae_lambda=args.gae_lambda, optim_critic_iters=args.optim_critic_iters, trust_region_size=args.trust_region_size, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), ) test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "npg") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_npg_determinism() -> None: main_fn = lambda args: test_npg(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_npg", main_fn, get_args()).run() ================================================ FILE: test/continuous/test_ppo.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=150000) parser.add_argument("--collection_step_num_episodes", type=int, default=16) parser.add_argument("--update_step_num_repetitions", type=int, default=2) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special parser.add_argument("--vf_coef", type=float, default=0.25) parser.add_argument("--ent_coef", type=float, default=0.0) parser.add_argument("--eps_clip", type=float, default=0.2) parser.add_argument("--max_grad_norm", type=float, default=0.5) parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument("--dual_clip", type=float, default=None) parser.add_argument("--value_clip", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--resume", action="store_true") parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( preprocess_net=Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = AdamOptimizerFactory(lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, ) algorithm: PPO = PPO( policy=policy, critic=critic, optim=optim, gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, return_scaling=args.return_scaling, advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, dual_clip=args.dual_clip, value_clip=args.value_clip, gae_lambda=args.gae_lambda, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), ) test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( algorithm.state_dict(), ckpt_path, ) return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") # train result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_ppo_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_ppo(args) def test_ppo_determinism() -> None: main_fn = lambda args: test_ppo(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_ppo", main_fn, get_args()).run() ================================================ FILE: test/continuous/test_redq.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch import torch.nn as nn from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import REDQ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.redq import REDQPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--ensemble_size", type=int, default=4) parser.add_argument("--subset_size", type=int, default=2) parser.add_argument("--actor_lr", type=float, default=1e-4) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--auto_alpha", action="store_true", default=False) parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--start_timesteps", type=int, default=1000) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--collection_step_num_env_steps", type=int, default=1) parser.add_argument("--update_per_step", type=int, default=3) parser.add_argument("--n_step", type=int, default=1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--target_mode", type=str, choices=("min", "mean"), default="min") parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # you can also use tianshou.env.SubprocVectorEnv # training_envs = gym.make(args.task) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True, conditioned_sigma=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) def linear(x: int, y: int) -> nn.Module: return EnsembleLinear(args.ensemble_size, x, y) net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, linear_layer=linear, ) critic = ContinuousCritic(preprocess_net=net_c, linear_layer=linear, flatten_input=False).to( args.device, ) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = REDQPolicy( actor=actor, action_space=env.action_space, ) algorithm: REDQ = REDQ( policy=policy, policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, ensemble_size=args.ensemble_size, subset_size=args.subset_size, tau=args.tau, gamma=args.gamma, alpha=args.alpha, n_step_return_horizon=args.n_step, actor_delay=args.update_per_step, target_mode=args.target_mode, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs) training_collector.reset() training_collector.collect(n_step=args.start_timesteps, random=True) # log log_path = os.path.join(args.logdir, args.task, "redq") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_redq_determinism() -> None: main_fn = lambda args: test_redq(args, enable_assertions=False) ignored_messages = [ "Params[actor_old]", ] # actor_old only present in v1 (due to flawed inheritance) AlgorithmDeterminismTest( "continuous_redq", main_fn, get_args(), ignored_messages=ignored_messages, ).run() ================================================ FILE: test/continuous/test_sac_with_il.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import SAC, OffPolicyImitationLearning from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.imitation_base import ImitationPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ( ContinuousActorDeterministic, ContinuousActorProbabilistic, ContinuousCritic, ) from tianshou.utils.space_info import SpaceInfo try: import envpool except ImportError: envpool = None def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--actor_lr", type=float, default=1e-3) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--il_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--auto_alpha", type=int, default=1) parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=24000) parser.add_argument("--il_step_per_epoch", type=int, default=500) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--imitation_hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_sac_with_il( args: argparse.Namespace = get_args(), enable_assertions: bool = True, skip_il: bool = False, ) -> None: # if you want to use python vector env, please refer to other test scripts # training_envs = env = envpool.make_gymnasium(args.task, num_envs=args.num_training_envs, seed=args.seed) # test_envs = envpool.make_gymnasium(args.task, num_envs=args.num_test_envs, seed=args.seed) env = gym.make(args.task) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed + args.num_training_envs) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) policy = SACPolicy( actor=actor, action_space=env.action_space, ) algorithm: SAC = SAC( policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, alpha=args.alpha, n_step_return_horizon=args.n_step, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs) # training_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) if skip_il: return # here we define an imitation collector with a trivial policy if args.task.startswith("Pendulum"): args.reward_threshold -= 50 # lower the goal il_net = Net( state_shape=args.state_shape, hidden_sizes=args.imitation_hidden_sizes, ) il_actor = ContinuousActorDeterministic( preprocess_net=il_net, action_shape=args.action_shape, max_action=args.max_action, ).to(args.device) optim = AdamOptimizerFactory(lr=args.il_lr) il_policy = ImitationPolicy( actor=il_actor, action_space=env.action_space, action_scaling=True, action_bound_method="clip", ) il_algorithm: OffPolicyImitationLearning = OffPolicyImitationLearning( policy=il_policy, optim=optim, ) il_test_env = gym.make(args.task) il_test_env.reset(seed=args.seed + args.num_training_envs + args.num_test_envs) il_test_collector = Collector[CollectStats]( il_algorithm, # envpool.make_gymnasium(args.task, num_envs=args.num_test_envs, seed=args.seed), il_test_env, ) training_collector.reset() result = il_algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=il_test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_sac_determinism() -> None: main_fn = lambda args: test_sac_with_il(args, enable_assertions=False, skip_il=True) AlgorithmDeterminismTest("continuous_sac", main_fn, get_args()).run() ================================================ FILE: test/continuous/test_td3.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import TD3 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--actor_lr", type=float, default=1e-4) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--exploration_noise", type=float, default=0.1) parser.add_argument("--policy_noise", type=float, default=0.2) parser.add_argument("--noise_clip", type=float, default=0.5) parser.add_argument("--update_actor_freq", type=int, default=2) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=20000) parser.add_argument("--collection_step_num_env_steps", type=int, default=8) parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_training_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # you can also use tianshou.env.SubprocVectorEnv # training_envs = gym.make(args.task) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = ContinuousDeterministicPolicy( actor=actor, action_space=env.action_space, exploration_noise=GaussianNoise(sigma=args.exploration_noise), ) algorithm: TD3 = TD3( policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, n_step_return_horizon=args.n_step, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs) # training_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "td3") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_td3_determinism() -> None: main_fn = lambda args: test_td3(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_td3", main_fn, get_args()).run() ================================================ FILE: test/continuous/test_trpo.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch import nn from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import TRPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer_size", type=int, default=50000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=50000) parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) parser.add_argument( "--update_step_num_repetitions", type=int, default=2 ) # theoretically it should be 1 parser.add_argument("--batch_size", type=int, default=99999) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) # trpo special parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--optim_critic_iters", type=int, default=5) parser.add_argument("--max_kl", type=float, default=0.005) parser.add_argument("--backtrack_coeff", type=float, default=0.8) parser.add_argument("--max_backtracks", type=int, default=10) return parser.parse_known_args()[0] def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # you can also use tianshou.env.SubprocVectorEnv # training_envs = gym.make(args.task) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( preprocess_net=Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ), ).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = AdamOptimizerFactory(lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, ) algorithm: TRPO = TRPO( policy=policy, critic=critic, optim=optim, gamma=args.gamma, return_scaling=args.return_scaling, advantage_normalization=args.advantage_normalization, gae_lambda=args.gae_lambda, optim_critic_iters=args.optim_critic_iters, max_kl=args.max_kl, backtrack_coeff=args.backtrack_coeff, max_backtracks=args.max_backtracks, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), ) test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "trpo") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_trpo_determinism() -> None: main_fn = lambda args: test_trpo(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_trpo", main_fn, get_args()).run() ================================================ FILE: test/determinism_test.py ================================================ from argparse import Namespace from collections.abc import Callable, Sequence from pathlib import Path from typing import Any import pytest import torch from tianshou.utils.determinism import TraceDeterminismTest, TraceLoggerContext class TorchDeterministicModeContext: def __init__(self, mode: str | int = "default") -> None: self.new_mode = mode self.original_mode: str | int | None = None def __enter__(self) -> None: self.original_mode = torch.get_deterministic_debug_mode() torch.set_deterministic_debug_mode(self.new_mode) def __exit__(self, exc_type, exc_value, traceback): # type: ignore assert self.original_mode is not None torch.set_deterministic_debug_mode(self.original_mode) class AlgorithmDeterminismTest: """ Represents a determinism test for Tianshou's RL algorithms. A test using this class should be added for every algorithm in Tianshou. Then, when making changes to one or more algorithms (e.g. refactoring), run the respective tests on the old branch (creating snapshots) and then on the new branch that contains the changes (comparing with the snapshots). Intended usage is therefore: 1. On the old branch: Set ENABLED=True and FORCE_SNAPSHOT_UPDATE=True and run the tests. 2. On the new branch: Set ENABLED=True and FORCE_SNAPSHOT_UPDATE=False and run the tests. 3. Inspect determinism_tests.log """ ENABLED = False """ whether determinism tests are enabled. """ FORCE_SNAPSHOT_UPDATE = False """ whether to force the update/creation of snapshots for every test. Enable this when running on the "old" branch and you want to prepare the snapshots for a comparison with the "new" branch. """ PASS_IF_CORE_MESSAGES_UNCHANGED = True """ whether to pass the test if only the core messages are unchanged. If this is False, then the full log is required to be equivalent, whereas if it is True, only the core messages need to be equivalent. The core messages test whether the algorithm produces the same network parameters. """ def __init__( self, name: str, main_fn: Callable[[Namespace], Any], args: Namespace, is_offline: bool = False, ignored_messages: Sequence[str] = (), ): """ :param name: the (unique!) name of the test :param main_fn: the function to be called for the test :param args: the arguments to be passed to the main function (some of which are overridden for the test) :param is_offline: whether the algorithm being tested is an offline algorithm and therefore does not configure the number of training environments (`num_training_envs`) :param ignored_messages: message fragments to ignore in the trace log (if any) """ self.determinism_test = TraceDeterminismTest( base_path=Path(__file__).parent / "resources" / "determinism", log_filename="determinism_tests.log", core_messages=["Params"], ignored_messages=ignored_messages, ) self.name = name def set(attr: str, value: Any) -> None: old_value = getattr(args, attr) if old_value is None: raise ValueError(f"Attribute '{attr}' is not defined for args: {args}") setattr(args, attr, value) set("epoch", 3) set("epoch_num_steps", 100) set("device", "cpu") if not is_offline: set("num_training_envs", 1) set("num_test_envs", 1) self.args = args self.main_fn = main_fn def run(self, update_snapshot: bool = False) -> None: """ :param update_snapshot: whether to update to snapshot (may be centrally overridden by FORCE_SNAPSHOT_UPDATE) """ if not self.ENABLED: pytest.skip("Algorithm determinism tests are disabled.") if self.FORCE_SNAPSHOT_UPDATE: update_snapshot = True # run the actual process with TraceLoggerContext() as trace: with TorchDeterministicModeContext(): self.main_fn(self.args) log = trace.get_log() self.determinism_test.check( log, self.name, create_reference_result=update_snapshot, pass_if_core_messages_unchanged=self.PASS_IF_CORE_MESSAGES_UNCHANGED, ) ================================================ FILE: test/discrete/__init__.py ================================================ ================================================ FILE: test/discrete/test_a2c_with_il.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import A2C, Algorithm, OffPolicyImitationLearning from tianshou.algorithm.imitation.imitation_base import ImitationPolicy from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic try: import envpool except ImportError: envpool = None def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--il_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=50000) parser.add_argument("--il_step_per_epoch", type=int, default=1000) parser.add_argument("--collection_step_num_episodes", type=int, default=16) parser.add_argument("--collection_step_num_env_steps", type=int, default=16) parser.add_argument("--update_per_step", type=float, default=1 / 16) parser.add_argument("--update_step_num_repetitions", type=int, default=1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--imitation_hidden_sizes", type=int, nargs="*", default=[128]) parser.add_argument("--num_training_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) # a2c special parser.add_argument("--vf_coef", type=float, default=0.5) parser.add_argument("--ent_coef", type=float, default=0.0) parser.add_argument("--max_grad_norm", type=float, default=None) parser.add_argument("--gae_lambda", type=float, default=1.0) parser.add_argument("--return_scaling", action="store_true", default=False) return parser.parse_known_args()[0] def test_a2c_with_il( args: argparse.Namespace = get_args(), enable_assertions: bool = True, skip_il: bool = False, ) -> None: # seed np.random.seed(args.seed) torch.manual_seed(args.seed) if envpool is not None: training_envs = env = envpool.make( args.task, env_type="gymnasium", num_envs=args.num_training_envs, seed=args.seed, ) test_envs = envpool.make( args.task, env_type="gymnasium", num_envs=args.num_test_envs, seed=args.seed, ) else: env = gym.make(args.task) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) training_envs.seed(args.seed) test_envs.seed(args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) critic = DiscreteCritic(preprocess_net=net).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=isinstance(env.action_space, Box), action_space=env.action_space, ) algorithm: A2C = A2C( policy=policy, critic=critic, optim=optim, gamma=args.gamma, gae_lambda=args.gae_lambda, vf_coef=args.vf_coef, ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm, return_scaling=args.return_scaling, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), ) training_collector.reset() test_collector = Collector[CollectStats](algorithm, test_envs) test_collector.reset() # log log_path = os.path.join(args.logdir, args.task, "a2c") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) if skip_il: return # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) optim = AdamOptimizerFactory(lr=args.il_lr) il_policy = ImitationPolicy( actor=actor, action_space=env.action_space, ) il_algorithm: OffPolicyImitationLearning = OffPolicyImitationLearning( policy=il_policy, optim=optim, ) if envpool is not None: il_env = envpool.make( args.task, env_type="gymnasium", num_envs=args.num_test_envs, seed=args.seed, ) else: il_env = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_test_envs)], ) il_env.seed(args.seed) il_test_collector = Collector[CollectStats]( il_algorithm, il_env, ) training_collector.reset() result = il_algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=il_test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_ppo_determinism() -> None: main_fn = lambda args: test_a2c_with_il(args, enable_assertions=False, skip_il=True) AlgorithmDeterminismTest("discrete_a2c", main_fn, get_args()).run() ================================================ FILE: test/discrete/test_bdqn.py ================================================ import argparse import gymnasium as gym import numpy as np import torch from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import BDQN from tianshou.algorithm.modelfree.bdqn import BDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import BranchingNet from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() # task parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) # network architecture parser.add_argument("--common_hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--action_hidden_sizes", type=int, nargs="*", default=[64]) parser.add_argument("--value_hidden_sizes", type=int, nargs="*", default=[64]) parser.add_argument("--action_per_branch", type=int, default=40) # training hyperparameters parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.01) parser.add_argument("--eps_train", type=float, default=0.76) parser.add_argument("--eps_decay", type=float, default=1e-4) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--target_update_freq", type=int, default=200) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=80000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_bdq(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) if isinstance(env.observation_space, gym.spaces.Box): args.state_shape = env.observation_space.shape elif isinstance(env.observation_space, gym.spaces.Discrete): args.state_shape = int(env.observation_space.n) assert isinstance(env.action_space, gym.spaces.MultiDiscrete) args.num_branches = env.action_space.shape[0] if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) print("Observations shape:", args.state_shape) print("Num branches:", args.num_branches) print("Actions per branch:", args.action_per_branch) training_envs = DummyVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) for _ in range(args.num_training_envs) ], ) test_envs = DummyVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) for _ in range(args.num_test_envs) ], ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = BranchingNet( state_shape=args.state_shape, num_branches=args.num_branches, action_per_branch=args.action_per_branch, common_hidden_sizes=args.common_hidden_sizes, value_hidden_sizes=args.value_hidden_sizes, action_hidden_sizes=args.action_hidden_sizes, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = BDQNPolicy( model=net, action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: BDQN = BDQN( policy=policy, optim=optim, gamma=args.gamma, target_update_freq=args.target_update_freq, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, args.num_training_envs), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=False) # initial data collection with policy_within_training_step(policy): training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) policy.set_eps_training(eps) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, training_fn=train_fn, stop_fn=stop_fn, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_bdq_determinism() -> None: main_fn = lambda args: test_bdq(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_bdq", main_fn, get_args()).run() ================================================ FILE: test/discrete/test_c51.py ================================================ import argparse import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import C51 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, ReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--num_atoms", type=int, default=51) parser.add_argument("--v_min", type=float, default=-10.0) parser.add_argument("--v_max", type=float, default=10.0) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=8000) parser.add_argument("--collection_step_num_env_steps", type=int, default=8) parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_training_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument("--resume", action="store_true") parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax=True, num_atoms=args.num_atoms, ) optim = AdamOptimizerFactory(lr=args.lr) policy = C51Policy( model=net, action_space=env.action_space, observation_space=env.observation_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: C51 = C51( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # buffer buf: ReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), alpha=args.alpha, beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(training_envs)) # collectors training_collector = Collector[CollectStats]( algorithm, training_envs, buf, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection with policy_within_training_step(policy): training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # logger log_path = os.path.join(args.logdir, args.task, "c51") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(algorithm: Algorithm) -> None: torch.save(algorithm.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) policy.set_eps_training(eps) else: policy.set_eps_training(0.1 * args.eps_train) def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( algorithm.state_dict(), ckpt_path, ) buffer_path = os.path.join(log_path, "train_buffer.pkl") with open(buffer_path, "wb") as f: pickle.dump(training_collector.buffer, f) return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") buffer_path = os.path.join(log_path, "train_buffer.pkl") if os.path.exists(buffer_path): with open(buffer_path, "rb") as f: training_collector.buffer = pickle.load(f) print("Successfully restore buffer.") else: print("Fail to restore buffer.") # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_c51_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_c51(args) def test_pc51(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 args.seed = 1 test_c51(args) def test_c51_determinism() -> None: main_fn = lambda args: test_c51(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_c51", main_fn, get_args()).run() ================================================ FILE: test/discrete/test_discrete_sac.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import DiscreteSAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.discrete_sac import ( DiscreteSACPolicy, ) from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--actor_lr", type=float, default=1e-4) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.05) parser.add_argument("--auto_alpha", action="store_true", default=False) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_discrete_sac( args: argparse.Namespace = get_args(), enable_assertions: bool = True, ) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 170} # lower the goal args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model obs_dim = space_info.observation_info.obs_dim action_dim = space_info.action_info.action_dim net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, softmax_output=False ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) critic1 = DiscreteCritic(preprocess_net=net_c1, last_size=action_dim).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net(state_shape=obs_dim, hidden_sizes=args.hidden_sizes) critic2 = DiscreteCritic(preprocess_net=net_c2, last_size=action_dim).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # better not to use auto alpha in CartPole if args.auto_alpha: target_entropy = 0.98 * np.log(action_dim) log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = DiscreteSACPolicy( actor=actor, action_space=env.action_space, ) algorithm = DiscreteSAC( policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, alpha=args.alpha, n_step_return_horizon=args.n_step, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), ) test_collector = Collector[CollectStats](algorithm, test_envs) # training_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "discrete_sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_training=False, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_discrete_sac_determinism() -> None: main_fn = lambda args: test_discrete_sac(args, enable_assertions=False) ignored_messages = [ "Params[actor_old]", # actor_old only present in v1 (due to flawed inheritance) ] AlgorithmDeterminismTest( "discrete_sac", main_fn, get_args(), ignored_messages=ignored_messages ).run() ================================================ FILE: test/discrete/test_dqn.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, ReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=20) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # Q_param = V_param = {"hidden_sizes": [128]} # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, # dueling=(Q_param, V_param), ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, observation_space=env.observation_space, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: DQN = DQN( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) # buffer buf: ReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), alpha=args.alpha, beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(training_envs)) # collectors training_collector = Collector[CollectStats]( algorithm, training_envs, buf, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection with policy_within_training_step(policy): training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # logger log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) policy.set_eps_training(eps) else: policy.set_eps_training(0.1 * args.eps_train) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_dqn_determinism() -> None: main_fn = lambda args: test_dqn(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_dqn", main_fn, get_args()).run() def test_pdqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 args.seed = 1 test_dqn(args) ================================================ FILE: test/discrete/test_drqn.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--stack_num", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=20000) parser.add_argument("--update_per_step", type=float, default=1 / 16) parser.add_argument("--collection_step_num_env_steps", type=int, default=16) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--layer_num", type=int, default=2) parser.add_argument("--num_training_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # training_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Recurrent( layer_num=args.layer_num, state_shape=args.state_shape, action_shape=args.action_shape, ).to( args.device, ) optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: DQN = DQN( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) # collector buffer = VectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), stack_num=args.stack_num, ignore_obs_next=True, ) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) # the stack_num is for RNN training: sample framestack obs test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection with policy_within_training_step(policy): training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # log log_path = os.path.join(args.logdir, args.task, "drqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_drqn_determinism() -> None: main_fn = lambda args: test_drqn(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_drqn", main_fn, get_args()).run() ================================================ FILE: test/discrete/test_fqf.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import FQF from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.fqf import FQFPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, ReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--fraction_lr", type=float, default=2.5e-9) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--num_fractions", type=int, default=32) parser.add_argument("--num_cosines", type=int, default=64) parser.add_argument("--ent_coef", type=float, default=10.0) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64, 64]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) assert isinstance(env.action_space, gym.spaces.Discrete) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model feature_net = Net( state_shape=args.state_shape, action_shape=args.hidden_sizes[-1], hidden_sizes=args.hidden_sizes[:-1], softmax=False, ) net = FullQuantileFunction( preprocess_net=feature_net, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, num_cosines=args.num_cosines, ) optim = AdamOptimizerFactory(lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) fraction_optim = RMSpropOptimizerFactory(lr=args.fraction_lr) policy = FQFPolicy( model=net, fraction_model=fraction_net, action_space=env.action_space, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: FQF = FQF( policy=policy, optim=optim, fraction_optim=fraction_optim, gamma=args.gamma, num_fractions=args.num_fractions, ent_coef=args.ent_coef, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # buffer buf: ReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), alpha=args.alpha, beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(training_envs)) # collectors training_collector = Collector[CollectStats]( algorithm, training_envs, buf, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection with policy_within_training_step(policy): training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # logger log_path = os.path.join(args.logdir, args.task, "fqf") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) policy.set_eps_training(eps) else: policy.set_eps_training(0.1 * args.eps_train) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_pfqf(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_fqf(args) def test_fqf_determinism() -> None: main_fn = lambda args: test_fqf(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_fqf", main_fn, get_args()).run() ================================================ FILE: test/discrete/test_iqn.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import IQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.iqn import IQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, ReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import ImplicitQuantileNetwork from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--sample_size", type=int, default=32) parser.add_argument("--online_sample_size", type=int, default=8) parser.add_argument("--target_sample_size", type=int, default=8) parser.add_argument("--num_cosines", type=int, default=64) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64, 64]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) assert isinstance(env.action_space, gym.spaces.Discrete) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model feature_net = Net( state_shape=args.state_shape, action_shape=args.hidden_sizes[-1], hidden_sizes=args.hidden_sizes[:-1], softmax=False, ) net = ImplicitQuantileNetwork( preprocess_net=feature_net, action_shape=args.action_shape, num_cosines=args.num_cosines, ) optim = AdamOptimizerFactory(lr=args.lr) policy = IQNPolicy( model=net, action_space=env.action_space, sample_size=args.sample_size, online_sample_size=args.online_sample_size, target_sample_size=args.target_sample_size, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: IQN = IQN( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # buffer buf: ReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), alpha=args.alpha, beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(training_envs)) # collectors training_collector = Collector[CollectStats]( algorithm, training_envs, buf, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection with policy_within_training_step(policy): training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # logger log_path = os.path.join(args.logdir, args.task, "iqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) policy.set_eps_training(eps) else: policy.set_eps_training(0.1 * args.eps_train) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_piqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_iqn(args) def test_iqn_determinism() -> None: main_fn = lambda args: test_iqn(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_iqn", main_fn, get_args()).run() ================================================ FILE: test/discrete/test_ppo_discrete.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ( ActionReprNet, ActionReprNetDataParallelWrapper, ActorCritic, DataParallelNet, Net, ) from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=50000) parser.add_argument("--collection_step_num_env_steps", type=int, default=2000) parser.add_argument("--update_step_num_repetitions", type=int, default=10) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=20) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special parser.add_argument("--vf_coef", type=float, default=0.5) parser.add_argument("--ent_coef", type=float, default=0.0) parser.add_argument("--eps_clip", type=float, default=0.2) parser.add_argument("--max_grad_norm", type=float, default=0.5) parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=0) parser.add_argument("--advantage_normalization", type=int, default=0) parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--dual_clip", type=float, default=None) parser.add_argument("--value_clip", type=int, default=0) return parser.parse_known_args()[0] def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # training_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) critic: DiscreteCritic | DataParallelNet actor: ActionReprNet if torch.cuda.is_available(): actor = ActionReprNetDataParallelWrapper( DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) ) critic = DataParallelNet(DiscreteCritic(preprocess_net=net).to(args.device)) else: actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) critic = DiscreteCritic(preprocess_net=net).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical policy = DiscreteActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, deterministic_eval=True, ) algorithm: PPO = PPO( policy=policy, critic=critic, optim=optim, gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, gae_lambda=args.gae_lambda, return_scaling=args.return_scaling, dual_clip=args.dual_clip, value_clip=args.value_clip, advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), ) test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_ppo_determinism() -> None: main_fn = lambda args: test_ppo(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_ppo", main_fn, get_args()).run() ================================================ FILE: test/discrete/test_qrdqn.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import QRDQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--num_quantiles", type=int, default=200) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.task == "CartPole-v1" and env.spec: env.spec.reward_threshold = 190 # lower the goal if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # training_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax=False, num_atoms=args.num_quantiles, ) optim = AdamOptimizerFactory(lr=args.lr) policy = QRDQNPolicy( model=net, action_space=env.action_space, observation_space=env.observation_space, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: QRDQN = QRDQN( policy=policy, optim=optim, gamma=args.gamma, num_quantiles=args.num_quantiles, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # buffer buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), alpha=args.alpha, beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(training_envs)) # collectors training_collector = Collector[CollectStats]( algorithm, training_envs, buf, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection with policy_within_training_step(policy): training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # logger log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(algo: Algorithm) -> None: torch.save(algo.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) policy.set_eps_training(eps) else: policy.set_eps_training(0.1 * args.eps_train) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_pqrdqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_qrdqn(args) def test_qrdqn_determinism() -> None: main_fn = lambda args: test_qrdqn(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_qrdqn", main_fn, get_args()).run() ================================================ FILE: test/discrete/test_rainbow.py ================================================ import argparse import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import RainbowDQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--num_atoms", type=int, default=51) parser.add_argument("--v_min", type=float, default=-10.0) parser.add_argument("--v_max", type=float, default=10.0) parser.add_argument("--noisy_std", type=float, default=0.1) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=8000) parser.add_argument("--collection_step_num_env_steps", type=int, default=8) parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_training_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument("--beta_final", type=float, default=1.0) parser.add_argument("--resume", action="store_true") parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] def test_rainbow(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # training_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) def noisy_linear(x: int, y: int) -> NoisyLinear: return NoisyLinear(x, y, args.noisy_std) # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax=True, num_atoms=args.num_atoms, dueling_param=({"linear_layer": noisy_linear}, {"linear_layer": noisy_linear}), ) optim = AdamOptimizerFactory(lr=args.lr) policy = C51Policy( model=net, action_space=env.action_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: RainbowDQN = RainbowDQN( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # buffer buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), alpha=args.alpha, beta=args.beta, weight_norm=True, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(training_envs)) # collectors training_collector = Collector[CollectStats]( algorithm, training_envs, buf, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection with policy_within_training_step(policy): training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # logger log_path = os.path.join(args.logdir, args.task, "rainbow") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def train_fn(epoch: int, env_step: int) -> None: # eps annealing, just a demo if env_step <= 10000: policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) policy.set_eps_training(eps) else: policy.set_eps_training(0.1 * args.eps_train) # beta annealing, just a demo if args.prioritized_replay: if env_step <= 10000: beta = args.beta elif env_step <= 50000: beta = args.beta - (env_step - 10000) / 40000 * (args.beta - args.beta_final) else: beta = args.beta_final buf.set_beta(beta) def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( algorithm.state_dict(), ckpt_path, ) buffer_path = os.path.join(log_path, "train_buffer.pkl") with open(buffer_path, "wb") as f: pickle.dump(training_collector.buffer, f) return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") buffer_path = os.path.join(log_path, "train_buffer.pkl") if os.path.exists(buffer_path): with open(buffer_path, "rb") as f: training_collector.buffer = pickle.load(f) print("Successfully restore buffer.") else: print("Fail to restore buffer.") # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_rainbow_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_rainbow(args) def test_prainbow(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 args.seed = 1 test_rainbow(args) def test_rainbow_determinism() -> None: main_fn = lambda args: test_rainbow(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_rainbow", main_fn, get_args()).run() ================================================ FILE: test/discrete/test_reinforce.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import Reinforce from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=40000) parser.add_argument("--collection_step_num_episodes", type=int, default=8) parser.add_argument("--update_step_num_repetitions", type=int, default=2) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_reinforce(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax=True, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) dist_fn = torch.distributions.Categorical policy = ProbabilisticActorPolicy( actor=net, dist_fn=dist_fn, action_space=env.action_space, action_scaling=isinstance(env.action_space, Box), ) algorithm: Reinforce = Reinforce( policy=policy, optim=optim, gamma=args.gamma, return_standardization=args.return_scaling, ) for m in net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), ) test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "pg") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(algorithm: Algorithm) -> None: torch.save(algorithm.policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train training_config = OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) result = algorithm.run_training(training_config) if enable_assertions: assert stop_fn(result.best_reward) def test_reinforce_determinism() -> None: main_fn = lambda args: test_reinforce(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_reinforce", main_fn, get_args()).run() ================================================ FILE: test/highlevel/__init__.py ================================================ ================================================ FILE: test/highlevel/env_factory.py ================================================ from tianshou.highlevel.env import ( EnvFactoryRegistered, VectorEnvType, ) class DiscreteTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: super().__init__( task="CartPole-v1", venv_type=VectorEnvType.DUMMY, ) class ContinuousTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: super().__init__( task="Pendulum-v1", venv_type=VectorEnvType.DUMMY, ) ================================================ FILE: test/highlevel/test_experiment_builder.py ================================================ import pytest from test.highlevel.env_factory import ContinuousTestEnvFactory, DiscreteTestEnvFactory from tianshou.highlevel.config import ( OffPolicyTrainingConfig, OnPolicyTrainingConfig, ) from tianshou.highlevel.experiment import ( A2CExperimentBuilder, DDPGExperimentBuilder, DiscreteSACExperimentBuilder, DQNExperimentBuilder, ExperimentBuilder, ExperimentConfig, IQNExperimentBuilder, OffPolicyExperimentBuilder, OnPolicyExperimentBuilder, PPOExperimentBuilder, REDQExperimentBuilder, ReinforceExperimentBuilder, SACExperimentBuilder, TD3ExperimentBuilder, TRPOExperimentBuilder, ) def create_training_config( builder_cls: type[ExperimentBuilder], num_epochs: int = 1, epoch_num_steps: int = 100, num_training_envs: int = 2, num_test_envs: int = 2, ) -> OffPolicyTrainingConfig | OnPolicyTrainingConfig: if issubclass(builder_cls, OffPolicyExperimentBuilder): return OffPolicyTrainingConfig( max_epochs=num_epochs, epoch_num_steps=epoch_num_steps, num_training_envs=num_training_envs, num_test_envs=num_test_envs, ) elif issubclass(builder_cls, OnPolicyExperimentBuilder): return OnPolicyTrainingConfig( max_epochs=num_epochs, epoch_num_steps=epoch_num_steps, num_training_envs=num_training_envs, num_test_envs=num_test_envs, ) else: raise ValueError @pytest.mark.parametrize( "builder_cls", [ PPOExperimentBuilder, A2CExperimentBuilder, SACExperimentBuilder, DDPGExperimentBuilder, TD3ExperimentBuilder, # NPGExperimentBuilder, # TODO test fails non-deterministically REDQExperimentBuilder, TRPOExperimentBuilder, ReinforceExperimentBuilder, ], ) def test_experiment_builder_continuous_default_params( builder_cls: type[ExperimentBuilder], ) -> None: env_factory = ContinuousTestEnvFactory() training_config = create_training_config( builder_cls, num_epochs=1, epoch_num_steps=100, num_training_envs=2, num_test_envs=2, ) experiment_config = ExperimentConfig(persistence_enabled=False) builder = builder_cls( experiment_config=experiment_config, env_factory=env_factory, training_config=training_config, ) experiment = builder.build() experiment.run(run_name="test") print(experiment) @pytest.mark.parametrize( "builder_cls", [ ReinforceExperimentBuilder, PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder, DiscreteSACExperimentBuilder, IQNExperimentBuilder, ], ) def test_experiment_builder_discrete_default_params( builder_cls: type[ExperimentBuilder], ) -> None: env_factory = DiscreteTestEnvFactory() training_config = create_training_config( builder_cls, num_epochs=1, epoch_num_steps=100, num_training_envs=2, num_test_envs=2, ) builder = builder_cls( experiment_config=ExperimentConfig(persistence_enabled=False), env_factory=env_factory, training_config=training_config, ) experiment = builder.build() experiment.run(run_name="test") print(experiment) ================================================ FILE: test/modelbased/__init__.py ================================================ ================================================ FILE: test/modelbased/test_dqn_icm.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import DQN, Algorithm, ICMOffPolicyWrapper from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.discrete import IntrinsicCuriosityModule from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=20) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument( "--lr_scale", type=float, default=1.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( "--reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( "--forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", ) return parser.parse_known_args()[0] def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # Q_param = V_param = {"hidden_sizes": [128]} # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, # dueling=(Q_param, V_param), ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: DQN = DQN( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) # ICM wrapper feature_dim = args.hidden_sizes[-1] obs_dim = space_info.observation_info.obs_dim feature_net = MLP( input_dim=obs_dim, output_dim=feature_dim, hidden_sizes=args.hidden_sizes[:-1], ) action_dim = space_info.action_info.action_dim icm_net = IntrinsicCuriosityModule( feature_net=feature_net, feature_dim=feature_dim, action_dim=action_dim, hidden_sizes=args.hidden_sizes[-1:], ).to(args.device) icm_optim = AdamOptimizerFactory(lr=args.lr) icm_algorithm = ICMOffPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=args.lr_scale, reward_scale=args.reward_scale, forward_loss_weight=args.forward_loss_weight, ) # buffer buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), alpha=args.alpha, beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(training_envs)) # collector training_collector = Collector[CollectStats]( icm_algorithm, training_envs, buf, exploration_noise=True ) test_collector = Collector[CollectStats](icm_algorithm, test_envs, exploration_noise=True) training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # log log_path = str(os.path.join(args.logdir, args.task, "dqn_icm")) writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) policy.set_eps_training(eps) else: policy.set_eps_training(0.1 * args.eps_train) # train result = icm_algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) assert stop_fn(result.best_reward) ================================================ FILE: test/modelbased/test_ppo_icm.py ================================================ import argparse import os import gymnasium as gym import numpy as np import torch from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net from tianshou.utils.net.discrete import ( DiscreteActor, DiscreteCritic, IntrinsicCuriosityModule, ) from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=50000) parser.add_argument("--collection_step_num_env_steps", type=int, default=2000) parser.add_argument("--update_step_num_repetitions", type=int, default=10) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=20) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special parser.add_argument("--vf_coef", type=float, default=0.5) parser.add_argument("--ent_coef", type=float, default=0.0) parser.add_argument("--eps_clip", type=float, default=0.2) parser.add_argument("--max_grad_norm", type=float, default=0.5) parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=0) parser.add_argument("--advantage_normalization", type=int, default=0) parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--dual_clip", type=float, default=None) parser.add_argument("--value_clip", type=int, default=0) parser.add_argument( "--lr_scale", type=float, default=1.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( "--reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( "--forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", ) return parser.parse_known_args()[0] def test_ppo(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) critic = DiscreteCritic(preprocess_net=net).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) # base algorithm: PPO optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=isinstance(env.action_space, Box), action_space=env.action_space, deterministic_eval=True, ) algorithm: PPO = PPO( policy=policy, critic=critic, optim=optim, gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, gae_lambda=args.gae_lambda, return_scaling=args.return_scaling, dual_clip=args.dual_clip, value_clip=args.value_clip, advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) # ICM wrapper feature_dim = args.hidden_sizes[-1] feature_net = MLP( input_dim=space_info.observation_info.obs_dim, output_dim=feature_dim, hidden_sizes=args.hidden_sizes[:-1], ) action_dim = space_info.action_info.action_dim icm_net = IntrinsicCuriosityModule( feature_net=feature_net, feature_dim=feature_dim, action_dim=action_dim, hidden_sizes=args.hidden_sizes[-1:], ).to(args.device) icm_optim = AdamOptimizerFactory(lr=args.lr) icm_algorithm = ICMOnPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=args.lr_scale, reward_scale=args.reward_scale, forward_loss_weight=args.forward_loss_weight, ) # collector training_collector = Collector[CollectStats]( icm_algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), ) test_collector = Collector[CollectStats](icm_algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo_icm") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(alg: Algorithm) -> None: torch.save(alg.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = icm_algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_training=True, ) ) assert stop_fn(result.best_reward) ================================================ FILE: test/modelbased/test_psrl.py ================================================ import argparse import os import numpy as np import pytest import torch from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import PSRL from tianshou.algorithm.modelbased.psrl import PSRLPolicy from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger try: import envpool except ImportError: envpool = None def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="NChain-v0") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer_size", type=int, default=50000) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=1000) parser.add_argument("--collection_step_num_episodes", type=int, default=1) parser.add_argument("--num_training_envs", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--rew_mean_prior", type=float, default=0.0) parser.add_argument("--rew_std_prior", type=float, default=1.0) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eps", type=float, default=0.01) parser.add_argument("--add_done_loop", action="store_true", default=False) parser.add_argument( "--logger", type=str, default="none", # TODO: Change to "wandb" once wandb supports Gym >=0.26.0 choices=["wandb", "tensorboard", "none"], ) return parser.parse_known_args()[0] @pytest.mark.skipif( envpool is None, reason="EnvPool is not installed. If on linux, please install it (e.g. as poetry extra)", ) def test_psrl(args: argparse.Namespace = get_args()) -> None: training_envs = env = envpool.make_gymnasium( args.task, num_envs=args.num_training_envs, seed=args.seed ) test_envs = envpool.make_gymnasium(args.task, num_envs=args.num_test_envs, seed=args.seed) if args.reward_threshold is None: default_reward_threshold = {"NChain-v0": 3400} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) print("reward threshold:", args.reward_threshold) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # model n_action = args.action_shape n_state = args.state_shape trans_count_prior = np.ones((n_state, n_action, n_state)) rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) policy = PSRLPolicy( trans_count_prior=trans_count_prior, rew_mean_prior=rew_mean_prior, rew_std_prior=rew_std_prior, action_space=env.action_space, discount_factor=args.gamma, epsilon=args.eps, ) algorithm: PSRL = PSRL( policy=policy, add_done_loop=args.add_done_loop, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) training_collector.reset() test_collector = Collector[CollectStats](algorithm, test_envs) test_collector.reset() # Logger log_path = os.path.join(args.logdir, args.task, "psrl") writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger: WandbLogger | TensorboardLogger | LazyLogger if args.logger == "wandb": logger = WandbLogger(save_interval=1, project="psrl", name="wandb_test", config=args) logger.load(writer) elif args.logger == "tensorboard": logger = TensorboardLogger(writer) else: logger = LazyLogger() def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold training_collector.collect(n_step=args.buffer_size, random=True) # train result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=1, test_step_num_episodes=args.num_test_envs, batch_size=0, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, logger=logger, test_in_training=False, ) ) assert result.best_reward >= args.reward_threshold ================================================ FILE: test/offline/__init__.py ================================================ ================================================ FILE: test/offline/gather_cartpole_data.py ================================================ import argparse import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import QRDQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo def expert_file_name() -> str: return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v1.pkl") def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--num_quantiles", type=int, default=200) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument("--save_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: args = get_args() env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 190} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # training_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax=False, num_atoms=args.num_quantiles, ) optim = AdamOptimizerFactory(lr=args.lr) policy = QRDQNPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, eps_inference=args.eps_test, ) algorithm: QRDQN = QRDQN( policy=policy, optim=optim, gamma=args.gamma, num_quantiles=args.num_quantiles, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # buffer buf: VectorReplayBuffer | PrioritizedVectorReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, buffer_num=len(training_envs), alpha=args.alpha, beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(training_envs)) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, buf, exploration_noise=True ) training_collector.reset() test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) test_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) policy.set_eps_training(eps) else: policy.set_eps_training(0.1 * args.eps_train) # train result = algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, training_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_training=True, ) ) assert stop_fn(result.best_reward) # save buffer in pickle format, for imitation learning unittest buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) policy.set_eps_inference(0.2) collector = Collector[CollectStats](algorithm, test_envs, buf, exploration_noise=True) collector.reset() collector_stats = collector.collect(n_step=args.buffer_size) if args.save_buffer_name.endswith(".hdf5"): buf.save_hdf5(args.save_buffer_name) else: with open(args.save_buffer_name, "wb") as f: pickle.dump(buf, f) print(collector_stats) return buf ================================================ FILE: test/offline/gather_pendulum_data.py ================================================ import argparse import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import SAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy, SACTrainingStats from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def expert_file_name() -> str: return os.path.join(os.path.dirname(__file__), "expert_SAC_Pendulum-v1.pkl") def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--actor_lr", type=float, default=1e-3) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=7) parser.add_argument("--epoch_num_steps", type=int, default=8000) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--gamma", default=0.99) parser.add_argument("--tau", default=0.005) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) # sac: parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--auto_alpha", type=int, default=1) parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--save_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] def gather_data() -> VectorReplayBuffer: """Return expert buffer data.""" args = get_args() env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # you can also use tianshou.env.SubprocVectorEnv # training_envs = gym.make(args.task) training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic = ContinuousCritic(preprocess_net=net_c).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = SACPolicy( actor=actor, action_space=env.action_space, ) algorithm: SAC[SACTrainingStats] = SAC( policy=policy, policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, alpha=args.alpha, n_step_return_horizon=args.n_step, ) # collector buffer = VectorReplayBuffer(args.buffer_size, len(training_envs)) training_collector = Collector[CollectStats]( algorithm, training_envs, buffer, exploration_noise=True ) test_collector = Collector[CollectStats](algorithm, test_envs) # training_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, test_in_training=True, ) ) training_collector.reset() collector_stats = training_collector.collect(n_step=args.buffer_size) print(collector_stats) if args.save_buffer_name.endswith(".hdf5"): buffer.save_hdf5(args.save_buffer_name) else: with open(args.save_buffer_name, "wb") as f: pickle.dump(buffer, f) return buffer ================================================ FILE: test/offline/test_bcq.py ================================================ import argparse import datetime import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data from tianshou.algorithm import BCQ, Algorithm from tianshou.algorithm.imitation.bcq import BCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64]) parser.add_argument("--actor_lr", type=float, default=1e-3) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=500) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument("--vae_hidden_sizes", type=int, nargs="*", default=[32, 32]) # default to 2 * action_dim parser.add_argument("--latent_dim", type=int, default=None) parser.add_argument("--gamma", default=0.99) parser.add_argument("--tau", default=0.005) # Weighting for Clipped Double Q-learning in BCQ parser.add_argument("--lmbda", default=0.75) # Max perturbation hyper-parameter for BCQ parser.add_argument("--phi", default=0.05) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) parser.add_argument("--show_progress", action="store_true") return parser.parse_known_args()[0] def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) else: buffer = gather_data() env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action args.state_dim = space_info.observation_info.obs_dim args.action_dim = space_info.action_info.action_dim if args.reward_threshold is None: # too low? default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # perturbation network net_a = MLP( input_dim=args.state_dim + args.action_dim, output_dim=args.action_dim, hidden_sizes=args.hidden_sizes, ) actor_perturbation = Perturbation( preprocess_net=net_a, max_action=args.max_action, phi=args.phi ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic = ContinuousCritic(preprocess_net=net_c).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) # vae # output_dim = 0, so the last Module in the encoder is ReLU vae_encoder = MLP( input_dim=args.state_dim + args.action_dim, hidden_sizes=args.vae_hidden_sizes, ) if not args.latent_dim: args.latent_dim = args.action_dim * 2 vae_decoder = MLP( input_dim=args.state_dim + args.latent_dim, output_dim=args.action_dim, hidden_sizes=args.vae_hidden_sizes, ) vae = VAE( encoder=vae_encoder, decoder=vae_decoder, hidden_dim=args.vae_hidden_sizes[-1], latent_dim=args.latent_dim, max_action=args.max_action, ).to(args.device) vae_optim = AdamOptimizerFactory() policy = BCQPolicy( actor_perturbation=actor_perturbation, critic=critic, vae=vae, action_space=env.action_space, ) algorithm = BCQ( policy=policy, actor_perturbation_optim=actor_optim, critic_optim=critic_optim, vae_optim=vae_optim, gamma=args.gamma, tau=args.tau, lmbda=args.lmbda, ).to(args.device) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector # buffer has been gathered # training_collector = Collector[CollectStats](policy, training_envs, buffer, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs) # logger t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f"seed_{args.seed}_{t0}-{args.task.replace('-', '_')}_bcq" log_path = os.path.join(args.logdir, args.task, "bcq", log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def watch() -> None: algorithm.load_state_dict( torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) # train result = algorithm.run_training( OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, show_progress=args.show_progress, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_bcq_determinism() -> None: main_fn = lambda args: test_bcq(args, enable_assertions=False) AlgorithmDeterminismTest("offline_bcq", main_fn, get_args(), is_offline=True).run() ================================================ FILE: test/offline/test_cql.py ================================================ import argparse import datetime import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data from tianshou.algorithm import CQL, Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--actor_lr", type=float, default=1e-3) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--auto_alpha", default=True, action="store_true") parser.add_argument("--alpha_lr", type=float, default=1e-3) parser.add_argument("--cql_alpha_lr", type=float, default=1e-3) parser.add_argument("--start_timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=500) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--cql_weight", type=float, default=1.0) parser.add_argument("--with_lagrange", type=bool, default=True) parser.add_argument("--lagrange_threshold", type=float, default=10.0) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eval_freq", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) else: buffer = gather_data() env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.min_action = space_info.action_info.min_action args.max_action = space_info.action_info.max_action args.state_dim = space_info.observation_info.obs_dim args.action_dim = space_info.action_info.action_dim if args.reward_threshold is None: # too low? default_reward_threshold = {"Pendulum-v0": -1200, "Pendulum-v1": -1200} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model # actor network net_a = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, ) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, conditioned_sigma=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic = ContinuousCritic(preprocess_net=net_c).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: target_entropy = float(-np.prod(args.action_shape)) log_alpha = 0.0 alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) policy = SACPolicy( actor=actor, # CQL seems to perform better without action scaling # TODO: investigate why action_scaling=False, action_space=env.action_space, ) algorithm = CQL( policy=policy, policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, cql_alpha_lr=args.cql_alpha_lr, cql_weight=args.cql_weight, tau=args.tau, gamma=args.gamma, alpha=args.alpha, temperature=args.temperature, with_lagrange=args.with_lagrange, lagrange_threshold=args.lagrange_threshold, min_action=args.min_action, max_action=args.max_action, ) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector # buffer has been gathered # training_collector = Collector[CollectStats](policy, training_envs, buffer, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f"seed_{args.seed}_{t0}-{args.task.replace('-', '_')}_cql" log_path = os.path.join(args.logdir, args.task, "cql", log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer result = algorithm.run_training( OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_cql_determinism() -> None: main_fn = lambda args: test_cql(args, enable_assertions=False) AlgorithmDeterminismTest("offline_cql", main_fn, get_args(), is_offline=True).run() ================================================ FILE: test/offline/test_discrete_bcq.py ================================================ import argparse import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_cartpole_data import expert_file_name, gather_data from tianshou.algorithm import Algorithm, DiscreteBCQ from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DiscreteActor from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--unlikely_action_threshold", type=float, default=0.6) parser.add_argument("--imitation_logits_penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=2000) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume", action="store_true") parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] def test_discrete_bcq( args: argparse.Namespace = get_args(), enable_assertions: bool = True, ) -> None: # envs env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 185} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) policy_net = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, ).to(args.device) imitation_net = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteBCQPolicy( model=policy_net, imitator=imitation_net, action_space=env.action_space, unlikely_action_threshold=args.unlikely_action_threshold, eps_inference=args.eps_test, ) algorithm: DiscreteBCQ = DiscreteBCQ( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, imitation_logits_penalty=args.imitation_logits_penalty, ) # buffer buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) else: buffer = gather_data() # collector test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # logger log_path = os.path.join(args.logdir, args.task, "discrete_bcq") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( algorithm.state_dict(), ckpt_path, ) return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") # train result = algorithm.run_training( OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None: test_discrete_bcq() args.resume = True test_discrete_bcq(args) def test_discrete_bcq_determinism() -> None: main_fn = lambda args: test_discrete_bcq(args, enable_assertions=False) AlgorithmDeterminismTest("offline_discrete_bcq", main_fn, get_args(), is_offline=True).run() ================================================ FILE: test/offline/test_discrete_cql.py ================================================ import argparse import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_cartpole_data import expert_file_name, gather_data from tianshou.algorithm import Algorithm, DiscreteCQL from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--num_quantiles", type=int, default=200) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--min_q_weight", type=float, default=10.0) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=1000) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64]) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_discrete_cql( args: argparse.Namespace = get_args(), enable_assertions: bool = True, ) -> None: # envs env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 170} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax=False, num_atoms=args.num_quantiles, ) optim = AdamOptimizerFactory(lr=args.lr) policy = QRDQNPolicy( model=net, action_space=env.action_space, ) algorithm: DiscreteCQL = DiscreteCQL( policy=policy, optim=optim, gamma=args.gamma, num_quantiles=args.num_quantiles, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, min_q_weight=args.min_q_weight, ).to(args.device) # buffer buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) else: buffer = gather_data() # collector test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, "discrete_cql") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = algorithm.run_training( OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_discrete_cql_determinism() -> None: main_fn = lambda args: test_discrete_cql(args, enable_assertions=False) AlgorithmDeterminismTest("offline_discrete_cql", main_fn, get_args(), is_offline=True).run() ================================================ FILE: test/offline/test_discrete_crr.py ================================================ import argparse import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_cartpole_data import expert_file_name, gather_data from tianshou.algorithm import Algorithm, DiscreteCRR from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=7e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=1000) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser.parse_known_args()[0] def test_discrete_crr( args: argparse.Namespace = get_args(), enable_assertions: bool = True, ) -> None: # envs env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v1": 180} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model and algorithm net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) actor = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax_output=False, ) action_dim = space_info.action_info.action_dim critic = DiscreteCritic( preprocess_net=net, hidden_sizes=args.hidden_sizes, last_size=action_dim, ) optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteActorPolicy( actor=actor, action_space=env.action_space, ) algorithm: DiscreteCRR = DiscreteCRR( policy=policy, critic=critic, optim=optim, gamma=args.gamma, target_update_freq=args.target_update_freq, ).to(args.device) # buffer buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) else: buffer = gather_data() # collector test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, "discrete_crr") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = algorithm.run_training( OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_discrete_crr_determinism() -> None: main_fn = lambda args: test_discrete_crr(args, enable_assertions=False) AlgorithmDeterminismTest("offline_discrete_crr", main_fn, get_args(), is_offline=True).run() ================================================ FILE: test/offline/test_gail.py ================================================ import argparse import os import pickle import gymnasium as gym import numpy as np import torch from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data from tianshou.algorithm import GAIL, Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--disc_lr", type=float, default=5e-4) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=150000) parser.add_argument("--collection_step_num_episodes", type=int, default=16) parser.add_argument("--update_step_num_repetitions", type=int, default=2) parser.add_argument("--disc_update_num", type=int, default=2) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special parser.add_argument("--vf_coef", type=float, default=0.25) parser.add_argument("--ent_coef", type=float, default=0.0) parser.add_argument("--eps_clip", type=float, default=0.2) parser.add_argument("--max_grad_norm", type=float, default=0.5) parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument("--dual_clip", type=float, default=None) parser.add_argument("--value_clip", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--resume", action="store_true") parser.add_argument("--save_interval", type=int, default=4) parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) else: buffer = gather_data() env = gym.make(args.task) if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action training_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_training_envs)] ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action, ).to( args.device, ) critic = ContinuousCritic( preprocess_net=Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = AdamOptimizerFactory(lr=args.lr) # discriminator disc_net = ContinuousCritic( preprocess_net=Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, activation=torch.nn.Tanh, concat=True, ), ).to(args.device) for m in disc_net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) disc_optim = AdamOptimizerFactory(lr=args.disc_lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, ) algorithm: GAIL = GAIL( policy=policy, critic=critic, optim=optim, expert_buffer=buffer, disc_net=disc_net, disc_optim=disc_optim, disc_update_num=args.disc_update_num, gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, return_scaling=args.return_scaling, advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, dual_clip=args.dual_clip, value_clip=args.value_clip, gae_lambda=args.gae_lambda, ) # collector training_collector = Collector[CollectStats]( algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), ) test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "gail") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( algorithm.state_dict(), ckpt_path, ) return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") # trainer result = algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, test_in_training=True, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_gail_determinism() -> None: main_fn = lambda args: test_gail(args, enable_assertions=False) AlgorithmDeterminismTest("offline_gail", main_fn, get_args(), is_offline=True).run() ================================================ FILE: test/offline/test_td3_bc.py ================================================ import argparse import datetime import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data from tianshou.algorithm import TD3BC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--actor_lr", type=float, default=1e-3) parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=500) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--alpha", type=float, default=2.5) parser.add_argument("--exploration_noise", type=float, default=0.1) parser.add_argument("--policy_noise", type=float, default=0.2) parser.add_argument("--noise_clip", type=float, default=0.5) parser.add_argument("--update_actor_freq", type=int, default=2) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eval_freq", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: with open(args.load_buffer_name, "rb") as f: buffer = pickle.load(f) else: buffer = gather_data() env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action if args.reward_threshold is None: # too low? default_reward_threshold = {"Pendulum-v0": -1200, "Pendulum-v1": -1200} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) args.state_dim = space_info.action_info.action_dim args.action_dim = space_info.observation_info.obs_dim test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # actor network net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, ) actor = ContinuousActorDeterministic( preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic networks net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # policy and algorithm policy = ContinuousDeterministicPolicy( actor=actor, action_space=env.action_space, exploration_noise=GaussianNoise(sigma=args.exploration_noise), ) algorithm: TD3BC = TD3BC( policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, alpha=args.alpha, n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector # buffer has been gathered # training_collector = Collector[CollectStats](policy, training_envs, buffer, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs) # logger t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f"seed_{args.seed}_{t0}-{args.task.replace('-', '_')}_td3_bc" log_path = os.path.join(args.logdir, args.task, "td3_bc", log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train result = algorithm.run_training( OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, ) ) if enable_assertions: assert stop_fn(result.best_reward) def test_td3_bc_determinism() -> None: main_fn = lambda args: test_td3_bc(args, enable_assertions=False) AlgorithmDeterminismTest("offline_td3_bc", main_fn, get_args(), is_offline=True).run() ================================================ FILE: test/pettingzoo/pistonball.py ================================================ import argparse import os import warnings import gymnasium as gym import numpy as np import torch from pettingzoo.butterfly import pistonball_v6 from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import DQN, Algorithm, MultiAgentOffPolicyAlgorithm from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, InfoStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=2000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument( "--gamma", type=float, default=0.9, help="a smaller gamma favors earlier win", ) parser.add_argument( "--n_pistons", type=int, default=3, help="Number of pistons(agents) in the env", ) parser.add_argument("--n_step", type=int, default=100) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=3) parser.add_argument("--epoch_num_steps", type=int, default=500) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=100) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( "--watch", default=False, action="store_true", help="no training, watch the play of pre-trained models", ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser def get_args() -> argparse.Namespace: parser = get_parser() return parser.parse_known_args()[0] def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: return PettingZooEnv(pistonball_v6.env(continuous=False, n_pistons=args.n_pistons)) def get_agents( args: argparse.Namespace = get_args(), agents: list[OffPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[Algorithm, list[torch.optim.Optimizer] | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] if isinstance(env.observation_space, gym.spaces.Dict) else env.observation_space ) args.state_shape = observation_space.shape or int(observation_space.n) args.action_shape = env.action_space.shape or int(env.action_space.n) if agents is not None: algorithms = agents else: algorithms = [] optims = [] for _ in range(args.n_pistons): # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, eps_inference=args.eps_test, ) agent: DQN = DQN( policy=policy, optim=optim, gamma=args.gamma, n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) algorithms.append(agent) optims.append(optim) ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=algorithms, env=env) return ma_algorithm, optims, env.agents def train_agent( args: argparse.Namespace = get_args(), agents: list[OffPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[InfoStats, Algorithm]: training_envs = DummyVectorEnv([get_env for _ in range(args.num_training_envs)]) test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) marl_algorithm, optim, agents = get_agents(args, agents=agents, optims=optims) # collector training_collector = Collector[CollectStats]( marl_algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True) training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: pass def stop_fn(mean_rewards: float) -> bool: return False def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] # trainer result = marl_algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, update_step_num_gradient_steps_per_sample=args.update_per_step, logger=logger, test_in_training=False, multi_agent_return_reduction=reward_metric, ) ) return result, marl_algorithm def watch(args: argparse.Namespace = get_args(), policy: Algorithm | None = None) -> None: env = DummyVectorEnv([get_env]) if not policy: warnings.warn( "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) collector = Collector[CollectStats](policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) result.pprint_asdict() ================================================ FILE: test/pettingzoo/pistonball_continuous.py ================================================ import argparse import os import warnings from typing import Any import gymnasium as gym import numpy as np import torch from pettingzoo.butterfly import pistonball_v6 from torch import nn from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import PPO, Algorithm from tianshou.algorithm.algorithm_base import OnPolicyAlgorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.multiagent.marl import MultiAgentOnPolicyAlgorithm from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ModuleWithVectorOutput from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic class DQNet(ModuleWithVectorOutput): """Reference: Human-level control through deep reinforcement learning.""" def __init__( self, c: int, h: int, w: int, device: str | int | torch.device = "cpu", ) -> None: net = nn.Sequential( nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True), nn.Flatten(), ) with torch.no_grad(): output_dim = np.prod(net(torch.zeros(1, c, h, w)).shape[1:]) super().__init__(int(output_dim)) self.device = device self.c = c self.h = h self.w = w self.net = net def forward( self, x: np.ndarray | torch.Tensor, state: Any | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, Any]: r"""Mapping: x -> Q(x, \*).""" if info is None: info = {} x = torch.as_tensor(x, device=self.device, dtype=torch.float32) return self.net(x.reshape(-1, self.c, self.w, self.h)), state def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=2000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument( "--gamma", type=float, default=0.9, help="a smaller gamma favors earlier win", ) parser.add_argument( "--n_pistons", type=int, default=3, help="Number of pistons(agents) in the env", ) parser.add_argument("--n_step", type=int, default=100) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=500) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--collection_step_num_episodes", type=int, default=16) parser.add_argument("--update_step_num_repetitions", type=int, default=2) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument( "--watch", default=False, action="store_true", help="no training, watch the play of pre-trained models", ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special parser.add_argument("--vf_coef", type=float, default=0.25) parser.add_argument("--ent_coef", type=float, default=0.0) parser.add_argument("--eps_clip", type=float, default=0.2) parser.add_argument("--max_grad_norm", type=float, default=0.5) parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument("--dual_clip", type=float, default=None) parser.add_argument("--value_clip", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--resume", action="store_true") parser.add_argument("--save_interval", type=int, default=4) parser.add_argument("--render", type=float, default=0.0) return parser def get_args() -> argparse.Namespace: parser = get_parser() return parser.parse_known_args()[0] def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: return PettingZooEnv(pistonball_v6.env(continuous=True, n_pistons=args.n_pistons)) def get_agents( args: argparse.Namespace = get_args(), agents: list[OnPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[Algorithm, list[torch.optim.Optimizer] | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] if isinstance(env.observation_space, gym.spaces.Dict) else env.observation_space ) args.state_shape = observation_space.shape or observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] if agents is not None: algorithms = agents else: algorithms = [] optims = [] for _ in range(args.n_pistons): # model net = DQNet( observation_space.shape[2], observation_space.shape[1], observation_space.shape[0], device=args.device, ).to(args.device) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action, ).to(args.device) net2 = DQNet( observation_space.shape[2], observation_space.shape[1], observation_space.shape[0], device=args.device, ).to(args.device) critic = ContinuousCritic(preprocess_net=net2).to(args.device) for m in set(actor.modules()).union(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = AdamOptimizerFactory(lr=args.lr) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, action_scaling=True, action_bound_method="clip", ) algorithm: PPO = PPO( policy=policy, critic=critic, optim=optim, gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, return_scaling=args.return_scaling, advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, # dual_clip=args.dual_clip, # dual clip cause monotonically increasing log_std :) value_clip=args.value_clip, gae_lambda=args.gae_lambda, ) algorithms.append(algorithm) optims.append(optim) ma_algorithm = MultiAgentOnPolicyAlgorithm( algorithms=algorithms, env=env, ) return ma_algorithm, optims, env.agents def train_agent( args: argparse.Namespace = get_args(), agents: list[OnPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[InfoStats, Algorithm]: training_envs = DummyVectorEnv([get_env for _ in range(args.num_training_envs)]) test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) marl_algorithm, optim, agents = get_agents(args, agents=agents, optims=optims) # collector training_collector = Collector[CollectStats]( marl_algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=False, # True ) test_collector = Collector[CollectStats](marl_algorithm, test_envs) # training_collector.collect(n_step=args.batch_size * args.num_training_envs, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) def save_best_fn(policy: Algorithm) -> None: pass def stop_fn(mean_rewards: float) -> bool: return False def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] # train result = marl_algorithm.run_training( OnPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, test_in_training=True, ) ) return result, marl_algorithm def watch(args: argparse.Namespace = get_args(), policy: Algorithm | None = None) -> None: env = DummyVectorEnv([get_env]) if not policy: warnings.warn( "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) collector = Collector[CollectStats](policy, env) collector_result = collector.collect(n_episode=1, render=args.render) collector_result.pprint_asdict() ================================================ FILE: test/pettingzoo/test_pistonball.py ================================================ import argparse import pytest from pistonball import get_args, train_agent, watch @pytest.mark.skip(reason="Performance bound was never tested, no point in running this for now") def test_piston_ball(args: argparse.Namespace = get_args()) -> None: if args.watch: watch(args) return train_agent(args) # assert result.best_reward >= args.win_rate ================================================ FILE: test/pettingzoo/test_pistonball_continuous.py ================================================ import argparse import pytest from pistonball_continuous import get_args, train_agent, watch @pytest.mark.skip(reason="runtime too long and unstable result") def test_piston_ball_continuous(args: argparse.Namespace = get_args()) -> None: if args.watch: watch(args) return result, agent = train_agent(args) # assert result.best_reward >= 30.0 ================================================ FILE: test/pettingzoo/test_tic_tac_toe.py ================================================ import argparse from tic_tac_toe import get_args, train_agent, watch def test_tic_tac_toe(args: argparse.Namespace = get_args()) -> None: if args.watch: watch(args) return result, agent = train_agent(args) assert result.best_reward >= args.win_rate ================================================ FILE: test/pettingzoo/tic_tac_toe.py ================================================ import argparse import os from copy import deepcopy from functools import partial import gymnasium import numpy as np import torch from pettingzoo.classic import tictactoe_v3 from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import ( DQN, Algorithm, MARLRandomDiscreteMaskedOffPolicyAlgorithm, MultiAgentOffPolicyAlgorithm, ) from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net def get_env(render_mode: str | None = None) -> PettingZooEnv: return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode)) def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps_test", type=float, default=0.05) parser.add_argument("--eps_train", type=float, default=0.1) parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument( "--gamma", type=float, default=0.9, help="a smaller gamma favors earlier win", ) parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=50) parser.add_argument("--epoch_num_steps", type=int, default=1000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_training_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.1) parser.add_argument( "--win_rate", type=float, default=0.6, help="the expected winning rate: Optimal policy can get 0.7", ) parser.add_argument( "--watch", default=False, action="store_true", help="no training, watch the play of pre-trained models", ) parser.add_argument( "--agent_id", type=int, default=2, help="the learned agent plays as the agent_id-th player. Choices are 1 and 2.", ) parser.add_argument( "--resume_path", type=str, default="", help="the path of agent pth file for resuming from a pre-trained agent", ) parser.add_argument( "--opponent_path", type=str, default="", help="the path of opponent agent pth file for resuming from a pre-trained agent", ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) return parser def get_args() -> argparse.Namespace: parser = get_parser() return parser.parse_known_args()[0] def get_agents( args: argparse.Namespace = get_args(), agent_learn: OffPolicyAlgorithm | None = None, agent_opponent: OffPolicyAlgorithm | None = None, optim: OptimizerFactory | None = None, ) -> tuple[MultiAgentOffPolicyAlgorithm, torch.optim.Optimizer | None, list]: env = get_env() observation_space = ( env.observation_space.spaces["observation"] if isinstance(env.observation_space, gymnasium.spaces.Dict) else env.observation_space ) args.state_shape = observation_space.shape or int(observation_space.n) args.action_shape = env.action_space.shape or int(env.action_space.n) if agent_learn is None: # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, ).to(args.device) if optim is None: optim = AdamOptimizerFactory(lr=args.lr) algorithm = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, eps_inference=args.eps_test, ) agent_learn = DQN( policy=algorithm, optim=optim, n_step_return_horizon=args.n_step, gamma=args.gamma, target_update_freq=args.target_update_freq, ) if args.resume_path: agent_learn.load_state_dict(torch.load(args.resume_path)) if agent_opponent is None: if args.opponent_path: agent_opponent = deepcopy(agent_learn) agent_opponent.load_state_dict(torch.load(args.opponent_path)) else: agent_opponent = MARLRandomDiscreteMaskedOffPolicyAlgorithm( action_space=env.action_space ) if args.agent_id == 1: agents = [agent_learn, agent_opponent] else: agents = [agent_opponent, agent_learn] ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env) return ma_algorithm, optim, env.agents def train_agent( args: argparse.Namespace = get_args(), agent_learn: OffPolicyAlgorithm | None = None, agent_opponent: OffPolicyAlgorithm | None = None, optim: OptimizerFactory | None = None, ) -> tuple[InfoStats, OffPolicyAlgorithm]: training_envs = DummyVectorEnv([get_env for _ in range(args.num_training_envs)]) test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) training_envs.seed(args.seed) test_envs.seed(args.seed) marl_algorithm, optim, agents = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim, ) # collector training_collector = Collector[CollectStats]( marl_algorithm, training_envs, VectorReplayBuffer(args.buffer_size, len(training_envs)), exploration_noise=True, ) test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True) training_collector.reset() training_collector.collect(n_step=args.batch_size * args.num_training_envs) # log log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn") writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) player_agent_id = agents[args.agent_id - 1] def save_best_fn(policy: Algorithm) -> None: if hasattr(args, "model_save_path"): model_save_path = args.model_save_path else: model_save_path = os.path.join(args.logdir, "tic_tac_toe", "dqn", "policy.pth") torch.save(policy.get_algorithm(player_agent_id).state_dict(), model_save_path) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.win_rate def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, args.agent_id - 1] # trainer result = marl_algorithm.run_training( OffPolicyTrainerParams( training_collector=training_collector, test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, update_step_num_gradient_steps_per_sample=args.update_per_step, logger=logger, test_in_training=False, multi_agent_return_reduction=reward_metric, ) ) return result, marl_algorithm.get_algorithm(player_agent_id) def watch( args: argparse.Namespace = get_args(), agent_learn: OffPolicyAlgorithm | None = None, agent_opponent: OffPolicyAlgorithm | None = None, ) -> None: env = DummyVectorEnv([partial(get_env, render_mode="human")]) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) collector = Collector[CollectStats](policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render, reset_before_collect=True) result.pprint_asdict() ================================================ FILE: tianshou/__init__.py ================================================ # isort: skip_file # NOTE: Import order is important to avoid circular import errors! from tianshou import data, env, exploration, algorithm, trainer, utils __version__ = "2.0.0" def _register_log_config_callback() -> None: from sensai.util import logging def configure() -> None: logging.getLogger("numba").setLevel(logging.INFO) logging.set_configure_callback(configure) _register_log_config_callback() __all__ = [ "algorithm", "data", "env", "exploration", "trainer", "utils", ] ================================================ FILE: tianshou/algorithm/__init__.py ================================================ """Algorithm package.""" # isort:skip_file from tianshou.algorithm.algorithm_base import Algorithm, TrainingStats from tianshou.algorithm.modelfree.reinforce import Reinforce from tianshou.algorithm.modelfree.dqn import DQN from tianshou.algorithm.modelfree.ddpg import DDPG from tianshou.algorithm.random import MARLRandomDiscreteMaskedOffPolicyAlgorithm from tianshou.algorithm.modelfree.bdqn import BDQN from tianshou.algorithm.modelfree.c51 import C51 from tianshou.algorithm.modelfree.rainbow import RainbowDQN from tianshou.algorithm.modelfree.qrdqn import QRDQN from tianshou.algorithm.modelfree.iqn import IQN from tianshou.algorithm.modelfree.fqf import FQF from tianshou.algorithm.modelfree.a2c import A2C from tianshou.algorithm.modelfree.npg import NPG from tianshou.algorithm.modelfree.ppo import PPO from tianshou.algorithm.modelfree.trpo import TRPO from tianshou.algorithm.modelfree.td3 import TD3 from tianshou.algorithm.modelfree.sac import SAC from tianshou.algorithm.modelfree.redq import REDQ from tianshou.algorithm.modelfree.discrete_sac import DiscreteSAC from tianshou.algorithm.imitation.imitation_base import OffPolicyImitationLearning from tianshou.algorithm.imitation.bcq import BCQ from tianshou.algorithm.imitation.cql import CQL from tianshou.algorithm.imitation.td3_bc import TD3BC from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQ from tianshou.algorithm.imitation.discrete_cql import DiscreteCQL from tianshou.algorithm.imitation.discrete_crr import DiscreteCRR from tianshou.algorithm.imitation.gail import GAIL from tianshou.algorithm.modelbased.psrl import PSRL from tianshou.algorithm.modelbased.icm import ICMOffPolicyWrapper from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.algorithm.multiagent.marl import MultiAgentOffPolicyAlgorithm ================================================ FILE: tianshou/algorithm/algorithm_base.py ================================================ import logging import time from abc import ABC, abstractmethod from collections.abc import Callable, Mapping from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np import torch from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete from numba import njit from numpy.typing import ArrayLike from overrides import override from sensai.util.hash import pickle_hash from sensai.util.helper import mark_used from torch import nn from torch.nn.modules.module import ( _IncompatibleKeys, # we have to do this since we override load_state_dict ) from torch.optim.lr_scheduler import LRScheduler from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as from tianshou.data.batch import Batch, BatchProtocol, TArr from tianshou.data.buffer.buffer_base import TBuffer from tianshou.data.types import ( ActBatchProtocol, ActStateBatchProtocol, BatchWithReturnsProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.utils.determinism import TraceLogger from tianshou.utils.lagged_network import ( EvalModeModuleWrapper, LaggedNetworkCollection, ) from tianshou.utils.net.common import RandomActor from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode if TYPE_CHECKING: from tianshou.data.stats import InfoStats from tianshou.trainer import ( OfflineTrainer, OfflineTrainerParams, OffPolicyTrainer, OffPolicyTrainerParams, OnPolicyTrainer, OnPolicyTrainerParams, Trainer, TrainerParams, ) mark_used(TrainerParams) logger = logging.getLogger(__name__) TArrOrActBatch = TypeVar("TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") @dataclass(kw_only=True) class TrainingStats(DataclassPPrintMixin): _non_loss_fields = ("train_time", "smoothed_loss") train_time: float = 0.0 """The time for learning models.""" # TODO: modified in the trainer but not used anywhere else. Should be refactored. smoothed_loss: dict = field(default_factory=dict) """The smoothed loss statistics of the policy learn step.""" # Mainly so that we can override this in the TrainingStatsWrapper def _get_self_dict(self) -> dict[str, Any]: return self.__dict__ def get_loss_stats_dict(self) -> dict[str, float]: """Return loss statistics as a dict for logging. Returns a dict with all fields except train_time and smoothed_loss. Moreover, fields with value None excluded, and instances of SequenceSummaryStats are replaced by their mean. """ result = {} for k, v in self._get_self_dict().items(): if k.startswith("_"): logger.debug(f"Skipping {k=} as it starts with an underscore.") continue if k in self._non_loss_fields or v is None: continue if isinstance(v, SequenceSummaryStats): result[k] = v.mean else: result[k] = v return result class TrainingStatsWrapper(TrainingStats): _setattr_frozen = False _training_stats_public_fields = TrainingStats.__dataclass_fields__.keys() def __init__(self, wrapped_stats: TrainingStats) -> None: """In this particular case, super().__init__() should be called LAST in the subclass init.""" self._wrapped_stats = wrapped_stats # HACK: special sauce for the existing attributes of the base TrainingStats class # for some reason, delattr doesn't work here, so we need to delegate their handling # to the wrapped stats object by always keeping the value there and in self in sync # see also __setattr__ for k in self._training_stats_public_fields: super().__setattr__(k, getattr(self._wrapped_stats, k)) self._setattr_frozen = True @override def _get_self_dict(self) -> dict[str, Any]: return {**self._wrapped_stats._get_self_dict(), **self.__dict__} @property def wrapped_stats(self) -> TrainingStats: return self._wrapped_stats def __getattr__(self, name: str) -> Any: return getattr(self._wrapped_stats, name) def __setattr__(self, name: str, value: Any) -> None: """Setattr logic for wrapper of a dataclass with default values. 1. If name exists directly in self, set it there. 2. If it exists in self._wrapped_stats, set it there instead. 3. Special case: if name is in the base TrainingStats class, keep it in sync between self and the _wrapped_stats. 4. If name doesn't exist in either and attribute setting is frozen, raise an AttributeError. """ # HACK: special sauce for the existing attributes of the base TrainingStats class, see init # Need to keep them in sync with the wrapped stats object if name in self._training_stats_public_fields: setattr(self._wrapped_stats, name, value) super().__setattr__(name, value) return if not self._setattr_frozen: super().__setattr__(name, value) return if not hasattr(self, name): raise AttributeError( f"Setting new attributes on StatsWrappers outside of init is not allowed. " f"Tried to set {name=}, {value=} on {self.__class__.__name__}. \n" f"NOTE: you may get this error if you call super().__init__() in your subclass init too early! " f"The call to super().__init__() should be the last call in your subclass init.", ) if hasattr(self._wrapped_stats, name): setattr(self._wrapped_stats, name, value) else: super().__setattr__(name, value) class Policy(nn.Module, ABC): """Represents a policy, which provides the fundamental mapping from observations to actions.""" def __init__( self, action_space: gym.Space, observation_space: gym.Space | None = None, action_scaling: bool = False, action_bound_method: Literal["clip", "tanh"] | None = "clip", ): """ :param action_space: the environment's action_space. :param observation_space: the environment's observation space. :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. This applies to continuous action spaces only (gym.spaces.Box) and has no effect for discrete spaces. When enabled, policy outputs are expected to be in the normalized range [-1, 1] (after bounding), and are then linearly transformed to the actual required range. This improves neural network training stability, allows the same algorithm to work across environments with different action ranges, and standardizes exploration strategies. Should be disabled if the actor model already produces outputs in the correct range. :param action_bound_method: the method used for bounding actions in continuous action spaces to the range [-1, 1] before scaling them to the environment's action space (provided that `action_scaling` is enabled). This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None for discrete spaces. When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly constrains outputs to [-1, 1] while preserving gradients. The choice of bounding method affects both training dynamics and exploration behavior. Clipping provides hard boundaries but may create plateau regions in the gradient landscape, while tanh provides smoother transitions but can compress sensitivity near the boundaries. Should be set to None if the actor model inherently produces bounded outputs. Typically used together with `action_scaling=True`. """ allowed_action_bound_methods = ("clip", "tanh") if ( action_bound_method is not None and action_bound_method not in allowed_action_bound_methods ): raise ValueError( f"Got invalid {action_bound_method=}. " f"Valid values are: {allowed_action_bound_methods}.", ) if action_scaling and not isinstance(action_space, Box): raise ValueError( f"action_scaling can only be True when action_space is Box but got: {action_space}", ) super().__init__() self.observation_space = observation_space self.action_space = action_space if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary): action_type = "discrete" elif isinstance(action_space, Box): action_type = "continuous" else: raise ValueError(f"Unsupported action space: {action_space}.") self._action_type = cast(Literal["discrete", "continuous"], action_type) self.agent_id = 0 self.action_scaling = action_scaling self.action_bound_method = action_bound_method self.is_within_training_step = False """ flag indicating whether we are currently within a training step, which encompasses data collection for training (in online RL algorithms) and the policy update (gradient steps). It can be used, for example, to control whether a flag controlling deterministic evaluation should indeed be applied, because within a training step, we typically always want to apply stochastic evaluation (even if such a flag is enabled), as well as stochastic action computation for q-targets (e.g. in SAC based algorithms). This flag should normally remain False and should be set to True only by the algorithm which performs training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer, the user should ensure that this flag is set correctly. """ self._compile() @property def action_type(self) -> Literal["discrete", "continuous"]: return self._action_type @staticmethod def _action_to_numpy(act: TArr) -> np.ndarray: act = to_numpy(act) # NOTE: to_numpy could confusingly also return a Batch if not isinstance(act, np.ndarray): raise ValueError( f"act should have been be a numpy.ndarray, but got {type(act)}.", ) return act def map_action( self, act: TArr, ) -> np.ndarray: """Map raw network output to action range in gym's env.action_space. This function is called in :meth:`~tianshou.data.Collector.collect` and only affects action sending to env. Remapped action will not be stored in buffer and thus can be viewed as a part of env (a black box action transformation). Action mapping includes 2 standard procedures: bounding and scaling. Bounding procedure expects original action range is (-inf, inf) and maps it to [-1, 1], while scaling procedure expects original action range is (-1, 1) and maps it to [action_space.low, action_space.high]. Bounding procedure is applied first. :param act: a data batch or numpy.ndarray which is the action taken by policy.forward. :return: action in the same form of input "act" but remap to the target action space. """ act = self._action_to_numpy(act) if isinstance(self.action_space, gym.spaces.Box): if self.action_bound_method == "clip": act = np.clip(act, -1.0, 1.0) elif self.action_bound_method == "tanh": act = np.tanh(act) if self.action_scaling: assert np.min(act) >= -1.0 and np.max(act) <= 1.0, ( f"action scaling only accepts raw action range = [-1, 1], but got: {act}" ) low, high = self.action_space.low, self.action_space.high act = low + (high - low) * (act + 1.0) / 2.0 return act def map_action_inverse( self, act: TArr, ) -> np.ndarray: """Inverse operation to :meth:`map_action`. This function is called in :meth:`~tianshou.data.Collector.collect` for random initial steps. It scales [action_space.low, action_space.high] to the value ranges of policy.forward. :param act: a data batch, list or numpy.ndarray which is the action taken by gym.spaces.Box.sample(). :return: action remapped. """ act = self._action_to_numpy(act) if isinstance(self.action_space, gym.spaces.Box): if self.action_scaling: low, high = self.action_space.low, self.action_space.high scale = high - low eps = np.finfo(np.float32).eps.item() scale[scale < eps] += eps act = (act - low) * 2.0 / scale - 1.0 if self.action_bound_method == "tanh": act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 return act def compute_action( self, obs: ArrayLike, info: dict[str, Any] | None = None, state: dict | BatchProtocol | np.ndarray | None = None, ) -> np.ndarray | int: """Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info. :param obs: observation from the gym's env. :param info: information given by the gym's env. :param state: the hidden state of RNN policy, used for recurrent policy. :return: action as int (for discrete env's) or array (for continuous ones). """ obs = np.array(obs) # convert array-like to array (e.g. LazyFrames) obs = obs[None, :] # add batch dimension obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info)) act = self.forward(obs_batch, state=state).act.squeeze() if isinstance(act, torch.Tensor): act = act.detach().cpu().numpy() act = self.map_action(act) if isinstance(self.action_space, Discrete): # could be an array of shape (), easier to just convert to int act = int(act) # type: ignore return act @staticmethod def _compile() -> None: f64 = np.array([0, 1], dtype=np.float64) f32 = np.array([0, 1], dtype=np.float32) b = np.array([False, True], dtype=np.bool_) i64 = np.array([[0, 1]], dtype=np.int64) _gae(f64, f64, f64, b, 0.1, 0.1) _gae(f32, f32, f64, b, 0.1, 0.1) _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") def add_exploration_noise( self, act: _TArrOrActBatch, batch: ObsBatchProtocol, ) -> _TArrOrActBatch: """(Optionally) adds noise to an actions computed by the policy's forward method for exploration purposes. NOTE: The base implementation does not add any noise, but subclasses can override this method to add appropriate mechanisms for adding noise. :param act: a data batch or numpy.ndarray containing actions computed by the policy's forward method. :param batch: the corresponding input batch that was passed to forward; provided for advanced usage. :return: actions in the same format as the input `act` but with added exploration noise (if implemented - otherwise returns `act` unchanged). """ return act class LaggedNetworkAlgorithmMixin(ABC): """ Base class for an algorithm mixin which adds support for lagged networks (target networks) whose weights are updated periodically. """ def __init__(self) -> None: self._lagged_networks = LaggedNetworkCollection() def _add_lagged_network(self, src: torch.nn.Module) -> EvalModeModuleWrapper: """ Adds a lagged network to the collection, returning the target network, which is forced to eval mode. The target network is a copy of the source network, which, however, supports only the forward method (hence the type torch.nn.Module); attribute access is not supported. :param src: the source network whose parameters are to be copied to the target network :return: the target network, which supports only the forward method and is forced to eval mode """ return self._lagged_networks.add_lagged_network(src) @abstractmethod def _update_lagged_network_weights(self) -> None: pass class LaggedNetworkFullUpdateAlgorithmMixin(LaggedNetworkAlgorithmMixin): """ Algorithm mixin which adds support for lagged networks (target networks) where weights are updated by fully copying the weights of the source network to the target network. """ def _update_lagged_network_weights(self) -> None: self._lagged_networks.full_parameter_update() class LaggedNetworkPolyakUpdateAlgorithmMixin(LaggedNetworkAlgorithmMixin): """ Algorithm mixin which adds support for lagged networks (target networks) where weights are updated via Polyak averaging (soft update using a convex combination of the parameters of the source and target networks with weight `tau` and `1-tau` respectively). """ def __init__(self, tau: float) -> None: """ :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being the fraction with which to retain the target network's parameters. """ super().__init__() self.tau = tau def _update_lagged_network_weights(self) -> None: self._lagged_networks.polyak_parameter_update(self.tau) TPolicy = TypeVar("TPolicy", bound=Policy) TTrainerParams = TypeVar("TTrainerParams", bound="TrainerParams") class Algorithm(torch.nn.Module, Generic[TPolicy, TTrainerParams], ABC): """ The base class for reinforcement learning algorithms in Tianshou. An algorithm critically defines how to update the parameters of neural networks based on a batch data, optionally applying pre-processing and post-processing to the data. The actual update step is highly algorithm-specific and thus is defined in subclasses. """ _STATE_DICT_KEY_OPTIMIZERS = "_optimizers" def __init__( self, *, policy: TPolicy, ) -> None: """:param policy: the policy""" super().__init__() self.policy: TPolicy = policy self.lr_schedulers: list[LRScheduler] = [] self._optimizers: list[Algorithm.Optimizer] = [] """ list of optimizers associated with the algorithm (created via `_create_optimizer`), whose states will be returned when calling `state_dict` and which will be restored when calling `load_state_dict` accordingly """ class Optimizer: """Wrapper for a torch optimizer that optionally performs gradient clipping.""" def __init__( self, optim: torch.optim.Optimizer, module: torch.nn.Module, max_grad_norm: float | None = None, ) -> None: """ :param optim: the optimizer :param module: the module whose parameters are being affected by `optim` :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. When not None, gradients will be rescaled using to ensure their L2 norm does not exceed this value. This prevents exploding gradients and stabilizes training by limiting the magnitude of parameter updates. Set to None to disable gradient clipping. """ super().__init__() self._optim = optim self._module = module self._max_grad_norm = max_grad_norm def step( self, loss: torch.Tensor, retain_graph: bool | None = None, create_graph: bool = False, ) -> None: """Performs an optimizer step, optionally applying gradient clipping (if configured at construction). :param loss: the loss to backpropagate :param retain_graph: passed on to `backward` :param create_graph: passed on to `backward` """ self._optim.zero_grad() loss.backward(retain_graph=retain_graph, create_graph=create_graph) if self._max_grad_norm is not None: nn.utils.clip_grad_norm_(self._module.parameters(), max_norm=self._max_grad_norm) self._optim.step() def state_dict(self) -> dict: """Returns the `state_dict` of the wrapped optimizer.""" return self._optim.state_dict() def load_state_dict(self, state_dict: dict) -> None: """Loads the given `state_dict` into the wrapped optimizer.""" self._optim.load_state_dict(state_dict) def _create_optimizer( self, module: torch.nn.Module, factory: OptimizerFactory, max_grad_norm: float | None = None, ) -> Optimizer: optimizer, lr_scheduler = factory.create_instances(module) if lr_scheduler is not None: self.lr_schedulers.append(lr_scheduler) optim = self.Optimizer(optimizer, module, max_grad_norm=max_grad_norm) self._optimizers.append(optim) return optim def state_dict(self, *args, destination=None, prefix="", keep_vars=False): # type: ignore d = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) # add optimizer states opt_key = prefix + self._STATE_DICT_KEY_OPTIMIZERS assert opt_key not in d d[opt_key] = [o.state_dict() for o in self._optimizers] return d def load_state_dict( self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False ) -> _IncompatibleKeys: # don't override type in annotation since it's is declared as Mapping in nn.Module state_dict = cast(dict[str, Any], state_dict) # restore optimizer states optimizers_state_dict = state_dict.pop(self._STATE_DICT_KEY_OPTIMIZERS) for optim, optim_state in zip(self._optimizers, optimizers_state_dict, strict=True): optim.load_state_dict(optim_state) return super().load_state_dict(state_dict, strict=strict, assign=assign) def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: """Pre-process the data from the provided replay buffer. Meant to be overridden by subclasses. Typical usage is to add new keys to the batch, e.g., to add the value function of the next state. Used in :meth:`update`, which is usually called repeatedly during training. For modifying the replay buffer only once at the beginning (e.g., for offline learning) see :meth:`process_buffer`. """ return batch def _postprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: """Post-process the data from the provided replay buffer. This will only have an effect if the buffer has the method `update_weight` and the batch has the attribute `weight`. Typical usage is to update the sampling weight in prioritized experience replay. Used in :meth:`update`. """ if hasattr(buffer, "update_weight"): if hasattr(batch, "weight"): buffer.update_weight(indices, batch.weight) else: logger.warning( "batch has no attribute 'weight', but buffer has an " "update_weight method. This is probably a mistake." "Prioritized replay is disabled for this batch.", ) def _update( self, sample_size: int | None, buffer: ReplayBuffer | None, update_with_batch_fn: Callable[[RolloutBatchProtocol], TrainingStats], ) -> TrainingStats: """Orchestrates an update step. An update involves three algorithm-specific sub-steps: * pre-processing of the batch, * performing the actual network update with the batch, and * post-processing of the batch. The return value is that of the network update call, augmented with the training time within update. :param sample_size: 0 means it will extract all the data from the buffer, otherwise it will sample a batch with given sample_size. None also means it will extract all the data from the buffer, but it will be shuffled first. :param buffer: the corresponding replay buffer. :param update_with_batch_fn: the function to call for the actual update step, which is algorithm-specific and thus provided by the subclass. :return: A dataclass object containing data to be logged (e.g., loss) """ if not self.policy.is_within_training_step: raise RuntimeError( f"update() was called outside of a training step as signalled by {self.policy.is_within_training_step=} " f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned " f"flag yourself. You can to this e.g., by using the contextmanager {policy_within_training_step.__name__}.", ) if buffer is None: return TrainingStats() start_time = time.time() batch, indices = buffer.sample(sample_size) TraceLogger.log(logger, lambda: f"Updating with batch: indices={pickle_hash(indices)}") batch = self._preprocess_batch(batch, buffer, indices) with torch_train_mode(self): training_stat = update_with_batch_fn(batch) self._postprocess_batch(batch, buffer, indices) for lr_scheduler in self.lr_schedulers: lr_scheduler.step() training_stat.train_time = time.time() - start_time return training_stat @staticmethod def value_mask(buffer: ReplayBuffer, indices: np.ndarray) -> np.ndarray: """Value mask determines whether the obs_next of buffer[indices] is valid. For instance, usually "obs_next" after "done" flag is considered to be invalid, and its q/advantage value can provide meaningless (even misleading) information, and should be set to 0 by hand. But if "done" flag is generated because timelimit of game length (info["TimeLimit.truncated"] is set to True in gym's settings), "obs_next" will instead be valid. Value mask is typically used for assisting in calculating the correct q/advantage value. :param buffer: the corresponding replay buffer. :param numpy.ndarray indices: indices of replay buffer whose "obs_next" will be judged. :return: A bool type numpy.ndarray in the same shape with indices. "True" means "obs_next" of that buffer[indices] is valid. """ return ~buffer.terminated[indices] @staticmethod def compute_episodic_return( batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, v_s_: np.ndarray | torch.Tensor | None = None, v_s: np.ndarray | torch.Tensor | None = None, gamma: float = 0.99, gae_lambda: float = 0.95, ) -> tuple[np.ndarray, np.ndarray]: r"""Compute returns over given batch. Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) to calculate q/advantage value of given batch. Returns are calculated as advantage + value, which is exactly equivalent to using :math:`TD(\lambda)` for estimating returns. Setting `v_s_` and `v_s` to None (or all zeros) and `gae_lambda` to 1.0 calculates the discounted return-to-go/ Monte-Carlo return. :param batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recognized by buffer.unfinished_index(). :param buffer: the corresponding replay buffer. :param indices: tells the batch's location in buffer, batch is equal to buffer[indices]. :param v_s_: the value function of all next states :math:`V(s')`. If None, it will be set to an array of 0. :param v_s: the value function of all current states :math:`V(s)`. If None, it is set based upon `v_s_` rolled by 1. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values (closer to 1) reduce bias but increase variance by giving more weight to longer trajectories, while lower values (closer to 0) reduce variance but increase bias by relying more on the immediate TD error and value function estimates. At λ=0, GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. :return: two numpy arrays (returns, advantage) with each shape (bsz, ). """ rew = batch.rew if v_s_ is None: assert np.isclose(gae_lambda, 1.0) v_s_ = np.zeros_like(rew) else: v_s_ = to_numpy(v_s_.flatten()) v_s_ = v_s_ * Algorithm.value_mask(buffer, indices) v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten()) end_flag = np.logical_or(batch.terminated, batch.truncated) end_flag[np.isin(indices, buffer.unfinished_index())] = True advantage = _gae(v_s, v_s_, rew, end_flag, gamma, gae_lambda) returns = advantage + v_s # normalization varies from each policy, so we don't do it here return returns, advantage @staticmethod def compute_nstep_return( batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, ) -> BatchWithReturnsProtocol: r""" Computes the n-step return for Q-learning targets, adds it to the batch and returns the resulting batch. .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. :param batch: a data batch, which is equal to buffer[indices]. :param buffer: the data buffer. :param indices: tell batch's location in buffer :param target_q_fn: a function which computes the target Q value of "obs_next" given data buffer and wanted indices (`n_step` steps ahead). :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step: the number of estimation step, should be an int greater than 0. :return: a Batch. The result will be stored in `batch.returns` as a torch.Tensor with the same shape as target_q_fn's return tensor. """ if len(indices) != len(batch): raise ValueError(f"Batch size {len(batch)} and indices size {len(indices)} mismatch.") # naming convention # I = number of indices # B = size of the replay buffer # N = n_step # A = the output dimension of target_q_fn for a single index. Presumably # this is the number of actions in the discrete case, or something like that. # 1 = 1 extra dimension # TODO: it's very weird that this is not always one! # We set the n-step-return for a single index to be the same shape as the target_q_fn. # I don't understand how a non-scalar value would make sense there, but such cases are covered by tests # support in following naming convention I = len(indices) N = n_step _indices_to_stack = [indices] for _ in range(N - 1): next_indices = buffer.next(_indices_to_stack[-1]) _indices_to_stack.append(next_indices) stacked_indices_NI = np.stack(_indices_to_stack) """The stacked indices represent a 2d array of shape `IxN` of the type [ [i_1, i_2,...], [i_(next(1)), i_(next(2)), ...], [i_(next(next(1)), ... ... ] where `next` is the subsequent transition in the buffer. """ indices_after_n_steps_I = stacked_indices_NI[-1] """Indicates indexes of transitions in buffer that occur N steps after the user provided 'indices'; they are truncated at the end of each episode""" with torch.no_grad(): target_q_torch_IA = target_q_fn(buffer, indices_after_n_steps_I) target_q_IA = to_numpy(target_q_torch_IA.reshape(I, -1)) """Represents the Q-values (one for each action) of the transition after N steps.""" target_q_IA *= Algorithm.value_mask(buffer, indices_after_n_steps_I).reshape(-1, 1) end_flag_B = buffer.done.copy() end_flag_B[buffer.unfinished_index()] = True n_step_return_IA = _nstep_return( buffer.rew, end_flag_B, target_q_IA, stacked_indices_NI, gamma, n_step, ) """The n-step return plus the last Q-values, see method's docstring""" batch.returns = to_torch_as(n_step_return_IA, target_q_torch_IA) # TODO: this is simply converting to a certain type. Why is this necessary, and why is it happening here? if hasattr(batch, "weight"): batch.weight = to_torch_as(batch.weight, target_q_torch_IA) return cast(BatchWithReturnsProtocol, batch) @abstractmethod def create_trainer(self, params: TTrainerParams) -> "Trainer": pass def run_training(self, params: TTrainerParams) -> "InfoStats": trainer = self.create_trainer(params) return trainer.run() class OnPolicyAlgorithm( Algorithm[TPolicy, "OnPolicyTrainerParams"], Generic[TPolicy], ABC, ): """Base class for on-policy RL algorithms.""" def create_trainer(self, params: "OnPolicyTrainerParams") -> "OnPolicyTrainer": from tianshou.trainer import OnPolicyTrainer return OnPolicyTrainer(self, params) @abstractmethod def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int ) -> TrainingStats: """Performs an update step based on the given batch of data, updating the network parameters. :param batch: the batch of data :param batch_size: the minibatch size for gradient updates :param repeat: the number of times to repeat the update over the whole batch :return: a dataclas object containing statistics on the learning process, including the data needed to be logged (e.g. loss values). """ def update( self, buffer: ReplayBuffer, batch_size: int | None, repeat: int, ) -> TrainingStats: update_with_batch_fn = lambda batch: self._update_with_batch( batch=batch, batch_size=batch_size, repeat=repeat ) return super()._update( sample_size=0, buffer=buffer, update_with_batch_fn=update_with_batch_fn ) class OffPolicyAlgorithm( Algorithm[TPolicy, "OffPolicyTrainerParams"], Generic[TPolicy], ABC, ): """Base class for off-policy RL algorithms.""" def create_trainer(self, params: "OffPolicyTrainerParams") -> "OffPolicyTrainer": from tianshou.trainer import OffPolicyTrainer return OffPolicyTrainer(self, params) @abstractmethod def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> TrainingStats: """Performs an update step based on the given batch of data, updating the network parameters. :param batch: the batch of data :return: a dataclas object containing statistics on the learning process, including the data needed to be logged (e.g. loss values). """ def update( self, buffer: ReplayBuffer, sample_size: int | None, ) -> TrainingStats: update_with_batch_fn = lambda batch: self._update_with_batch(batch) return super()._update( sample_size=sample_size, buffer=buffer, update_with_batch_fn=update_with_batch_fn, ) class OfflineAlgorithm( Algorithm[TPolicy, "OfflineTrainerParams"], Generic[TPolicy], ABC, ): """Base class for offline RL algorithms.""" def process_buffer(self, buffer: TBuffer) -> TBuffer: """Pre-process the replay buffer to prepare for offline learning, e.g. to add new keys.""" return buffer def run_training(self, params: "OfflineTrainerParams") -> "InfoStats": # NOTE: This override is required for correct typing when converting # an algorithm to an offline algorithm using diamond inheritance # (e.g. DiscreteCQL) in order to make it match first in the MRO return super().run_training(params) def create_trainer(self, params: "OfflineTrainerParams") -> "OfflineTrainer": from tianshou.trainer import OfflineTrainer return OfflineTrainer(self, params) @abstractmethod def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> TrainingStats: """Performs an update step based on the given batch of data, updating the network parameters. :param batch: the batch of data :return: a dataclas object containing statistics on the learning process, including the data needed to be logged (e.g. loss values). """ def update( self, buffer: ReplayBuffer, sample_size: int | None, ) -> TrainingStats: update_with_batch_fn = lambda batch: self._update_with_batch(batch) return super()._update( sample_size=sample_size, buffer=buffer, update_with_batch_fn=update_with_batch_fn, ) class OnPolicyWrapperAlgorithm( OnPolicyAlgorithm[TPolicy], Generic[TPolicy], ABC, ): """ Base class for an on-policy algorithm that is a wrapper around another algorithm. It applies the wrapped algorithm's pre-processing and post-processing methods and chains the update method of the wrapped algorithm with the wrapper's own update method. """ def __init__( self, wrapped_algorithm: OnPolicyAlgorithm[TPolicy], ): super().__init__(policy=wrapped_algorithm.policy) self.wrapped_algorithm = wrapped_algorithm def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: """Performs the pre-processing as defined by the wrapped algorithm.""" return self.wrapped_algorithm._preprocess_batch(batch, buffer, indices) def _postprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: """Performs the batch post-processing as defined by the wrapped algorithm.""" self.wrapped_algorithm._postprocess_batch(batch, buffer, indices) def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int ) -> TrainingStats: """Performs the update as defined by the wrapped algorithm, followed by the wrapper's update.""" original_stats = self.wrapped_algorithm._update_with_batch( batch, batch_size=batch_size, repeat=repeat ) return self._wrapper_update_with_batch(batch, batch_size, repeat, original_stats) @abstractmethod def _wrapper_update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, original_stats: TrainingStats, ) -> TrainingStats: pass class OffPolicyWrapperAlgorithm( OffPolicyAlgorithm[TPolicy], Generic[TPolicy], ABC, ): """ Base class for an off-policy algorithm that is a wrapper around another algorithm. It applies the wrapped algorithm's pre-processing and post-processing methods and chains the update method of the wrapped algorithm with the wrapper's own update method. """ def __init__( self, wrapped_algorithm: OffPolicyAlgorithm[TPolicy], ): super().__init__(policy=wrapped_algorithm.policy) self.wrapped_algorithm = wrapped_algorithm def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: """Performs the pre-processing as defined by the wrapped algorithm.""" return self.wrapped_algorithm._preprocess_batch(batch, buffer, indices) def _postprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: """Performs the batch post-processing as defined by the wrapped algorithm.""" self.wrapped_algorithm._postprocess_batch(batch, buffer, indices) def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> TrainingStats: """Performs the update as defined by the wrapped algorithm, followed by the wrapper's update .""" original_stats = self.wrapped_algorithm._update_with_batch(batch) return self._wrapper_update_with_batch(batch, original_stats) @abstractmethod def _wrapper_update_with_batch( self, batch: RolloutBatchProtocol, original_stats: TrainingStats ) -> TrainingStats: pass class RandomActionPolicy(Policy): def __init__( self, action_space: gym.Space, ) -> None: super().__init__(action_space=action_space) if not isinstance(action_space, gym.spaces.Discrete | gym.spaces.Box): raise NotImplementedError( f"RandomActionPolicy currently only supports Discrete and Box action spaces, but got {action_space}.", ) self.actor = RandomActor(action_space) def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ActStateBatchProtocol: act, next_state = self.actor.compute_action_batch(batch.obs), state return cast(ActStateBatchProtocol, Batch(act=act, state=next_state)) @njit def _gae( v_s: np.ndarray, v_s_: np.ndarray, rew: np.ndarray, end_flag: np.ndarray, gamma: float, gae_lambda: float, ) -> np.ndarray: r"""Computes advantages with GAE. The return is given by the output of this + v_s. Note that the advantages plus v_s is exactly the same as the TD-lambda target, which is computed by the recursive formula: .. math:: G_t^\lambda = r_t + \gamma ( \lambda G_{t+1}^\lambda + (1 - \lambda) V_{t+1} ) The GAE is computed recursively as: .. math:: \delta_t = r_t + \gamma V_{t+1} - V_t \n A_t^\lambda= \delta_t + \gamma \lambda A_{t+1}^\lambda And the following equality holds: .. math:: G_t^\lambda = A_t^\lambda+ V_t :param v_s: values in an episode, i.e. $V_t$ :param v_s_: next values in an episode, i.e. v_s shifted by 1, equivalent to $V_{t+1}$ :param rew: rewards in an episode, i.e. $r_t$ :param end_flag: boolean array indicating whether the episode is done :param gamma: the discount factor in [0, 1] for future rewards. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values (closer to 1) reduce bias but increase variance by giving more weight to longer trajectories, while lower values (closer to 0) reduce variance but increase bias by relying more on the immediate TD error and value function estimates. At λ=0, GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. :return: """ returns = np.zeros(rew.shape) delta = rew + v_s_ * gamma - v_s discount = (1.0 - end_flag) * (gamma * gae_lambda) gae = 0.0 for i in range(len(rew) - 1, -1, -1): gae = delta[i] + discount[i] * gae returns[i] = gae return returns @njit def episode_mc_return_to_go(rewards: np.ndarray, gamma: float = 0.99) -> np.ndarray: """Calculates discounted monte-carlo returns to go from rewards of a single episode. :param rewards: rewards of a single episode. Assumed to be a 1-dim array from reset till the end of the episode. :param gamma: discount factor :return: a numpy array of shape (len(rewards), ). """ len_episode = len(rewards) ret2go = np.zeros(len_episode) ret2go[-1] = rewards[-1] for j in range(len_episode - 2, -1, -1): ret2go[j] = rewards[j] + gamma * ret2go[j + 1] return ret2go @njit def _nstep_return( rew_B: np.ndarray, end_flag_B: np.ndarray, target_q_IA: np.ndarray, stacked_indices_NI: np.ndarray, gamma: float, n_step: int, ) -> np.ndarray: """Computes n-step returns starting at the transitions at the selected indices in the buffer. Importantly, this is not a pure MC n-step return but it also uses the Q-values of the obs-action pair after the n-step transition to compute the return. Thus, it computes `n_step_return + gamma^(n) * Q(s_{t+n}, a_{t+n})` where `n_step_return = r_t + gamma * r_{t+1} + ... + gamma^(n-1) * r_{t+n-1}`. See the docstring of `compute_nstep_return` for more details. The target_q_B should be the array of `Q(s_{t+n}, a_{t+n})` corresponding to the batch of rewards that started at t=0. Notation: I = number of indices B = size of the replay buffer N = n_step A = the output dimension of target_q_fn for a single index. Presumably, this is the number of actions in the discrete case, or something like that. See comments in the method `compute_nstep_return` for more details. 1 = 1 extra dimension :param rew_B: rewards of the entire replay buffer :param end_flag_B: end flags (where done=True) of the entire replay buffer :param target_q_IA: Q-values of the transitions after n steps. Passed as a 2d array of shape (I, A) :param stacked_indices_NI: indices of the transitions in the buffer of the structure [ [i_1, i_2,...], [i_(next(1)), i_(next(2)), ...], [i_(next(next(1)), ... ... ] where `next` is the subsequent transition in the buffer. """ N = n_step I, A = target_q_IA.shape gamma_buffer_N = np.ones(N + 1) for i in range(1, N + 1): gamma_buffer_N[i] = gamma_buffer_N[i - 1] * gamma target_q_IA = target_q_IA.reshape(I, -1) """Make sure tarqet_q_I has an empty extra dimension, usually already passed with the right shape, hence the input param name""" n_step_mc_returns_IA = np.zeros(target_q_IA.shape) """Will hold the n_step MC return part of the final n_step + Q-value return. """ gammas_IN = np.full(I, N) for n in range(N - 1, -1, -1): now = stacked_indices_NI[n] gammas_IN[end_flag_B[now] > 0] = n + 1 n_step_mc_returns_IA[end_flag_B[now] > 0] = 0.0 n_step_mc_returns_IA = rew_B[now].reshape(I, 1) + gamma * n_step_mc_returns_IA n_step_return_with_Q_IA = ( target_q_IA * gamma_buffer_N[gammas_IN].reshape(I, 1) + n_step_mc_returns_IA ) return n_step_return_with_Q_IA.reshape((I, A)) ================================================ FILE: tianshou/algorithm/imitation/__init__.py ================================================ ================================================ FILE: tianshou/algorithm/imitation/bcq.py ================================================ import copy from dataclasses import dataclass from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np import torch import torch.nn.functional as F from tianshou.algorithm.algorithm_base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, Policy, TrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.utils.net.continuous import VAE @dataclass(kw_only=True) class BCQTrainingStats(TrainingStats): actor_loss: float critic1_loss: float critic2_loss: float vae_loss: float TBCQTrainingStats = TypeVar("TBCQTrainingStats", bound=BCQTrainingStats) class BCQPolicy(Policy): def __init__( self, *, actor_perturbation: torch.nn.Module, action_space: gym.Space, critic: torch.nn.Module, vae: VAE, forward_sampled_times: int = 100, observation_space: gym.Space | None = None, action_scaling: bool = False, action_bound_method: Literal["clip", "tanh"] | None = "clip", ) -> None: """ :param actor_perturbation: the actor perturbation. `(s, a -> perturbed a)` :param critic: the first critic network. :param vae: the VAE network, generating actions similar to those in batch. :param forward_sampled_times: the number of sampled actions in forward function. The policy samples many actions and takes the action with the max value. :param observation_space: the environment's observation space :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. This applies to continuous action spaces only (gym.spaces.Box) and has no effect for discrete spaces. When enabled, policy outputs are expected to be in the normalized range [-1, 1] (after bounding), and are then linearly transformed to the actual required range. This improves neural network training stability, allows the same algorithm to work across environments with different action ranges, and standardizes exploration strategies. Should be disabled if the actor model already produces outputs in the correct range. :param action_bound_method: the method used for bounding actions in continuous action spaces to the range [-1, 1] before scaling them to the environment's action space (provided that `action_scaling` is enabled). This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None for discrete spaces. When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly constrains outputs to [-1, 1] while preserving gradients. The choice of bounding method affects both training dynamics and exploration behavior. Clipping provides hard boundaries but may create plateau regions in the gradient landscape, while tanh provides smoother transitions but can compress sensitivity near the boundaries. Should be set to None if the actor model inherently produces bounded outputs. Typically used together with `action_scaling=True`. """ super().__init__( action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, action_bound_method=action_bound_method, ) self.actor_perturbation = actor_perturbation self.critic = critic self.vae = vae self.forward_sampled_times = forward_sampled_times def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ActBatchProtocol: """Compute action over the given batch data.""" # There is "obs" in the Batch # obs_group: several groups. Each group has a state. device = next(self.parameters()).device obs_group: torch.Tensor = to_torch(batch.obs, device=device) act_group = [] for obs_orig in obs_group: # now obs is (state_dim) obs = (obs_orig.reshape(1, -1)).repeat(self.forward_sampled_times, 1) # now obs is (forward_sampled_times, state_dim) # decode(obs) generates action and actor perturbs it act = self.actor_perturbation(obs, self.vae.decode(obs)) # now action is (forward_sampled_times, action_dim) q1 = self.critic(obs, act) # q1 is (forward_sampled_times, 1) max_indice = q1.argmax(0) act_group.append(act[max_indice].cpu().data.numpy().flatten()) act_group = np.array(act_group) return cast(ActBatchProtocol, Batch(act=act_group)) class BCQ( OfflineAlgorithm[BCQPolicy], LaggedNetworkPolyakUpdateAlgorithmMixin, ): """Implementation of Batch-Constrained Deep Q-learning (BCQ) algorithm. arXiv:1812.02900.""" def __init__( self, *, policy: BCQPolicy, actor_perturbation_optim: OptimizerFactory, critic_optim: OptimizerFactory, vae_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, critic2_optim: OptimizerFactory | None = None, gamma: float = 0.99, tau: float = 0.005, lmbda: float = 0.75, num_sampled_action: int = 10, ) -> None: """ :param policy: the policy :param actor_perturbation_optim: the optimizer factory for the policy's actor perturbation network. :param critic_optim: the optimizer factory for the policy's critic network. :param critic2: the second critic network; if None, clone the critic from the policy :param critic2_optim: the optimizer factory for the second critic network; if None, use optimizer factory of first critic :param vae_optim: the optimizer factory for the VAE network. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param lmbda: param for Clipped Double Q-learning. :param num_sampled_action: the number of sampled actions in calculating target Q. The algorithm samples several actions using VAE, and perturbs each action to get the target Q. """ # actor is Perturbation! super().__init__( policy=policy, ) LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) self.actor_perturbation_target = self._add_lagged_network(self.policy.actor_perturbation) self.actor_perturbation_optim = self._create_optimizer( self.policy.actor_perturbation, actor_perturbation_optim ) self.critic_target = self._add_lagged_network(self.policy.critic) self.critic_optim = self._create_optimizer(self.policy.critic, critic_optim) self.critic2 = critic2 or copy.deepcopy(self.policy.critic) self.critic2_target = self._add_lagged_network(self.critic2) self.critic2_optim = self._create_optimizer(self.critic2, critic2_optim or critic_optim) self.vae_optim = self._create_optimizer(self.policy.vae, vae_optim) self.gamma = gamma self.lmbda = lmbda self.num_sampled_action = num_sampled_action def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> BCQTrainingStats: # batch: obs, act, rew, done, obs_next. (numpy array) # (batch_size, state_dim) # TODO: This does not use policy.forward but computes things directly, which seems odd device = next(self.parameters()).device batch: Batch = to_torch(batch, dtype=torch.float, device=device) obs, act = batch.obs, batch.act batch_size = obs.shape[0] # mean, std: (state.shape[0], latent_dim) recon, mean, std = self.policy.vae(obs, act) recon_loss = F.mse_loss(act, recon) # (....) is D_KL( N(mu, sigma) || N(0,1) ) KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean() vae_loss = recon_loss + KL_loss / 2 self.vae_optim.step(vae_loss) # critic training: with torch.no_grad(): # repeat num_sampled_action times obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0) # now obs_next: (num_sampled_action * batch_size, state_dim) # perturbed action generated by VAE act_next = self.policy.vae.decode(obs_next) # now obs_next: (num_sampled_action * batch_size, action_dim) target_Q1 = self.critic_target(obs_next, act_next) target_Q2 = self.critic2_target(obs_next, act_next) # Clipped Double Q-learning target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1 - self.lmbda) * torch.max( target_Q1, target_Q2, ) # now target_Q: (num_sampled_action * batch_size, 1) # the max value of Q target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1) # now target_Q: (batch_size, 1) target_Q = ( batch.rew.reshape(-1, 1) + torch.logical_not(batch.done).reshape(-1, 1) * self.gamma * target_Q ) target_Q = target_Q.float() current_Q1 = self.policy.critic(obs, act) current_Q2 = self.critic2(obs, act) critic1_loss = F.mse_loss(current_Q1, target_Q) critic2_loss = F.mse_loss(current_Q2, target_Q) self.critic_optim.step(critic1_loss) self.critic2_optim.step(critic2_loss) sampled_act = self.policy.vae.decode(obs) perturbed_act = self.policy.actor_perturbation(obs, sampled_act) # max actor_loss = -self.policy.critic(obs, perturbed_act).mean() self.actor_perturbation_optim.step(actor_loss) # update target networks self._update_lagged_network_weights() return BCQTrainingStats( actor_loss=actor_loss.item(), critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), vae_loss=vae_loss.item(), ) ================================================ FILE: tianshou/algorithm/imitation/cql.py ================================================ from copy import deepcopy from dataclasses import dataclass from typing import cast import numpy as np import torch import torch.nn.functional as F from overrides import override from tianshou.algorithm.algorithm_base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, ) from tianshou.algorithm.modelfree.sac import Alpha, SACPolicy, SACTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.buffer.buffer_base import TBuffer from tianshou.data.types import RolloutBatchProtocol from tianshou.utils.conversion import to_optional_float from tianshou.utils.torch_utils import torch_device @dataclass(kw_only=True) class CQLTrainingStats(SACTrainingStats): """A data structure for storing loss statistics of the CQL learn step.""" cql_alpha: float | None = None cql_alpha_loss: float | None = None # TODO: Perhaps SACPolicy should get a more generic name class CQL(OfflineAlgorithm[SACPolicy], LaggedNetworkPolyakUpdateAlgorithmMixin): """Implementation of the conservative Q-learning (CQL) algorithm. arXiv:2006.04779.""" def __init__( self, *, policy: SACPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module, critic_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, critic2_optim: OptimizerFactory | None = None, cql_alpha_lr: float = 1e-4, cql_weight: float = 1.0, tau: float = 0.005, gamma: float = 0.99, alpha: float | Alpha = 0.2, temperature: float = 1.0, with_lagrange: bool = True, lagrange_threshold: float = 10.0, min_action: float = -1.0, max_action: float = 1.0, num_repeat_actions: int = 10, alpha_min: float = 0.0, alpha_max: float = 1e6, max_grad_norm: float = 1.0, calibrated: bool = True, ) -> None: """ :param actor: the actor network following the rules (s -> a) :param policy_optim: the optimizer factory for the policy/its actor network. :param critic: the first critic network. :param critic_optim: the optimizer factory for the first critic network. :param action_space: the environment's action space. :param critic2: the second critic network. (s, a -> Q(s, a)). If None, use the same network as critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, clone the first critic's optimizer factory. :param cql_alpha_lr: the learning rate for the Lagrange multiplier optimization. Controls how quickly the CQL regularization coefficient (alpha) adapts during training. Higher values allow faster adaptation but may cause instability in the training process. Lower values provide more stable but slower adaptation of the regularization strength. Only relevant when with_lagrange=True. :param cql_weight: the coefficient that scales the conservative regularization term in the Q-function loss. Controls the strength of the conservative Q-learning component relative to standard TD learning. Higher values enforce more conservative value estimates by penalizing overestimation more strongly. Lower values allow the algorithm to behave more like standard Q-learning. Increasing this weight typically improves performance in purely offline settings where overestimation bias can lead to poor policy extraction. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). :param temperature: the temperature parameter used in the LogSumExp calculation of the CQL loss. Controls the sharpness of the softmax distribution when computing the expected Q-values. Lower values make the LogSumExp operation more selective, focusing on the highest Q-values. Higher values make the operation closer to an average, giving more weight to all Q-values. The temperature affects how conservatively the algorithm penalizes out-of-distribution actions. :param with_lagrange: a flag indicating whether to automatically tune the CQL regularization strength. If True, uses Lagrangian dual gradient descent to dynamically adjust the CQL alpha parameter. This formulation maintains the CQL regularization loss near the lagrange_threshold value. Adaptive tuning helps balance conservative learning against excessive pessimism. If False, the conservative loss is scaled by a fixed cql_weight throughout training. The original CQL paper recommends setting this to True for most offline RL tasks. :param lagrange_threshold: the target value for the CQL regularization loss when using Lagrangian optimization. When with_lagrange=True, the algorithm dynamically adjusts the CQL alpha parameter to maintain the regularization loss close to this threshold. Lower values result in more conservative behavior by enforcing stronger penalties on out-of-distribution actions. Higher values allow more optimistic Q-value estimates similar to standard Q-learning. This threshold effectively controls the level of conservatism in CQL's value estimation. :param min_action: the lower bound for each dimension of the action space. Used when sampling random actions for the CQL regularization term. Should match the environment's action space minimum values. These random actions help penalize Q-values for out-of-distribution actions. Typically set to -1.0 for normalized continuous action spaces. :param max_action: the upper bound for each dimension of the action space. Used when sampling random actions for the CQL regularization term. Should match the environment's action space maximum values. These random actions help penalize Q-values for out-of-distribution actions. Typically set to 1.0 for normalized continuous action spaces. :param num_repeat_actions: the number of action samples generated per state when computing the CQL regularization term. Controls how many random and policy actions are sampled for each state in the batch when estimating expected Q-values. Higher values provide more accurate approximation of the expected Q-values but increase computational cost. Lower values reduce computation but may provide less stable or less accurate regularization. The original CQL paper typically uses values around 10. :param alpha_min: the minimum value allowed for the adaptive CQL regularization coefficient. When using Lagrangian optimization (with_lagrange=True), constrains the automatically tuned cql_alpha parameter to be at least this value. Prevents the regularization strength from becoming too small during training. Setting a positive value ensures the algorithm maintains at least some degree of conservatism. Only relevant when with_lagrange=True. :param alpha_max: the maximum value allowed for the adaptive CQL regularization coefficient. When using Lagrangian optimization (with_lagrange=True), constrains the automatically tuned cql_alpha parameter to be at most this value. Prevents the regularization strength from becoming too large during training. Setting an appropriate upper limit helps avoid overly conservative behavior that might hinder learning useful value functions. Only relevant when with_lagrange=True. :param max_grad_norm: the maximum L2 norm threshold for gradient clipping when updating critic networks. Gradients with norm exceeding this value will be rescaled to have norm equal to this value. Helps stabilize training by preventing excessively large parameter updates from outlier samples. Higher values allow larger updates but may lead to training instability. Lower values enforce more conservative updates but may slow down learning. Setting to a large value effectively disables gradient clipping. :param calibrated: a flag indicating whether to use the calibrated version of CQL (CalQL). If True, calibrates Q-values by taking the maximum of computed Q-values and Monte Carlo returns. This modification helps address the excessive pessimism problem in standard CQL. Particularly useful for offline pre-training followed by online fine-tuning scenarios. Experimental results suggest this approach often achieves better performance than vanilla CQL. Based on techniques from the CalQL paper (arXiv:2303.05479). """ super().__init__( policy=policy, ) LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) device = torch_device(policy) self.policy_optim = self._create_optimizer(self.policy, policy_optim) self.critic = critic self.critic_optim = self._create_optimizer( self.critic, critic_optim, max_grad_norm=max_grad_norm ) self.critic2 = critic2 or deepcopy(critic) self.critic2_optim = self._create_optimizer( self.critic2, critic2_optim or critic_optim, max_grad_norm=max_grad_norm ) self.critic_old = self._add_lagged_network(self.critic) self.critic2_old = self._add_lagged_network(self.critic2) self.gamma = gamma self.alpha = Alpha.from_float_or_instance(alpha) self.temperature = temperature self.with_lagrange = with_lagrange self.lagrange_threshold = lagrange_threshold self.cql_weight = cql_weight self.cql_log_alpha = torch.tensor([0.0], requires_grad=True) # TODO: Use an OptimizerFactory? self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr) self.cql_log_alpha = self.cql_log_alpha.to(device) self.min_action = min_action self.max_action = max_action self.num_repeat_actions = num_repeat_actions self.alpha_min = alpha_min self.alpha_max = alpha_max self.calibrated = calibrated def _policy_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch = Batch(obs=obs, info=[None] * len(obs)) obs_result = self.policy(batch) return obs_result.act, obs_result.log_prob def _calc_policy_loss(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: act_pred, log_pi = self._policy_pred(obs) q1 = self.critic(obs, act_pred) q2 = self.critic2(obs, act_pred) min_Q = torch.min(q1, q2) # self.alpha: float | torch.Tensor actor_loss = (self.alpha.value * log_pi - min_Q).mean() # actor_loss.shape: (), log_pi.shape: (batch_size, 1) return actor_loss, log_pi def _calc_pi_values( self, obs_pi: torch.Tensor, obs_to_pred: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: act_pred, log_pi = self._policy_pred(obs_pi) q1 = self.critic(obs_to_pred, act_pred) q2 = self.critic2(obs_to_pred, act_pred) return q1 - log_pi.detach(), q2 - log_pi.detach() def _calc_random_values( self, obs: torch.Tensor, act: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: random_value1 = self.critic(obs, act) random_log_prob1 = np.log(0.5 ** act.shape[-1]) random_value2 = self.critic2(obs, act) random_log_prob2 = np.log(0.5 ** act.shape[-1]) return random_value1 - random_log_prob1, random_value2 - random_log_prob2 @override def process_buffer(self, buffer: TBuffer) -> TBuffer: """If `self.calibrated = True`, adds `calibration_returns` to buffer._meta. :param buffer: :return: """ if self.calibrated: # otherwise _meta hack cannot work assert isinstance(buffer, ReplayBuffer) batch, indices = buffer.sample(0) returns, _ = self.compute_episodic_return( batch=batch, buffer=buffer, indices=indices, gamma=self.gamma, gae_lambda=1.0, ) # TODO: don't access _meta directly buffer._meta = cast( RolloutBatchProtocol, Batch(**buffer._meta.__dict__, calibration_returns=returns), ) return buffer def _update_with_batch(self, batch: RolloutBatchProtocol) -> CQLTrainingStats: device = torch_device(self.policy) batch: Batch = to_torch(batch, dtype=torch.float, device=device) obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next batch_size = obs.shape[0] # compute actor loss and update actor actor_loss, log_pi = self._calc_policy_loss(obs) self.policy_optim.step(actor_loss) entropy = -log_pi.detach() alpha_loss = self.alpha.update(entropy) # compute target_Q with torch.no_grad(): act_next, new_log_pi = self._policy_pred(obs_next) target_Q1 = self.critic_old(obs_next, act_next) target_Q2 = self.critic2_old(obs_next, act_next) target_Q = torch.min(target_Q1, target_Q2) - self.alpha.value * new_log_pi target_Q = rew + torch.logical_not(batch.done) * self.gamma * target_Q.flatten() target_Q = target_Q.float() # shape: (batch_size) # compute critic loss current_Q1 = self.critic(obs, act).flatten() current_Q2 = self.critic2(obs, act).flatten() # shape: (batch_size) critic1_loss = F.mse_loss(current_Q1, target_Q) critic2_loss = F.mse_loss(current_Q2, target_Q) # CQL random_actions = ( torch.FloatTensor(batch_size * self.num_repeat_actions, act.shape[-1]) .uniform_(-self.min_action, self.max_action) .to(device) ) obs_len = len(obs.shape) repeat_size = [1, self.num_repeat_actions] + [1] * (obs_len - 1) view_size = [batch_size * self.num_repeat_actions, *list(obs.shape[1:])] tmp_obs = obs.unsqueeze(1).repeat(*repeat_size).view(*view_size) tmp_obs_next = obs_next.unsqueeze(1).repeat(*repeat_size).view(*view_size) # tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim) current_pi_value1, current_pi_value2 = self._calc_pi_values(tmp_obs, tmp_obs) next_pi_value1, next_pi_value2 = self._calc_pi_values(tmp_obs_next, tmp_obs) random_value1, random_value2 = self._calc_random_values(tmp_obs, random_actions) for value in [ current_pi_value1, current_pi_value2, next_pi_value1, next_pi_value2, random_value1, random_value2, ]: value.reshape(batch_size, self.num_repeat_actions, 1) if self.calibrated: returns = ( batch.calibration_returns.unsqueeze(1) .repeat( (1, self.num_repeat_actions), ) .view(-1, 1) ) random_value1 = torch.max(random_value1, returns) random_value2 = torch.max(random_value2, returns) current_pi_value1 = torch.max(current_pi_value1, returns) current_pi_value2 = torch.max(current_pi_value2, returns) next_pi_value1 = torch.max(next_pi_value1, returns) next_pi_value2 = torch.max(next_pi_value2, returns) # cat q values cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1) cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1) # shape: (batch_size, 3 * num_repeat, 1) cql1_scaled_loss = ( torch.logsumexp(cat_q1 / self.temperature, dim=1).mean() * self.cql_weight * self.temperature - current_Q1.mean() * self.cql_weight ) cql2_scaled_loss = ( torch.logsumexp(cat_q2 / self.temperature, dim=1).mean() * self.cql_weight * self.temperature - current_Q2.mean() * self.cql_weight ) # shape: (1) cql_alpha_loss = None cql_alpha = None if self.with_lagrange: cql_alpha = torch.clamp( self.cql_log_alpha.exp(), self.alpha_min, self.alpha_max, ) cql1_scaled_loss = cql_alpha * (cql1_scaled_loss - self.lagrange_threshold) cql2_scaled_loss = cql_alpha * (cql2_scaled_loss - self.lagrange_threshold) self.cql_alpha_optim.zero_grad() cql_alpha_loss = -(cql1_scaled_loss + cql2_scaled_loss) * 0.5 cql_alpha_loss.backward(retain_graph=True) self.cql_alpha_optim.step() critic1_loss = critic1_loss + cql1_scaled_loss critic2_loss = critic2_loss + cql2_scaled_loss # update critics self.critic_optim.step(critic1_loss, retain_graph=True) self.critic2_optim.step(critic2_loss) self._update_lagged_network_weights() return CQLTrainingStats( actor_loss=to_optional_float(actor_loss), critic1_loss=to_optional_float(critic1_loss), critic2_loss=to_optional_float(critic2_loss), alpha=to_optional_float(self.alpha.value), alpha_loss=to_optional_float(alpha_loss), cql_alpha_loss=to_optional_float(cql_alpha_loss), cql_alpha=to_optional_float(cql_alpha), ) ================================================ FILE: tianshou/algorithm/imitation/discrete_bcq.py ================================================ import math from dataclasses import dataclass from typing import Any, cast import gymnasium as gym import numpy as np import torch import torch.nn.functional as F from torch import nn from tianshou.algorithm.algorithm_base import ( LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, ) from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.types import ( BatchWithReturnsProtocol, ImitationBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) float_info = torch.finfo(torch.float32) INF = float_info.max @dataclass(kw_only=True) class DiscreteBCQTrainingStats(SimpleLossTrainingStats): q_loss: float i_loss: float reg_loss: float class DiscreteBCQPolicy(DiscreteQLearningPolicy): def __init__( self, *, model: torch.nn.Module, imitator: torch.nn.Module, target_update_freq: int = 8000, unlikely_action_threshold: float = 0.3, action_space: gym.spaces.Discrete, observation_space: gym.Space | None = None, eps_inference: float = 0.0, ) -> None: """ :param model: a model following the rules (s_B -> action_values_BA) :param imitator: a model following the rules (s -> imitation_logits) :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. :param unlikely_action_threshold: the threshold (tau) for unlikely actions, as shown in Equ. (17) in the paper. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. :param action_space: the environment's action space. :param observation_space: the environment's observation space. :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, i.e. non-training cases (such as evaluation during test steps). The epsilon value is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). """ super().__init__( model=model, action_space=action_space, observation_space=observation_space, eps_training=0.0, # no training data collection (offline) eps_inference=eps_inference, ) self.imitator = imitator assert target_update_freq > 0, ( f"BCQ needs target_update_freq>0 but got: {target_update_freq}." ) assert 0.0 <= unlikely_action_threshold < 1.0, ( f"unlikely_action_threshold should be in [0, 1) but got: {unlikely_action_threshold}" ) if unlikely_action_threshold > 0: self._log_tau = math.log(unlikely_action_threshold) else: self._log_tau = -np.inf def forward( self, batch: ObsBatchProtocol, state: Any | None = None, model: nn.Module | None = None, ) -> ImitationBatchProtocol: if model is None: model = self.model q_value, state = model(batch.obs, state=state, info=batch.info) imitation_logits, _ = self.imitator(batch.obs, state=state, info=batch.info) # mask actions for argmax ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values mask = (ratio < self._log_tau).float() act = (q_value - INF * mask).argmax(dim=-1) result = Batch( act=act, state=state, q_value=q_value, imitation_logits=imitation_logits, logits=imitation_logits, ) return cast(ImitationBatchProtocol, result) class DiscreteBCQ( OfflineAlgorithm[DiscreteBCQPolicy], LaggedNetworkFullUpdateAlgorithmMixin, ): """Implementation of the discrete batch-constrained deep Q-learning (BCQ) algorithm. arXiv:1910.01708.""" def __init__( self, *, policy: DiscreteBCQPolicy, optim: OptimizerFactory, gamma: float = 0.99, n_step_return_horizon: int = 1, target_update_freq: int = 8000, imitation_logits_penalty: float = 1e-2, ) -> None: """ :param policy: the policy :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. :param imitation_logits_penalty: regularization weight for imitation logits. :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. """ super().__init__( policy=policy, ) LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) self.optim = self._create_optimizer(self.policy, optim) assert 0.0 <= gamma <= 1.0, f"discount factor should be in [0, 1] but got: {gamma}" self.gamma = gamma assert n_step_return_horizon > 0, ( f"n_step_return_horizon should be greater than 0 but got: {n_step_return_horizon}" ) self.n_step = n_step_return_horizon self._target = target_update_freq > 0 self.freq = target_update_freq self._iter = 0 if self._target: self.model_old = self._add_lagged_network(self.policy.model) self._weight_reg = imitation_logits_penalty def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithReturnsProtocol: return self.compute_nstep_return( batch=batch, buffer=buffer, indices=indices, target_q_fn=self._target_q, gamma=self.gamma, n_step=self.n_step, ) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs_next: s_{t+n} next_obs_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) # target_Q = Q_old(s_, argmax(Q_new(s_, *))) act = self.policy(next_obs_batch).act target_q, _ = self.model_old(batch.obs_next) return target_q[np.arange(len(act)), act] def _update_with_batch( # type: ignore[override] self, batch: BatchWithReturnsProtocol, ) -> DiscreteBCQTrainingStats: if self._iter % self.freq == 0: self._update_lagged_network_weights() self._iter += 1 target_q = batch.returns.flatten() result = self.policy(batch) imitation_logits = result.imitation_logits current_q = result.q_value[np.arange(len(target_q)), batch.act] act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) i_loss = F.nll_loss(F.log_softmax(imitation_logits, dim=-1), act) reg_loss = imitation_logits.pow(2).mean() loss = q_loss + i_loss + self._weight_reg * reg_loss self.optim.step(loss) return DiscreteBCQTrainingStats( loss=loss.item(), q_loss=q_loss.item(), i_loss=i_loss.item(), reg_loss=reg_loss.item(), ) ================================================ FILE: tianshou/algorithm/imitation/discrete_cql.py ================================================ from dataclasses import dataclass import numpy as np import torch import torch.nn.functional as F from tianshou.algorithm import QRDQN from tianshou.algorithm.algorithm_base import OfflineAlgorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import to_torch from tianshou.data.types import RolloutBatchProtocol @dataclass(kw_only=True) class DiscreteCQLTrainingStats(SimpleLossTrainingStats): cql_loss: float qr_loss: float # NOTE: This uses diamond inheritance to convert from off-policy to offline class DiscreteCQL(OfflineAlgorithm[QRDQNPolicy], QRDQN[QRDQNPolicy]): # type: ignore[misc] """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.""" def __init__( self, *, policy: QRDQNPolicy, optim: OptimizerFactory, min_q_weight: float = 10.0, gamma: float = 0.99, num_quantiles: int = 200, n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ :param policy: the policy :param optim: the optimizer factory for the policy's model. :param min_q_weight: the weight for the cql loss. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_quantiles: the number of quantile midpoints in the inverse cumulative distribution function of the value. :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. """ QRDQN.__init__( self, policy=policy, optim=optim, gamma=gamma, num_quantiles=num_quantiles, n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.min_q_weight = min_q_weight def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> DiscreteCQLTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) all_dist = self.policy(batch).logits act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( (dist_diff * (self.tau_hat - (target_dist - curr_dist).detach().le(0.0).float()).abs()) .sum(-1) .mean(1) ) qr_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer # add CQL loss q = self.policy.compute_q_value(all_dist, None) dataset_expec = q.gather(1, act.unsqueeze(1)).mean() negative_sampling = q.logsumexp(1).mean() min_q_loss = negative_sampling - dataset_expec loss = qr_loss + min_q_loss * self.min_q_weight self.optim.step(loss) return DiscreteCQLTrainingStats( loss=loss.item(), qr_loss=qr_loss.item(), cql_loss=min_q_loss.item(), ) ================================================ FILE: tianshou/algorithm/imitation/discrete_crr.py ================================================ from dataclasses import dataclass from typing import Literal import numpy as np import torch import torch.nn.functional as F from torch.distributions import Categorical from torch.nn import ModuleList from tianshou.algorithm.algorithm_base import ( LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, ) from tianshou.algorithm.modelfree.reinforce import ( DiscountedReturnComputation, DiscreteActorPolicy, SimpleLossTrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, to_torch, to_torch_as from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.discrete import DiscreteCritic @dataclass class DiscreteCRRTrainingStats(SimpleLossTrainingStats): actor_loss: float critic_loss: float cql_loss: float class DiscreteCRR( OfflineAlgorithm[DiscreteActorPolicy], LaggedNetworkFullUpdateAlgorithmMixin, ): r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.""" def __init__( self, *, policy: DiscreteActorPolicy, critic: torch.nn.Module | DiscreteCritic, optim: OptimizerFactory, gamma: float = 0.99, policy_improvement_mode: Literal["exp", "binary", "all"] = "exp", ratio_upper_bound: float = 20.0, beta: float = 1.0, min_q_weight: float = 10.0, target_update_freq: int = 0, return_standardization: bool = False, ) -> None: r""" :param policy: the policy :param critic: the action-value critic (i.e., Q function) network. (s -> Q(s, \*)) :param optim: the optimizer factory for the policy's actor network and the critic networks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param str policy_improvement_mode: type of the weight function f. Possible values: "binary"/"exp"/"all". :param ratio_upper_bound: when policy_improvement_mode is "exp", the value of the exp function is upper-bounded by this parameter. :param beta: when policy_improvement_mode is "exp", this is the denominator of the exp function. :param min_q_weight: weight for CQL loss/regularizer. Default to 10. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. :param return_standardization: whether to standardize episode returns by subtracting the running mean and dividing by the running standard deviation. Note that this is known to be detrimental to performance in many cases! """ super().__init__( policy=policy, ) LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) self.discounted_return_computation = DiscountedReturnComputation( gamma=gamma, return_standardization=return_standardization, ) self.critic = critic self.optim = self._create_optimizer(ModuleList([self.policy, self.critic]), optim) self._target = target_update_freq > 0 self._freq = target_update_freq self._iter = 0 self.actor_old: torch.nn.Module | EvalModeModuleWrapper self.critic_old: torch.nn.Module | EvalModeModuleWrapper if self._target: self.actor_old = self._add_lagged_network(self.policy.actor) self.critic_old = self._add_lagged_network(self.critic) else: self.actor_old = self.policy.actor self.critic_old = self.critic self._policy_improvement_mode = policy_improvement_mode self._ratio_upper_bound = ratio_upper_bound self._beta = beta self._min_q_weight = min_q_weight def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithReturnsProtocol: return self.discounted_return_computation.add_discounted_returns( batch, buffer, indices, ) def _update_with_batch( # type: ignore[override] self, batch: BatchWithReturnsProtocol, ) -> DiscreteCRRTrainingStats: if self._target and self._iter % self._freq == 0: self._update_lagged_network_weights() q_t = self.critic(batch.obs) act = to_torch(batch.act, dtype=torch.long, device=q_t.device) qa_t = q_t.gather(1, act.unsqueeze(1)) # Critic loss with torch.no_grad(): target_a_t, _ = self.actor_old(batch.obs_next) target_m = Categorical(logits=target_a_t) q_t_target = self.critic_old(batch.obs_next) rew = to_torch_as(batch.rew, q_t_target) expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) expected_target_q[batch.done > 0] = 0.0 target = rew.unsqueeze(1) + self.discounted_return_computation.gamma * expected_target_q critic_loss = 0.5 * F.mse_loss(qa_t, target) # Actor loss act_target, _ = self.policy.actor(batch.obs) dist = Categorical(logits=act_target) expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True) advantage = qa_t - expected_policy_q if self._policy_improvement_mode == "binary": actor_loss_coef = (advantage > 0).float() elif self._policy_improvement_mode == "exp": actor_loss_coef = (advantage / self._beta).exp().clamp(0, self._ratio_upper_bound) else: actor_loss_coef = 1.0 # effectively behavior cloning actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean() # CQL loss/regularizer min_q_loss = (q_t.logsumexp(1) - qa_t).mean() loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss self.optim.step(loss) self._iter += 1 return DiscreteCRRTrainingStats( loss=loss.item(), actor_loss=actor_loss.item(), critic_loss=critic_loss.item(), cql_loss=min_q_loss.item(), ) ================================================ FILE: tianshou/algorithm/imitation/gail.py ================================================ from dataclasses import dataclass import numpy as np import torch import torch.nn.functional as F from tianshou.algorithm.modelfree.a2c import A2CTrainingStats from tianshou.algorithm.modelfree.ppo import PPO from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ( ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch, ) from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.utils.net.common import ModuleWithVectorOutput from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic from tianshou.utils.torch_utils import torch_device @dataclass(kw_only=True) class GailTrainingStats(A2CTrainingStats): disc_loss: SequenceSummaryStats acc_pi: SequenceSummaryStats acc_exp: SequenceSummaryStats class GAIL(PPO): """Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.""" def __init__( self, *, policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, expert_buffer: ReplayBuffer, disc_net: torch.nn.Module, disc_optim: OptimizerFactory, disc_update_num: int = 4, eps_clip: float = 0.2, dual_clip: float | None = None, value_clip: bool = False, advantage_normalization: bool = True, recompute_advantage: bool = False, vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, return_scaling: bool = False, ) -> None: """ :param policy: the policy (which must use an actor with known output dimension, i.e. any Tianshou `Actor` implementation or other subclass of `ModuleWithVectorOutput`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the actor and critic networks. :param expert_buffer: the replay buffer containing expert experience. :param disc_net: the discriminator neural network that distinguishes between expert and policy behaviors. Takes concatenated state-action pairs [obs, act] as input and outputs an unbounded logit value. The raw output is transformed in the algorithm using sigmoid functions: o(output) for expert probability and -log(1-o(-output)) for policy rewards. Positive output values indicate the discriminator believes the behavior is from an expert. Negative output values indicate the discriminator believes the behavior is from the policy. The network architecture should end with a linear layer of output size 1 without any activation function, as sigmoid operations are applied separately. :param disc_optim: the optimizer factory for the discriminator network. :param disc_update_num: the number of discriminator update steps performed for each policy update step. Controls the learning dynamics between the policy and the discriminator. Higher values strengthen the discriminator relative to the policy, potentially improving the quality of the reward signal but slowing down training. Lower values allow faster policy updates but may result in a weaker discriminator that fails to properly distinguish between expert and policy behaviors. Typical values range from 1 to 10, with the original GAIL paper using multiple discriminator updates per policy update. :param eps_clip: determines the range of allowed change in the policy during a policy update: The ratio of action probabilities indicated by the new and old policy is constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. Small values thus force the new policy to stay close to the old policy. Typical values range between 0.1 and 0.3, the value of 0.2 is recommended in the original PPO paper. The optimal value depends on the environment; more stochastic environments may need larger values. :param dual_clip: a clipping parameter (denoted as c in the literature) that prevents excessive pessimism in policy updates for negative-advantage actions. Excessive pessimism occurs when the policy update too strongly reduces the probability of selecting actions that led to negative advantages, potentially eliminating useful actions based on limited negative experiences. When enabled (c > 1), the objective for negative advantages becomes: max(min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A), c*A), where min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) is the original single-clipping objective determined by `eps_clip`. This creates a floor on negative policy gradients, maintaining some probability of exploring actions despite initial negative outcomes. Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer to 1.0 provide less protection against pessimistic updates. Set to None to disable dual clipping. :param value_clip: flag indicating whether to enable clipping for value function updates. When enabled, restricts how much the value function estimate can change from its previous prediction, using the same clipping range as the policy updates (eps_clip). This stabilizes training by preventing large fluctuations in value estimates, particularly useful in environments with high reward variance. The clipped value loss uses a pessimistic approach, taking the maximum of the original and clipped value errors: max((returns - value)², (returns - v_clipped)²) Setting to True often improves training stability but may slow convergence. Implementation follows the approach mentioned in arXiv:1811.02553v3 Sec. 4.1. :param advantage_normalization: whether to do per mini-batch advantage normalization. :param recompute_advantage: whether to recompute advantage every update repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. :param vf_coef: coefficient that weights the value loss relative to the actor loss in the overall loss function. Higher values prioritize accurate value function estimation over policy improvement. Controls the trade-off between policy optimization and value function fitting. Typically set between 0.5 and 1.0 for most actor-critic implementations. :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. Controls the exploration-exploitation trade-off by encouraging policy entropy. Higher values promote more exploration by encouraging a more uniform action distribution. Lower values focus more on exploitation of the current policy's knowledge. Typically set between 0.01 and 0.05 for most actor-critic implementations. :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. When not None, gradients will be rescaled using to ensure their L2 norm does not exceed this value. This prevents exploding gradients and stabilizes training by limiting the size of parameter updates. Set to None to disable gradient clipping. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values (closer to 1) reduce bias but increase variance by giving more weight to longer trajectories, while lower values (closer to 0) reduce variance but increase bias by relying more on the immediate TD error and value function estimates. At λ=0, GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. :param max_batchsize: the maximum number of samples to process at once when computing generalized advantage estimation (GAE) and value function predictions. Controls memory usage by breaking large batches into smaller chunks processed sequentially. Higher values may increase speed but require more GPU/CPU memory; lower values reduce memory requirements but may increase computation time. Should be adjusted based on available hardware resources and total batch size of your training data. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param return_scaling: flag indicating whether to enable scaling of estimated returns by dividing them by their running standard deviation without centering the mean. This reduces the magnitude variation of advantages across different episodes while preserving their signs and relative ordering. The use of running statistics (rather than batch-specific scaling) means that early training experiences may be scaled differently than later ones as the statistics evolve. When enabled, this improves training stability in environments with highly variable reward scales and makes the algorithm less sensitive to learning rate settings. However, it may reduce the algorithm's ability to distinguish between episodes with different absolute return magnitudes. Best used in environments where the relative ordering of actions is more important than the absolute scale of returns. """ super().__init__( policy=policy, critic=critic, optim=optim, eps_clip=eps_clip, dual_clip=dual_clip, value_clip=value_clip, advantage_normalization=advantage_normalization, recompute_advantage=recompute_advantage, vf_coef=vf_coef, ent_coef=ent_coef, max_grad_norm=max_grad_norm, gae_lambda=gae_lambda, max_batchsize=max_batchsize, gamma=gamma, return_scaling=return_scaling, ) self.disc_net = disc_net self.disc_optim = self._create_optimizer(self.disc_net, disc_optim) self.disc_update_num = disc_update_num self.expert_buffer = expert_buffer actor = self.policy.actor if not isinstance(actor, ModuleWithVectorOutput): raise TypeError("GAIL requires the policy to use an actor with known output dimension.") self.action_dim = actor.get_output_dim() def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> LogpOldProtocol: """Pre-process the data from the provided replay buffer. Used in :meth:`update`. Check out :ref:`process_fn` for more information. """ # update reward with torch.no_grad(): batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten()) return super()._preprocess_batch(batch, buffer, indices) def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor: device = torch_device(self.disc_net) obs = to_torch(batch.obs, device=device) act = to_torch(batch.act, device=device) return self.disc_net(torch.cat([obs, act], dim=1)) def _update_with_batch( # type: ignore[override] self, batch: LogpOldProtocol, batch_size: int | None, repeat: int, ) -> GailTrainingStats: # update discriminator losses = [] acc_pis = [] acc_exps = [] bsz = len(batch) // self.disc_update_num for b in batch.split(bsz, merge_last=True): logits_pi = self.disc(b) exp_b = self.expert_buffer.sample(bsz)[0] logits_exp = self.disc(exp_b) loss_pi = -F.logsigmoid(-logits_pi).mean() loss_exp = -F.logsigmoid(logits_exp).mean() loss_disc = loss_pi + loss_exp self.disc_optim.step(loss_disc) losses.append(loss_disc.item()) acc_pis.append((logits_pi < 0).float().mean().item()) acc_exps.append((logits_exp > 0).float().mean().item()) # update policy ppo_loss_stat = super()._update_with_batch(batch, batch_size, repeat) disc_losses_summary = SequenceSummaryStats.from_sequence(losses) acc_pi_summary = SequenceSummaryStats.from_sequence(acc_pis) acc_exps_summary = SequenceSummaryStats.from_sequence(acc_exps) return GailTrainingStats( **ppo_loss_stat.__dict__, disc_loss=disc_losses_summary, acc_pi=acc_pi_summary, acc_exp=acc_exps_summary, ) ================================================ FILE: tianshou/algorithm/imitation/imitation_base.py ================================================ from dataclasses import dataclass from typing import Any, Literal, cast import gymnasium as gym import numpy as np import torch import torch.nn.functional as F from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( OfflineAlgorithm, OffPolicyAlgorithm, Policy, TrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( ModelOutputBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) # Dimension Naming Convention # B - Batch Size # A - Action # D - Dist input (usually 2, loc and scale) # H - Dimension of hidden, can be None @dataclass(kw_only=True) class ImitationTrainingStats(TrainingStats): loss: float = 0.0 class ImitationPolicy(Policy): def __init__( self, *, actor: torch.nn.Module, action_space: gym.Space, observation_space: gym.Space | None = None, action_scaling: bool = False, action_bound_method: Literal["clip", "tanh"] | None = "clip", ): """ :param actor: a model following the rules (s -> a) :param action_space: the environment's action_space. :param observation_space: the environment's observation space :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. This applies to continuous action spaces only (gym.spaces.Box) and has no effect for discrete spaces. When enabled, policy outputs are expected to be in the normalized range [-1, 1] (after bounding), and are then linearly transformed to the actual required range. This improves neural network training stability, allows the same algorithm to work across environments with different action ranges, and standardizes exploration strategies. Should be disabled if the actor model already produces outputs in the correct range. :param action_bound_method: the method used for bounding actions in continuous action spaces to the range [-1, 1] before scaling them to the environment's action space (provided that `action_scaling` is enabled). This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None for discrete spaces. When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly constrains outputs to [-1, 1] while preserving gradients. The choice of bounding method affects both training dynamics and exploration behavior. Clipping provides hard boundaries but may create plateau regions in the gradient landscape, while tanh provides smoother transitions but can compress sensitivity near the boundaries. Should be set to None if the actor model inherently produces bounded outputs. Typically used together with `action_scaling=True`. """ super().__init__( action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, action_bound_method=action_bound_method, ) self.actor = actor def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ModelOutputBatchProtocol: # TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced if self.action_type == "discrete": # If it's discrete, the "actor" is usually a critic that maps obs to action_values # which then could be turned into logits or a Categorigal action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) act_B = action_values_BA.argmax(dim=1) result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) elif self.action_type == "continuous": # If it's continuous, the actor would usually deliver something like loc, scale determining a # Gaussian dist dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH) else: raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!") return cast(ModelOutputBatchProtocol, result) class ImitationLearningAlgorithmMixin: def _imitation_update( self, batch: RolloutBatchProtocol, policy: ImitationPolicy, optim: Algorithm.Optimizer, ) -> ImitationTrainingStats: if policy.action_type == "continuous": # regression act = policy(batch).act act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) loss = F.mse_loss(act, act_target) elif policy.action_type == "discrete": # classification act = F.log_softmax(policy(batch).logits, dim=-1) act_target = to_torch(batch.act, dtype=torch.long, device=act.device) loss = F.nll_loss(act, act_target) else: raise ValueError(policy.action_type) optim.step(loss) return ImitationTrainingStats(loss=loss.item()) class OffPolicyImitationLearning( OffPolicyAlgorithm[ImitationPolicy], ImitationLearningAlgorithmMixin, ): """Implementation of off-policy vanilla imitation learning.""" def __init__( self, *, policy: ImitationPolicy, optim: OptimizerFactory, ) -> None: """ :param policy: the policy :param optim: the optimizer factory """ super().__init__( policy=policy, ) self.optim = self._create_optimizer(self.policy, optim) def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> ImitationTrainingStats: return self._imitation_update(batch, self.policy, self.optim) class OfflineImitationLearning( OfflineAlgorithm[ImitationPolicy], ImitationLearningAlgorithmMixin, ): """Implementation of offline vanilla imitation learning.""" def __init__( self, *, policy: ImitationPolicy, optim: OptimizerFactory, ) -> None: """ :param policy: the policy :param optim: the optimizer factory """ super().__init__( policy=policy, ) self.optim = self._create_optimizer(self.policy, optim) def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> ImitationTrainingStats: return self._imitation_update(batch, self.policy, self.optim) ================================================ FILE: tianshou/algorithm/imitation/td3_bc.py ================================================ import torch import torch.nn.functional as F from tianshou.algorithm import TD3 from tianshou.algorithm.algorithm_base import OfflineAlgorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.modelfree.td3 import TD3TrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import to_torch_as from tianshou.data.types import RolloutBatchProtocol # NOTE: This uses diamond inheritance to convert from off-policy to offline class TD3BC(OfflineAlgorithm[ContinuousDeterministicPolicy], TD3): # type: ignore """Implementation of TD3+BC. arXiv:2106.06860.""" def __init__( self, *, policy: ContinuousDeterministicPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module, critic_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, policy_noise: float = 0.2, update_actor_freq: int = 2, noise_clip: float = 0.5, alpha: float = 2.5, n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network. (s, a -> Q(s, a)). If None, copy the first critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, use the first critic's factory. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param exploration_noise: add noise to action for exploration. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). :param policy_noise: scaling factor for the Gaussian noise added to target policy actions. This parameter implements target policy smoothing, a regularization technique described in the TD3 paper. The noise is sampled from a normal distribution and multiplied by this value before being added to actions. Higher values increase exploration in the target policy, helping to address function approximation error. The added noise is optionally clipped to a range determined by the noise_clip parameter. Typically set between 0.1 and 0.5 relative to the action scale of the environment. :param update_actor_freq: the frequency of actor network updates relative to critic network updates (the actor network is only updated once for every `update_actor_freq` critic updates). This implements the "delayed" policy updates from the TD3 algorithm, where the actor is updated less frequently than the critics. Higher values (e.g., 2-5) help stabilize training by allowing the critic to become more accurate before updating the policy. The default value of 2 follows the original TD3 paper's recommendation of updating the policy at half the rate of the Q-functions. :param noise_clip: defines the maximum absolute value of the noise added to target policy actions, i.e. noise values are clipped to the range [-noise_clip, noise_clip] (after generating and scaling the noise via `policy_noise`). This parameter implements bounded target policy smoothing as described in the TD3 paper. It prevents extreme noise values from causing unrealistic target values during training. Setting it 0.0 (or a negative value) disables clipping entirely. It is typically set to about twice the `policy_noise` value (e.g. 0.5 when `policy_noise` is 0.2). :param alpha: the value of alpha, which controls the weight for TD3 learning relative to behavior cloning. """ TD3.__init__( self, policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, critic2=critic2, critic2_optim=critic2_optim, tau=tau, gamma=gamma, policy_noise=policy_noise, noise_clip=noise_clip, update_actor_freq=update_actor_freq, n_step_return_horizon=n_step_return_horizon, ) self.alpha = alpha def _update_with_batch(self, batch: RolloutBatchProtocol) -> TD3TrainingStats: # critic 1&2 td1, critic1_loss = self._minimize_critic_squared_loss( batch, self.critic, self.critic_optim ) td2, critic2_loss = self._minimize_critic_squared_loss( batch, self.critic2, self.critic2_optim ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor if self._cnt % self.update_actor_freq == 0: act = self.policy(batch, eps=0.0).act q_value = self.critic(batch.obs, act) lmbda = self.alpha / q_value.abs().mean().detach() actor_loss = -lmbda * q_value.mean() + F.mse_loss(act, to_torch_as(batch.act, act)) self._last = actor_loss.item() self.policy_optim.step(actor_loss) self._update_lagged_network_weights() self._cnt += 1 return TD3TrainingStats( actor_loss=self._last, critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), ) ================================================ FILE: tianshou/algorithm/modelbased/__init__.py ================================================ ================================================ FILE: tianshou/algorithm/modelbased/icm.py ================================================ import numpy as np import torch import torch.nn.functional as F from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( OffPolicyAlgorithm, OffPolicyWrapperAlgorithm, OnPolicyAlgorithm, OnPolicyWrapperAlgorithm, TPolicy, TrainingStats, TrainingStatsWrapper, ) from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import RolloutBatchProtocol from tianshou.utils.net.discrete import IntrinsicCuriosityModule class ICMTrainingStats(TrainingStatsWrapper): def __init__( self, wrapped_stats: TrainingStats, *, icm_loss: float, icm_forward_loss: float, icm_inverse_loss: float, ) -> None: self.icm_loss = icm_loss self.icm_forward_loss = icm_forward_loss self.icm_inverse_loss = icm_inverse_loss super().__init__(wrapped_stats) class _ICMMixin: """Implementation of the Intrinsic Curiosity Module (ICM) algorithm. arXiv:1705.05363.""" def __init__( self, *, model: IntrinsicCuriosityModule, optim: Algorithm.Optimizer, lr_scale: float, reward_scale: float, forward_loss_weight: float, ) -> None: """ :param model: the ICM model. :param optim: the optimizer factory. :param lr_scale: a multiplier that effectively scales the learning rate for the ICM model updates. Higher values increase the step size during optimization of the intrinsic curiosity module. Lower values decrease the step size, leading to more gradual learning of the curiosity mechanism. This parameter offers an alternative to directly adjusting the base learning rate in the optimizer. :param reward_scale: a multiplier that controls the magnitude of intrinsic rewards (curiosity-driven rewards generated by the agent itself) relative to extrinsic rewards (task-specific rewards provided by the environment). Scales the prediction error (curiosity signal) before adding it to the environment rewards. Higher values increase the agent's motivation to explore novel states. Lower values decrease the influence of curiosity relative to task-specific rewards. Setting to zero effectively disables intrinsic motivation while still learning the ICM model. :param forward_loss_weight: relative importance in [0, 1] of the forward model loss in relation to the inverse model loss. Controls the trade-off between state prediction and action prediction in the ICM algorithm. Higher values (> 0.5) prioritize learning to predict next states given current states and actions. Lower values (< 0.5) prioritize learning to predict actions given current and next states. The total loss combines both components: (1-forward_loss_weight)*inverse_loss + forward_loss_weight*forward_loss. """ self.model = model self.optim = optim self.lr_scale = lr_scale self.reward_scale = reward_scale self.forward_loss_weight = forward_loss_weight def _icm_preprocess_batch( self, batch: RolloutBatchProtocol, ) -> None: mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next) batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss) batch.rew += to_numpy(mse_loss * self.reward_scale) @staticmethod def _icm_postprocess_batch(batch: BatchProtocol) -> None: # restore original reward batch.rew = batch.policy.orig_rew def _icm_update( self, batch: RolloutBatchProtocol, original_stats: TrainingStats, ) -> ICMTrainingStats: act_hat = batch.policy.act_hat act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) inverse_loss = F.cross_entropy(act_hat, act).mean() forward_loss = batch.policy.mse_loss.mean() loss = ( (1 - self.forward_loss_weight) * inverse_loss + self.forward_loss_weight * forward_loss ) * self.lr_scale self.optim.step(loss) return ICMTrainingStats( original_stats, icm_loss=loss.item(), icm_forward_loss=forward_loss.item(), icm_inverse_loss=inverse_loss.item(), ) class ICMOffPolicyWrapper(OffPolicyWrapperAlgorithm[TPolicy], _ICMMixin): """Implementation of the Intrinsic Curiosity Module (ICM) algorithm for off-policy learning. arXiv:1705.05363.""" def __init__( self, *, wrapped_algorithm: OffPolicyAlgorithm[TPolicy], model: IntrinsicCuriosityModule, optim: OptimizerFactory, lr_scale: float, reward_scale: float, forward_loss_weight: float, ) -> None: """ :param wrapped_algorithm: the base algorithm to which we want to add the ICM. :param model: the ICM model. :param optim: the optimizer factory for the ICM model. :param lr_scale: a multiplier that effectively scales the learning rate for the ICM model updates. Higher values increase the step size during optimization of the intrinsic curiosity module. Lower values decrease the step size, leading to more gradual learning of the curiosity mechanism. This parameter offers an alternative to directly adjusting the base learning rate in the optimizer. :param reward_scale: a multiplier that controls the magnitude of intrinsic rewards (curiosity-driven rewards generated by the agent itself) relative to extrinsic rewards (task-specific rewards provided by the environment). Scales the prediction error (curiosity signal) before adding it to the environment rewards. Higher values increase the agent's motivation to explore novel states. Lower values decrease the influence of curiosity relative to task-specific rewards. Setting to zero effectively disables intrinsic motivation while still learning the ICM model. :param forward_loss_weight: relative importance in [0, 1] of the forward model loss in relation to the inverse model loss. Controls the trade-off between state prediction and action prediction in the ICM algorithm. Higher values (> 0.5) prioritize learning to predict next states given current states and actions. Lower values (< 0.5) prioritize learning to predict actions given current and next states. The total loss combines both components: (1-forward_loss_weight)*inverse_loss + forward_loss_weight*forward_loss. """ OffPolicyWrapperAlgorithm.__init__( self, wrapped_algorithm=wrapped_algorithm, ) _ICMMixin.__init__( self, model=model, optim=self._create_optimizer(model, optim), lr_scale=lr_scale, reward_scale=reward_scale, forward_loss_weight=forward_loss_weight, ) def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: self._icm_preprocess_batch(batch) return super()._preprocess_batch(batch, buffer, indices) def _postprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: super()._postprocess_batch(batch, buffer, indices) self._icm_postprocess_batch(batch) def _wrapper_update_with_batch( self, batch: RolloutBatchProtocol, original_stats: TrainingStats, ) -> ICMTrainingStats: return self._icm_update(batch, original_stats) class ICMOnPolicyWrapper(OnPolicyWrapperAlgorithm[TPolicy], _ICMMixin): """Implementation of the Intrinsic Curiosity Module (ICM) algorithm for on-policy learning. arXiv:1705.05363.""" def __init__( self, *, wrapped_algorithm: OnPolicyAlgorithm[TPolicy], model: IntrinsicCuriosityModule, optim: OptimizerFactory, lr_scale: float, reward_scale: float, forward_loss_weight: float, ) -> None: """ :param wrapped_algorithm: the base algorithm to which we want to add the ICM. :param model: the ICM model. :param optim: the optimizer factory for the ICM model. :param lr_scale: a multiplier that effectively scales the learning rate for the ICM model updates. Higher values increase the step size during optimization of the intrinsic curiosity module. Lower values decrease the step size, leading to more gradual learning of the curiosity mechanism. This parameter offers an alternative to directly adjusting the base learning rate in the optimizer. :param reward_scale: a multiplier that controls the magnitude of intrinsic rewards (curiosity-driven rewards generated by the agent itself) relative to extrinsic rewards (task-specific rewards provided by the environment). Scales the prediction error (curiosity signal) before adding it to the environment rewards. Higher values increase the agent's motivation to explore novel states. Lower values decrease the influence of curiosity relative to task-specific rewards. Setting to zero effectively disables intrinsic motivation while still learning the ICM model. :param forward_loss_weight: relative importance in [0, 1] of the forward model loss in relation to the inverse model loss. Controls the trade-off between state prediction and action prediction in the ICM algorithm. Higher values (> 0.5) prioritize learning to predict next states given current states and actions. Lower values (< 0.5) prioritize learning to predict actions given current and next states. The total loss combines both components: (1-forward_loss_weight)*inverse_loss + forward_loss_weight*forward_loss. """ OnPolicyWrapperAlgorithm.__init__( self, wrapped_algorithm=wrapped_algorithm, ) _ICMMixin.__init__( self, model=model, optim=self._create_optimizer(model, optim), lr_scale=lr_scale, reward_scale=reward_scale, forward_loss_weight=forward_loss_weight, ) def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: self._icm_preprocess_batch(batch) return super()._preprocess_batch(batch, buffer, indices) def _postprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: super()._postprocess_batch(batch, buffer, indices) self._icm_postprocess_batch(batch) def _wrapper_update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, original_stats: TrainingStats, ) -> ICMTrainingStats: return self._icm_update(batch, original_stats) ================================================ FILE: tianshou/algorithm/modelbased/psrl.py ================================================ from dataclasses import dataclass from typing import Any, cast import gymnasium as gym import numpy as np import torch from tianshou.algorithm.algorithm_base import ( OnPolicyAlgorithm, Policy, TrainingStats, ) from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol @dataclass(kw_only=True) class PSRLTrainingStats(TrainingStats): psrl_rew_mean: float = 0.0 psrl_rew_std: float = 0.0 class PSRLModel: """Implementation of Posterior Sampling Reinforcement Learning Model.""" def __init__( self, trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, gamma: float, epsilon: float, ) -> None: """ :param trans_count_prior: dirichlet prior (alphas), with shape (n_state, n_action, n_state). :param rew_mean_prior: means of the normal priors of rewards, with shape (n_state, n_action). :param rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param epsilon: for precision control in value iteration. """ self.trans_count = trans_count_prior self.n_state, self.n_action = rew_mean_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior self.rew_square_sum = np.zeros_like(rew_mean_prior) self.rew_std_prior = rew_std_prior self.gamma = gamma self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight self.eps = epsilon self.policy: np.ndarray self.value = np.zeros(self.n_state) self.updated = False self.__eps = np.finfo(np.float32).eps.item() def observe( self, trans_count: np.ndarray, rew_sum: np.ndarray, rew_square_sum: np.ndarray, rew_count: np.ndarray, ) -> None: """Add data into memory pool. For rewards, we have a normal prior at first. After we observed a reward for a given state-action pair, we use the mean value of our observations instead of the prior mean as the posterior mean. The standard deviations are in inverse proportion to the number of the corresponding observations. :param trans_count: the number of observations, with shape (n_state, n_action, n_state). :param rew_sum: total rewards, with shape (n_state, n_action). :param rew_square_sum: total rewards' squares, with shape (n_state, n_action). :param rew_count: the number of rewards, with shape (n_state, n_action). """ self.updated = False self.trans_count += trans_count sum_count = self.rew_count + rew_count self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count self.rew_square_sum += rew_square_sum raw_std2 = self.rew_square_sum / sum_count - self.rew_mean**2 self.rew_std = np.sqrt( 1 / (sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior**2), ) self.rew_count = sum_count def sample_trans_prob(self) -> np.ndarray: return torch.distributions.Dirichlet(torch.from_numpy(self.trans_count)).sample().numpy() def sample_reward(self) -> np.ndarray: return np.random.normal(self.rew_mean, self.rew_std) def solve_policy(self) -> None: self.updated = True self.policy, self.value = self.value_iteration( self.sample_trans_prob(), self.sample_reward(), self.gamma, self.eps, self.value, ) @staticmethod def value_iteration( trans_prob: np.ndarray, rew: np.ndarray, gamma: float, eps: float, value: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Value iteration solver for MDPs. :param trans_prob: transition probabilities, with shape (n_state, n_action, n_state). :param rew: rewards, with shape (n_state, n_action). :param eps: for precision control. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param value: the initialize value of value array, with shape (n_state, ). :return: the optimal policy with shape (n_state, ). """ Q = rew + gamma * trans_prob.dot(value) new_value = Q.max(axis=1) while not np.allclose(new_value, value, eps): value = new_value Q = rew + gamma * trans_prob.dot(value) new_value = Q.max(axis=1) # this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly Q += eps * np.random.randn(*Q.shape) return Q.argmax(axis=1), new_value def __call__( self, obs: np.ndarray, state: Any = None, info: Any = None, ) -> np.ndarray: if not self.updated: self.solve_policy() return self.policy[obs] class PSRLPolicy(Policy): def __init__( self, *, trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, action_space: gym.spaces.Discrete, discount_factor: float = 0.99, epsilon: float = 0.01, observation_space: gym.Space | None = None, ) -> None: """ :param trans_count_prior: dirichlet prior (alphas), with shape (n_state, n_action, n_state). :param rew_mean_prior: means of the normal priors of rewards, with shape (n_state, n_action). :param rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). :param action_space: the environment's action_space. :param epsilon: for precision control in value iteration. :param observation_space: the environment's observation space """ super().__init__( action_space=action_space, observation_space=observation_space, action_scaling=False, action_bound_method=None, ) self.model = PSRLModel( trans_count_prior, rew_mean_prior, rew_std_prior, discount_factor, epsilon, ) def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ActBatchProtocol: """Compute action over the given batch data with PSRL model. :return: A :class:`~tianshou.data.Batch` with "act" key containing the action. """ assert isinstance(batch.obs, np.ndarray), "only support np.ndarray observation" # TODO: shouldn't the model output a state as well if state is passed (i.e. RNNs are involved)? act = self.model(batch.obs, state=state, info=batch.info) return cast(ActBatchProtocol, Batch(act=act)) class PSRL(OnPolicyAlgorithm[PSRLPolicy]): """Implementation of Posterior Sampling Reinforcement Learning (PSRL). Reference: Strens M., A Bayesian Framework for Reinforcement Learning, ICML, 2000. """ def __init__( self, *, policy: PSRLPolicy, add_done_loop: bool = False, ) -> None: """ :param policy: the policy :param add_done_loop: a flag indicating whether to add a self-loop transition for terminal states in the MDP. If True, whenever an episode terminates, an artificial transition from the terminal state back to itself is added to the transition counts for all actions. This modification can help stabilize learning for terminal states that have limited samples. Setting to True can be beneficial in environments where episodes frequently terminate, ensuring that terminal states receive sufficient updates to their value estimates. Default is False, which preserves the standard MDP formulation without artificial self-loops. """ super().__init__( policy=policy, ) self._add_done_loop = add_done_loop def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int ) -> PSRLTrainingStats: # NOTE: In contrast to other on-policy algorithms, this algorithm ignores # the batch_size and repeat arguments. # PSRL, being a Bayesian approach, updates its posterior distribution of # the MDP parameters based on the collected transition data as a whole, # rather than performing gradient-based updates that benefit from mini-batching. n_s, n_a = self.policy.model.n_state, self.policy.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) rew_square_sum = np.zeros((n_s, n_a)) rew_count = np.zeros((n_s, n_a)) for minibatch in batch.split(size=1): obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next obs_next = cast(np.ndarray, obs_next) assert not isinstance(obs, Batch), "Observations cannot be Batches here" obs = cast(np.ndarray, obs) trans_count[obs, act, obs_next] += 1 rew_sum[obs, act] += minibatch.rew rew_square_sum[obs, act] += minibatch.rew**2 rew_count[obs, act] += 1 if self._add_done_loop and minibatch.done: # special operation for terminal states: add a self-loop trans_count[obs_next, :, obs_next] += 1 rew_count[obs_next, :] += 1 self.policy.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) return PSRLTrainingStats( psrl_rew_mean=float(self.policy.model.rew_mean.mean()), psrl_rew_std=float(self.policy.model.rew_std.mean()), ) ================================================ FILE: tianshou/algorithm/modelfree/__init__.py ================================================ ================================================ FILE: tianshou/algorithm/modelfree/a2c.py ================================================ from abc import ABC from dataclasses import dataclass from typing import cast import numpy as np import torch import torch.nn.functional as F from tianshou.algorithm.algorithm_base import ( OnPolicyAlgorithm, TrainingStats, ) from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic @dataclass(kw_only=True) class A2CTrainingStats(TrainingStats): loss: SequenceSummaryStats actor_loss: SequenceSummaryStats vf_loss: SequenceSummaryStats ent_loss: SequenceSummaryStats gradient_steps: int class ActorCriticOnPolicyAlgorithm(OnPolicyAlgorithm[ProbabilisticActorPolicy], ABC): """Abstract base class for actor-critic algorithms that use generalized advantage estimation (GAE).""" def __init__( self, *, policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, optim_include_actor: bool, max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, return_scaling: bool = False, ) -> None: """ :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory. :param optim_include_actor: whether the optimizer shall include the actor network's parameters. Pass False for algorithms that shall update only the critic via the optimizer. :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. When not None, gradients will be rescaled using to ensure their L2 norm does not exceed this value. This prevents exploding gradients and stabilizes training by limiting the magnitude of parameter updates. Set to None to disable gradient clipping. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values (closer to 1) reduce bias but increase variance by giving more weight to longer trajectories, while lower values (closer to 0) reduce variance but increase bias by relying more on the immediate TD error and value function estimates. At λ=0, GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. :param max_batchsize: the maximum number of samples to process at once when computing generalized advantage estimation (GAE) and value function predictions. Controls memory usage by breaking large batches into smaller chunks processed sequentially. Higher values may increase speed but require more GPU/CPU memory; lower values reduce memory requirements but may increase computation time. Should be adjusted based on available hardware resources and total batch size of your training data. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param return_scaling: flag indicating whether to enable scaling of estimated returns by dividing them by their running standard deviation without centering the mean. This reduces the magnitude variation of advantages across different episodes while preserving their signs and relative ordering. The use of running statistics (rather than batch-specific scaling) means that early training experiences may be scaled differently than later ones as the statistics evolve. When enabled, this improves training stability in environments with highly variable reward scales and makes the algorithm less sensitive to learning rate settings. However, it may reduce the algorithm's ability to distinguish between episodes with different absolute return magnitudes. Best used in environments where the relative ordering of actions is more important than the absolute scale of returns. """ super().__init__( policy=policy, ) self.critic = critic assert 0.0 <= gae_lambda <= 1.0, f"GAE lambda should be in [0, 1] but got: {gae_lambda}" self.gae_lambda = gae_lambda self.max_batchsize = max_batchsize if optim_include_actor: self.optim = self._create_optimizer( ActorCritic(self.policy.actor, self.critic), optim, max_grad_norm=max_grad_norm, ) else: self.optim = self._create_optimizer(self.critic, optim, max_grad_norm=max_grad_norm) self.gamma = gamma self.return_scaling = return_scaling self.ret_rms = RunningMeanStd() self._eps = 1e-8 def _add_returns_and_advantages( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithAdvantagesProtocol: """Adds the returns and advantages to the given batch.""" v_s, v_s_ = [], [] with torch.no_grad(): for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): v_s.append(self.critic(minibatch.obs)) v_s_.append(self.critic(minibatch.obs_next)) batch.v_s = torch.cat(v_s, dim=0).flatten() # old value v_s = batch.v_s.cpu().numpy() v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy() # when normalizing values, we do not minus self.ret_rms.mean to be numerically # consistent with OPENAI baselines' value normalization pipeline. Empirical # study also shows that "minus mean" will harm performances a tiny little bit # due to unknown reasons (on Mujoco envs, not confident, though). if self.return_scaling: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) unnormalized_returns, advantages = self.compute_episodic_return( batch, buffer, indices, v_s_, v_s, gamma=self.gamma, gae_lambda=self.gae_lambda, ) if self.return_scaling: batch.returns = unnormalized_returns / np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns batch.returns = to_torch_as(batch.returns, batch.v_s) batch.adv = to_torch_as(advantages, batch.v_s) return cast(BatchWithAdvantagesProtocol, batch) class A2C(ActorCriticOnPolicyAlgorithm): """Implementation of (synchronous) Advantage Actor-Critic (A2C). arXiv:1602.01783.""" def __init__( self, *, policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, return_scaling: bool = False, ) -> None: """ :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory. :param vf_coef: coefficient that weights the value loss relative to the actor loss in the overall loss function. Higher values prioritize accurate value function estimation over policy improvement. Controls the trade-off between policy optimization and value function fitting. Typically set between 0.5 and 1.0 for most actor-critic implementations. :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. Controls the exploration-exploitation trade-off by encouraging policy entropy. Higher values promote more exploration by encouraging a more uniform action distribution. Lower values focus more on exploitation of the current policy's knowledge. Typically set between 0.01 and 0.05 for most actor-critic implementations. :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. When not None, gradients will be rescaled using to ensure their L2 norm does not exceed this value. This prevents exploding gradients and stabilizes training by limiting the magnitude of parameter updates. Set to None to disable gradient clipping. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values (closer to 1) reduce bias but increase variance by giving more weight to longer trajectories, while lower values (closer to 0) reduce variance but increase bias by relying more on the immediate TD error and value function estimates. At λ=0, GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. :param max_batchsize: the maximum size of the batch when computing GAE. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param return_scaling: flag indicating whether to enable scaling of estimated returns by dividing them by their running standard deviation without centering the mean. This reduces the magnitude variation of advantages across different episodes while preserving their signs and relative ordering. The use of running statistics (rather than batch-specific scaling) means that early training experiences may be scaled differently than later ones as the statistics evolve. When enabled, this improves training stability in environments with highly variable reward scales and makes the algorithm less sensitive to learning rate settings. However, it may reduce the algorithm's ability to distinguish between episodes with different absolute return magnitudes. Best used in environments where the relative ordering of actions is more important than the absolute scale of returns. """ super().__init__( policy=policy, critic=critic, optim=optim, optim_include_actor=True, max_grad_norm=max_grad_norm, gae_lambda=gae_lambda, max_batchsize=max_batchsize, gamma=gamma, return_scaling=return_scaling, ) self.vf_coef = vf_coef self.ent_coef = ent_coef self.max_grad_norm = max_grad_norm def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithAdvantagesProtocol: batch = self._add_returns_and_advantages(batch, buffer, indices) batch.act = to_torch_as(batch.act, batch.v_s) return batch def _update_with_batch( # type: ignore[override] self, batch: BatchWithAdvantagesProtocol, batch_size: int | None, repeat: int, ) -> A2CTrainingStats: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 gradient_steps = 0 for _ in range(repeat): for minibatch in batch.split(split_batch_size, merge_last=True): gradient_steps += 1 # calculate loss for actor dist = self.policy(minibatch).dist log_prob = dist.log_prob(minibatch.act) log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1) actor_loss = -(log_prob * minibatch.adv).mean() # calculate loss for critic value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = actor_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss self.optim.step(loss) actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) loss_summary_stat = SequenceSummaryStats.from_sequence(losses) actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) ent_loss_summary_stat = SequenceSummaryStats.from_sequence(ent_losses) return A2CTrainingStats( loss=loss_summary_stat, actor_loss=actor_loss_summary_stat, vf_loss=vf_loss_summary_stat, ent_loss=ent_loss_summary_stat, gradient_steps=gradient_steps, ) ================================================ FILE: tianshou/algorithm/modelfree/bdqn.py ================================================ from typing import cast import gymnasium as gym import numpy as np import torch from sensai.util.helper import mark_used from tianshou.algorithm.algorithm_base import TArrOrActBatch from tianshou.algorithm.modelfree.dqn import ( DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( ActBatchProtocol, BatchWithReturnsProtocol, ModelOutputBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.utils.net.common import BranchingNet mark_used(ActBatchProtocol) class BDQNPolicy(DiscreteQLearningPolicy[BranchingNet]): def __init__( self, *, model: BranchingNet, action_space: gym.spaces.Discrete, observation_space: gym.Space | None = None, eps_training: float = 0.0, eps_inference: float = 0.0, ): """ :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. :param action_space: the environment's action space :param observation_space: the environment's observation space. :param eps_training: the epsilon value for epsilon-greedy exploration during training. When collecting data for training, this is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, i.e. non-training cases (such as evaluation during test steps). The epsilon value is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). """ super().__init__( model=model, action_space=action_space, observation_space=observation_space, eps_training=eps_training, eps_inference=eps_inference, ) def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, model: torch.nn.Module | None = None, ) -> ModelOutputBatchProtocol: if model is None: model = self.model assert model is not None obs = batch.obs # TODO: this is very contrived, see also iqn.py obs_next_BO = obs.obs if hasattr(obs, "obs") else obs action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info) act_B = to_numpy(action_values_BA.argmax(dim=-1)) result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) def add_exploration_noise( self, act: TArrOrActBatch, batch: ObsBatchProtocol, ) -> TArrOrActBatch: eps = self.eps_training if self.is_within_training_step else self.eps_inference if np.isclose(eps, 0.0): return act if isinstance(act, np.ndarray): bsz = len(act) rand_mask = np.random.rand(bsz) < eps rand_act = np.random.randint( low=0, high=self.model.action_per_branch, size=(bsz, act.shape[-1]), ) if hasattr(batch.obs, "mask"): rand_act += batch.obs.mask act[rand_mask] = rand_act[rand_mask] return act # type: ignore[return-value] else: raise NotImplementedError( f"Currently only numpy arrays are supported, got {type(act)=}." ) class BDQN(QLearningOffPolicyAlgorithm[BDQNPolicy]): """Implementation of the Branching Dueling Q-Network (BDQN) algorithm arXiv:1711.08946.""" def __init__( self, *, policy: BDQNPolicy, optim: OptimizerFactory, gamma: float = 0.99, target_update_freq: int = 0, is_double: bool = True, ) -> None: """ :param policy: policy :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. :param is_double: flag indicating whether to use Double Q-learning for target value calculation. If True, the algorithm uses the online network to select actions and the target network to evaluate their Q-values. This decoupling helps reduce the overestimation bias that standard Q-learning is prone to. If False, the algorithm selects actions by directly taking the maximum Q-value from the target network. Note: This parameter is most effective when used with a target network (target_update_freq > 0). """ super().__init__( policy=policy, optim=optim, gamma=gamma, # BDQN implements its own returns computation (below), which supports only 1-step returns n_step_return_horizon=1, target_update_freq=target_update_freq, ) self.is_double = is_double def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} result = self.policy(obs_next_batch) if self.use_target_network: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) target_q = self.policy(obs_next_batch, model=self.model_old).logits else: target_q = result.logits if self.is_double: act = np.expand_dims(self.policy(obs_next_batch).act, -1) act = to_torch(act, dtype=torch.long, device=target_q.device) else: act = target_q.max(-1).indices.unsqueeze(-1) return torch.gather(target_q, -1, act).squeeze() def _compute_return( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indice: np.ndarray, gamma: float = 0.99, ) -> BatchWithReturnsProtocol: rew = batch.rew with torch.no_grad(): target_q_torch = self._target_q(buffer, indice) # (bsz, ?) target_q = to_numpy(target_q_torch) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True end_flag = end_flag[indice] mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q _target_q = rew + gamma * mean_target_q * (1 - end_flag) target_q = np.repeat(_target_q[..., None], self.policy.model.num_branches, axis=-1) target_q = np.repeat(target_q[..., None], self.policy.model.action_per_branch, axis=-1) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) return cast(BatchWithReturnsProtocol, batch) def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithReturnsProtocol: """Compute the 1-step return for BDQ targets.""" return self._compute_return(batch, buffer, indices) def _update_with_batch( # type: ignore[override] self, batch: BatchWithReturnsProtocol, ) -> SimpleLossTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) q = self.policy(batch).logits act_mask = torch.zeros_like(q) act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1) act_q = q * act_mask returns = batch.returns returns = returns * act_mask td_error = returns - act_q loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean() batch.weight = td_error.sum(-1).sum(-1) # prio-buffer self.optim.step(loss) return SimpleLossTrainingStats(loss=loss.item()) ================================================ FILE: tianshou/algorithm/modelfree/c51.py ================================================ import gymnasium as gym import numpy as np import torch from tianshou.algorithm.modelfree.dqn import ( DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) from tianshou.algorithm.modelfree.reinforce import LossSequenceTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol from tianshou.utils.net.common import Net class C51Policy(DiscreteQLearningPolicy): def __init__( self, model: torch.nn.Module | Net, action_space: gym.spaces.Space, observation_space: gym.Space | None = None, num_atoms: int = 51, v_min: float = -10.0, v_max: float = 10.0, eps_training: float = 0.0, eps_inference: float = 0.0, ): """ :param model: a model following the rules (s_B -> action_values_BA) :param num_atoms: the number of atoms in the support set of the value distribution. Default to 51. :param v_min: the value of the smallest atom in the support set. Default to -10.0. :param v_max: the value of the largest atom in the support set. Default to 10.0. :param eps_training: the epsilon value for epsilon-greedy exploration during training. When collecting data for training, this is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, i.e. non-training cases (such as evaluation during test steps). The epsilon value is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). """ assert isinstance(action_space, gym.spaces.Discrete) super().__init__( model=model, action_space=action_space, observation_space=observation_space, eps_training=eps_training, eps_inference=eps_inference, ) assert num_atoms > 1, f"num_atoms should be greater than 1 but got: {num_atoms}" assert v_min < v_max, f"v_max should be larger than v_min, but got {v_min=} and {v_max=}" self.num_atoms = num_atoms self.v_min = v_min self.v_max = v_max self.support = torch.nn.Parameter( torch.linspace(self.v_min, self.v_max, self.num_atoms), requires_grad=False, ) def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: return super().compute_q_value((logits * self.support).sum(2), mask) class C51(QLearningOffPolicyAlgorithm[C51Policy]): """Implementation of Categorical Deep Q-Network. arXiv:1707.06887.""" def __init__( self, *, policy: C51Policy, optim: OptimizerFactory, gamma: float = 0.99, n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ :param policy: a policy following the rules (s -> action_values_BA) :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. """ super().__init__( policy=policy, optim=optim, gamma=gamma, n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.delta_z = (policy.v_max - policy.v_min) / (policy.num_atoms - 1) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: return self.policy.support.repeat(len(indices), 1) # shape: [bsz, num_atoms] def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: obs_next_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) if self.use_target_network: act = self.policy(obs_next_batch).act next_dist = self.policy(obs_next_batch, model=self.model_old).logits else: next_batch = self.policy(obs_next_batch) act = next_batch.act next_dist = next_batch.logits next_dist = next_dist[np.arange(len(act)), act, :] target_support = batch.returns.clamp(self.policy.v_min, self.policy.v_max) # An amazing trick for calculating the projection gracefully. # ref: https://github.com/ShangtongZhang/DeepRL target_dist = ( 1 - (target_support.unsqueeze(1) - self.policy.support.view(1, -1, 1)).abs() / self.delta_z ).clamp(0, 1) * next_dist.unsqueeze(1) return target_dist.sum(-1) def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> LossSequenceTrainingStats: self._periodically_update_lagged_network_weights() with torch.no_grad(): target_dist = self._target_dist(batch) weight = batch.pop("weight", 1.0) curr_dist = self.policy(batch).logits act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :] cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1) loss = (cross_entropy * weight).mean() # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 batch.weight = cross_entropy.detach() # prio-buffer self.optim.step(loss) return LossSequenceTrainingStats(loss=loss.item()) ================================================ FILE: tianshou/algorithm/modelfree/ddpg.py ================================================ import warnings from abc import ABC from dataclasses import dataclass from typing import Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np import torch from sensai.util.helper import mark_used from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OffPolicyAlgorithm, Policy, TArrOrActBatch, TPolicy, TrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( ActBatchProtocol, ActStateBatchProtocol, BatchWithReturnsProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.utils.net.continuous import ( AbstractContinuousActorDeterministic, ContinuousCritic, ) mark_used(ActBatchProtocol) @dataclass(kw_only=True) class DDPGTrainingStats(TrainingStats): actor_loss: float critic_loss: float class ContinuousPolicyWithExplorationNoise(Policy, ABC): def __init__( self, *, exploration_noise: BaseNoise | Literal["default"] | None = None, action_space: gym.Space, observation_space: gym.Space | None = None, action_scaling: bool = True, action_bound_method: Literal["clip"] | None = "clip", ): """ :param exploration_noise: noise model for adding noise to continuous actions for exploration. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). :param action_space: the environment's action_space. :param observation_space: the environment's observation space :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. This applies to continuous action spaces only (gym.spaces.Box) and has no effect for discrete spaces. When enabled, policy outputs are expected to be in the normalized range [-1, 1] (after bounding), and are then linearly transformed to the actual required range. This improves neural network training stability, allows the same algorithm to work across environments with different action ranges, and standardizes exploration strategies. Should be disabled if the actor model already produces outputs in the correct range. :param action_bound_method: the method used for bounding actions in continuous action spaces to the range [-1, 1] before scaling them to the environment's action space (provided that `action_scaling` is enabled). This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None for discrete spaces. When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly constrains outputs to [-1, 1] while preserving gradients. The choice of bounding method affects both training dynamics and exploration behavior. Clipping provides hard boundaries but may create plateau regions in the gradient landscape, while tanh provides smoother transitions but can compress sensitivity near the boundaries. Should be set to None if the actor model inherently produces bounded outputs. Typically used together with `action_scaling=True`. """ super().__init__( action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, action_bound_method=action_bound_method, ) if exploration_noise == "default": exploration_noise = GaussianNoise(sigma=0.1) self.exploration_noise = exploration_noise def set_exploration_noise(self, noise: BaseNoise | None) -> None: """Set the exploration noise.""" self.exploration_noise = noise def add_exploration_noise( self, act: TArrOrActBatch, batch: ObsBatchProtocol, ) -> TArrOrActBatch: if self.exploration_noise is None: return act if isinstance(act, np.ndarray): return act + self.exploration_noise(act.shape) warnings.warn("Cannot add exploration noise to non-numpy_array action.") return act class ContinuousDeterministicPolicy(ContinuousPolicyWithExplorationNoise): """A policy for continuous action spaces that uses an actor which directly maps states to actions.""" def __init__( self, *, actor: AbstractContinuousActorDeterministic, exploration_noise: BaseNoise | Literal["default"] | None = None, action_space: gym.Space, observation_space: gym.Space | None = None, action_scaling: bool = True, action_bound_method: Literal["clip"] | None = "clip", ): """ :param actor: The actor network following the rules (s -> actions) :param exploration_noise: add noise to continuous actions for exploration; set to None for discrete action spaces. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). :param action_space: the environment's action space. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param observation_space: the environment's observation space. :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. This applies to continuous action spaces only (gym.spaces.Box) and has no effect for discrete spaces. When enabled, policy outputs are expected to be in the normalized range [-1, 1] (after bounding), and are then linearly transformed to the actual required range. This improves neural network training stability, allows the same algorithm to work across environments with different action ranges, and standardizes exploration strategies. Should be disabled if the actor model already produces outputs in the correct range. :param action_bound_method: method to bound action to range [-1, 1]. """ if action_scaling and not np.isclose(actor.max_action, 1.0): warnings.warn( "action_scaling and action_bound_method are only intended to deal" "with unbounded model action space, but find actor model bound" f"action space with max_action={actor.max_action}." "Consider using unbounded=True option of the actor model," "or set action_scaling to False and action_bound_method to None.", ) super().__init__( exploration_noise=exploration_noise, action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, action_bound_method=action_bound_method, ) self.actor = actor def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, model: torch.nn.Module | None = None, **kwargs: Any, ) -> ActStateBatchProtocol: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 2 keys: * ``act`` the action. * ``state`` the hidden state. """ if model is None: model = self.actor actions, hidden = model(batch.obs, state=state, info=batch.info) return cast(ActStateBatchProtocol, Batch(act=actions, state=hidden)) TActBatchProtocol = TypeVar("TActBatchProtocol", bound=ActBatchProtocol) class ActorCriticOffPolicyAlgorithm( OffPolicyAlgorithm[TPolicy], LaggedNetworkPolyakUpdateAlgorithmMixin, Generic[TPolicy, TActBatchProtocol], ABC, ): """Base class for actor-critic off-policy algorithms that use a lagged critic as a target network. Its implementation of `process_fn` adds the n-step return to the batch, using the Q-values computed by the target network (lagged critic, `critic_old`) in order to compute the reward-to-go. Specializations can override the action computation (`_target_q_compute_action`) or the Q-value computation based on these actions (`_target_q_compute_value`) to customize the target Q-value computation. The default implementation assumes a continuous action space where a critic receives a state/observation and an action; for discrete action spaces, where the critic receives only a state/observation, the method `_target_q_compute_value` must be overridden. """ def __init__( self, *, policy: TPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module, critic_optim: OptimizerFactory, tau: float = 0.005, gamma: float = 0.99, n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy :param policy_optim: the optimizer factory for the policy's model. :param critic: the critic network. For continuous action spaces: (s, a -> Q(s, a)). For discrete action spaces: (s -> ). **NOTE**: The default implementation of `_target_q_compute_value` assumes a continuous action space; override this method if using discrete actions. :param critic_optim: the optimizer factory for the critic network. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks """ assert 0.0 <= tau <= 1.0, f"tau should be in [0, 1] but got: {tau}" assert 0.0 <= gamma <= 1.0, f"gamma should be in [0, 1] but got: {gamma}" super().__init__( policy=policy, ) LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) self.policy_optim = self._create_optimizer(policy, policy_optim) self.critic = critic self.critic_old = self._add_lagged_network(self.critic) self.critic_optim = self._create_optimizer(self.critic, critic_optim) self.gamma = gamma self.n_step_return_horizon = n_step_return_horizon @staticmethod def _minimize_critic_squared_loss( batch: RolloutBatchProtocol, critic: torch.nn.Module, optimizer: Algorithm.Optimizer, ) -> tuple[torch.Tensor, torch.Tensor]: """Takes an optimizer step to minimize the squared loss of the critic given a batch of data. :param batch: the batch containing the observations, actions, returns, and (optionally) weights. :param critic: the critic network to minimize the loss for. :param optimizer: the optimizer for the critic's parameters. :return: a pair (`td`, `loss`), where `td` is the tensor of errors (current - target) and `loss` is the MSE loss. """ weight = getattr(batch, "weight", 1.0) current_q = critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * weight).mean() optimizer.step(critic_loss) return td, critic_loss def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol | BatchWithReturnsProtocol: # add the n-step return to the batch, which the critic (Q-functions) seeks to match, # based the Q-values computed by the target network (lagged critic) return self.compute_nstep_return( batch=batch, buffer=buffer, indices=indices, target_q_fn=self._target_q, gamma=self.gamma, n_step=self.n_step_return_horizon, ) def _target_q_compute_action(self, obs_batch: Batch) -> TActBatchProtocol: """ Computes the action to be taken for the given batch (containing the observations) within the context of Q-value target computation. :param obs_batch: the batch containing the observations. :return: batch containing the actions to be taken. """ return self.policy(obs_batch) def _target_q_compute_value( self, obs_batch: Batch, act_batch: TActBatchProtocol ) -> torch.Tensor: """ Computes the target Q-value given a batch with observations and actions taken. :param obs_batch: the batch containing the observations. :param act_batch: the batch containing the actions taken. :return: a tensor containing the target Q-values. """ # compute the target Q-value using the lagged critic network (target network) return self.critic_old(obs_batch.obs, act_batch.act) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: """ Computes the target Q-value for the given buffer and indices. :param buffer: the replay buffer :param indices: the indices within the buffer to compute the target Q-value for """ obs_next_batch = Batch( obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} act_batch = self._target_q_compute_action(obs_next_batch) return self._target_q_compute_value(obs_next_batch, act_batch) class DDPG( ActorCriticOffPolicyAlgorithm[ContinuousDeterministicPolicy, ActBatchProtocol], ): """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.""" def __init__( self, *, policy: ContinuousDeterministicPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module | ContinuousCritic, critic_optim: OptimizerFactory, tau: float = 0.005, gamma: float = 0.99, n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy :param policy_optim: the optimizer factory for the policy's model. :param critic: the critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer factory for the critic network. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. """ super().__init__( policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, tau=tau, gamma=gamma, n_step_return_horizon=n_step_return_horizon, ) self.actor_old = self._add_lagged_network(self.policy.actor) def _target_q_compute_action(self, obs_batch: Batch) -> ActBatchProtocol: # compute the action using the lagged actor network return self.policy(obs_batch, model=self.actor_old) def _update_with_batch(self, batch: RolloutBatchProtocol) -> DDPGTrainingStats: # critic td, critic_loss = self._minimize_critic_squared_loss(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer # actor actor_loss = -self.critic(batch.obs, self.policy(batch).act).mean() self.policy_optim.step(actor_loss) self._update_lagged_network_weights() return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) ================================================ FILE: tianshou/algorithm/modelfree/discrete_sac.py ================================================ from dataclasses import dataclass from typing import Any, TypeVar, cast import gymnasium as gym import numpy as np import torch from torch.distributions import Categorical from tianshou.algorithm.algorithm_base import Policy from tianshou.algorithm.modelfree.sac import Alpha, SACTrainingStats from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( DistBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.utils.net.discrete import DiscreteCritic @dataclass class DiscreteSACTrainingStats(SACTrainingStats): pass TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteSACTrainingStats) class DiscreteSACPolicy(Policy): def __init__( self, *, actor: torch.nn.Module, deterministic_eval: bool = True, action_space: gym.Space, observation_space: gym.Space | None = None, ): """ :param actor: the actor network following the rules (s -> dist_input_BD), where the distribution input is for a `Categorical` distribution. :param deterministic_eval: flag indicating whether the policy should use deterministic actions (using the mode of the action distribution) instead of stochastic ones (using random sampling) during evaluation. When enabled, the policy will always select the most probable action according to the learned distribution during evaluation phases, while still using stochastic sampling during training. This creates a clear distinction between exploration (training) and exploitation (evaluation) behaviors. Deterministic actions are generally preferred for final deployment and reproducible evaluation as they provide consistent behavior, reduce variance in performance metrics, and are more interpretable for human observers. Note that this parameter only affects behavior when the policy is not within a training step. When collecting rollouts for training, actions remain stochastic regardless of this setting to maintain proper exploration behaviour. :param action_space: the environment's action_space. :param observation_space: the environment's observation space """ assert isinstance(action_space, gym.spaces.Discrete) super().__init__( action_space=action_space, observation_space=observation_space, ) self.actor = actor self.deterministic_eval = deterministic_eval def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> Batch: logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Categorical(logits=logits_BA) act_B = ( dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() ) return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) class DiscreteSAC(ActorDualCriticsOffPolicyAlgorithm[DiscreteSACPolicy, DistBatchProtocol]): """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.""" def __init__( self, *, policy: DiscreteSACPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module | DiscreteCritic, critic_optim: OptimizerFactory, critic2: torch.nn.Module | DiscreteCritic | None = None, critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, alpha: float | Alpha = 0.2, n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. (s -> ). :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network. (s -> ). If None, copy the first critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, use the first critic's factory. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. """ super().__init__( policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, critic2=critic2, critic2_optim=critic2_optim, tau=tau, gamma=gamma, n_step_return_horizon=n_step_return_horizon, ) self.alpha = Alpha.from_float_or_instance(alpha) def _target_q_compute_value( self, obs_batch: Batch, act_batch: DistBatchProtocol ) -> torch.Tensor: dist = cast(Categorical, act_batch.dist) target_q = dist.probs * torch.min( self.critic_old(obs_batch.obs), self.critic2_old(obs_batch.obs), ) return target_q.sum(dim=-1) + self.alpha.value * dist.entropy() def _update_with_batch(self, batch: RolloutBatchProtocol) -> TDiscreteSACTrainingStats: # type: ignore weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) # critic 1 current_q1 = self.critic(batch.obs).gather(1, act).flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() self.critic_optim.step(critic1_loss) # critic 2 current_q2 = self.critic2(batch.obs).gather(1, act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() self.critic2_optim.step(critic2_loss) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor dist = self.policy(batch).dist entropy = dist.entropy() with torch.no_grad(): current_q1a = self.critic(batch.obs) current_q2a = self.critic2(batch.obs) q = torch.min(current_q1a, current_q2a) actor_loss = -(self.alpha.value * entropy + (dist.probs * q).sum(dim=-1)).mean() self.policy_optim.step(actor_loss) alpha_loss = self.alpha.update(entropy.detach()) self._update_lagged_network_weights() return DiscreteSACTrainingStats( # type: ignore[return-value] actor_loss=actor_loss.item(), critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), alpha=self.alpha.value, alpha_loss=alpha_loss, ) ================================================ FILE: tianshou/algorithm/modelfree/dqn.py ================================================ import logging from abc import ABC, abstractmethod from typing import Any, Generic, TypeVar, cast import gymnasium as gym import numpy as np import torch from gymnasium.spaces.discrete import Discrete from sensai.util.helper import mark_used from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( LaggedNetworkFullUpdateAlgorithmMixin, OffPolicyAlgorithm, Policy, TArrOrActBatch, ) from tianshou.algorithm.modelfree.reinforce import ( SimpleLossTrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data.types import ( ActBatchProtocol, BatchWithReturnsProtocol, ModelOutputBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.common import Net mark_used(ActBatchProtocol) TModel = TypeVar("TModel", bound=torch.nn.Module | Net) log = logging.getLogger(__name__) class DiscreteQLearningPolicy(Policy, Generic[TModel]): def __init__( self, *, model: TModel, action_space: gym.spaces.Space, observation_space: gym.Space | None = None, eps_training: float = 0.0, eps_inference: float = 0.0, ) -> None: """ :param model: a model mapping (obs, state, info) to action_values_BA. :param action_space: the environment's action space :param observation_space: the environment's observation space. :param eps_training: the epsilon value for epsilon-greedy exploration during training. When collecting data for training, this is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, i.e. non-training cases (such as evaluation during test steps). The epsilon value is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). """ super().__init__( action_space=action_space, observation_space=observation_space, action_scaling=False, action_bound_method=None, ) self.action_space = cast(Discrete, self.action_space) self.model = model self.eps_training = eps_training self.eps_inference = eps_inference def set_eps_training(self, eps: float) -> None: """ Sets the epsilon value for epsilon-greedy exploration during training. :param eps: the epsilon value for epsilon-greedy exploration during training. When collecting data for training, this is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). """ self.eps_training = eps def set_eps_inference(self, eps: float) -> None: """ Sets the epsilon value for epsilon-greedy exploration during inference. :param eps: the epsilon value for epsilon-greedy exploration during inference, i.e. non-training cases (such as evaluation during test steps). The epsilon value is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). """ self.eps_inference = eps def forward( self, batch: ObsBatchProtocol, state: Any | None = None, model: torch.nn.Module | None = None, ) -> ModelOutputBatchProtocol: """Compute action over the given batch data. If you need to mask the action, please add a "mask" into batch.obs, for example, if we have an environment that has "0/1/2" three actions: :: batch == Batch( obs=Batch( obs="original obs, with batch_size=1 for demonstration", mask=np.array([[False, True, False]]), # action 1 is available # action 0 and 2 are unavailable ), ... ) :param batch: :param state: optional hidden state (for RNNs) :param model: if not passed will use `self.model`. Typically used to pass the lagged target network instead of using the current model. :return: A :class:`~tianshou.data.Batch` which has 3 keys: * ``act`` the action. * ``logits`` the network's raw output. * ``state`` the hidden state. """ if model is None: model = self.model obs = batch.obs mask = getattr(obs, "mask", None) # TODO: this is convoluted! See also other places where this is done. obs_arr = obs.obs if hasattr(obs, "obs") else obs action_values_BA, hidden_BH = model(obs_arr, state=state, info=batch.info) q = self.compute_q_value(action_values_BA, mask) act_B = to_numpy(q.argmax(dim=1)) result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: """Compute the q value based on the network's raw output and action mask.""" if mask is not None: # the masked q value should be smaller than logits.min() min_value = logits.min() - logits.max() - 1.0 logits = logits + to_torch_as(1 - mask, logits) * min_value return logits def add_exploration_noise( self, act: TArrOrActBatch, batch: ObsBatchProtocol, ) -> TArrOrActBatch: eps = self.eps_training if self.is_within_training_step else self.eps_inference if np.isclose(eps, 0.0): return act if isinstance(act, np.ndarray): batch_size = len(act) rand_mask = np.random.rand(batch_size) < eps self.action_space = cast(Discrete, self.action_space) # for mypy action_num = int(self.action_space.n) q = np.random.rand(batch_size, action_num) # [0, 1] if hasattr(batch.obs, "mask"): q += batch.obs.mask rand_act = q.argmax(axis=1) act[rand_mask] = rand_act[rand_mask] return act # type: ignore[return-value] raise NotImplementedError( f"Currently only numpy array is supported for action, but got {type(act)}" ) TDQNPolicy = TypeVar("TDQNPolicy", bound=DiscreteQLearningPolicy) class QLearningOffPolicyAlgorithm( OffPolicyAlgorithm[TDQNPolicy], LaggedNetworkFullUpdateAlgorithmMixin, ABC ): """ Base class for Q-learning off-policy algorithms that use a Q-function to compute the n-step return. It optionally uses a lagged model, which is used as a target network and which is fully updated periodically. """ def __init__( self, *, policy: TDQNPolicy, optim: OptimizerFactory, gamma: float = 0.99, n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ :param policy: the policy :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. """ super().__init__( policy=policy, ) self.optim = self._create_policy_optimizer(optim) LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) assert 0.0 <= gamma <= 1.0, f"discount factor should be in [0, 1] but got: {gamma}" self.gamma = gamma assert n_step_return_horizon > 0, ( f"n_step_return_horizon should be greater than 0 but got: {n_step_return_horizon}" ) self.n_step = n_step_return_horizon self.target_update_freq = target_update_freq # TODO: 1 would be a more reasonable initialization given how it is incremented self._iter = 0 self.model_old: EvalModeModuleWrapper | None = ( self._add_lagged_network(self.policy.model) if self.use_target_network else None ) def _create_policy_optimizer(self, optim: OptimizerFactory) -> Algorithm.Optimizer: return self._create_optimizer(self.policy, optim) @property def use_target_network(self) -> bool: return self.target_update_freq > 0 @abstractmethod def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: pass def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithReturnsProtocol: """Compute the n-step return for Q-learning targets. More details can be found at :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. """ return self.compute_nstep_return( batch=batch, buffer=buffer, indices=indices, target_q_fn=self._target_q, gamma=self.gamma, n_step=self.n_step, ) def _periodically_update_lagged_network_weights(self) -> None: """ Periodically updates the parameters of the lagged target network (if any), i.e. every n-th call (where n=`target_update_freq`), the target network's parameters are fully updated with the model's parameters. """ if self.use_target_network and self._iter % self.target_update_freq == 0: self._update_lagged_network_weights() self._iter += 1 class DQN( QLearningOffPolicyAlgorithm[TDQNPolicy], Generic[TDQNPolicy], ): """Implementation of Deep Q Network. arXiv:1312.5602. Implementation of Double Q-Learning. arXiv:1509.06461. Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is implemented in the network side, not here). """ def __init__( self, *, policy: TDQNPolicy, optim: OptimizerFactory, gamma: float = 0.99, n_step_return_horizon: int = 1, target_update_freq: int = 0, is_double: bool = True, huber_loss_delta: float | None = None, ) -> None: """ :param policy: the policy :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. :param is_double: flag indicating whether to use the Double DQN algorithm for target value computation. If True, the algorithm uses the online network to select actions and the target network to evaluate their Q-values. This approach helps reduce the overestimation bias in Q-learning by decoupling action selection from action evaluation. If False, the algorithm follows the vanilla DQN method that directly takes the maximum Q-value from the target network. Note: Double Q-learning will only be effective when a target network is used (target_update_freq > 0). :param huber_loss_delta: controls whether to use the Huber loss instead of the MSE loss for the TD error and the threshold for the Huber loss. If None, the MSE loss is used. If not None, uses the Huber loss as described in the Nature DQN paper (nature14236) with the given delta, which limits the influence of outliers. Unlike the MSE loss where the gradients grow linearly with the error magnitude, the Huber loss causes the gradients to plateau at a constant value for large errors, providing more stable training. NOTE: The magnitude of delta should depend on the scale of the returns obtained in the environment. """ super().__init__( policy=policy, optim=optim, gamma=gamma, n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.is_double = is_double self.huber_loss_delta = huber_loss_delta def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} result = self.policy(obs_next_batch) if self.use_target_network: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) target_q = self.policy(obs_next_batch, model=self.model_old).logits else: target_q = result.logits if self.is_double: return target_q[np.arange(len(result.act)), result.act] # Nature DQN, over estimate return target_q.max(dim=1)[0] def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> SimpleLossTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) q = self.policy(batch).logits q = q[np.arange(len(q)), batch.act] returns = to_torch_as(batch.returns.flatten(), q) td_error = returns - q if self.huber_loss_delta is not None: y = q.reshape(-1, 1) t = returns.reshape(-1, 1) loss = torch.nn.functional.huber_loss( y, t, delta=self.huber_loss_delta, reduction="mean" ) else: loss = (td_error.pow(2) * weight).mean() batch.weight = td_error # prio-buffer self.optim.step(loss) return SimpleLossTrainingStats(loss=loss.item()) ================================================ FILE: tianshou/algorithm/modelfree/fqf.py ================================================ from dataclasses import dataclass from typing import Any, cast import gymnasium as gym import numpy as np import torch import torch.nn.functional as F from overrides import override from tianshou.algorithm import QRDQN, Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @dataclass(kw_only=True) class FQFTrainingStats(SimpleLossTrainingStats): quantile_loss: float fraction_loss: float entropy_loss: float class FQFPolicy(QRDQNPolicy): def __init__( self, *, model: FullQuantileFunction, fraction_model: FractionProposalNetwork, action_space: gym.spaces.Space, observation_space: gym.Space | None = None, eps_training: float = 0.0, eps_inference: float = 0.0, ): """ :param model: a model following the rules (s_B -> action_values_BA) :param fraction_model: a FractionProposalNetwork for proposing fractions/quantiles given state. :param action_space: the environment's action space :param observation_space: the environment's observation space. :param eps_training: the epsilon value for epsilon-greedy exploration during training. When collecting data for training, this is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, i.e. non-training cases (such as evaluation during test steps). The epsilon value is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). """ assert isinstance(action_space, gym.spaces.Discrete) super().__init__( model=model, action_space=action_space, observation_space=observation_space, eps_training=eps_training, eps_inference=eps_inference, ) self.fraction_model = fraction_model def forward( # type: ignore self, batch: ObsBatchProtocol, state: dict | Batch | np.ndarray | None = None, model: FullQuantileFunction | None = None, fractions: Batch | None = None, **kwargs: Any, ) -> FQFBatchProtocol: if model is None: model = self.model obs = batch.obs # TODO: this is convoluted! See also other places where this is done obs_next = obs.obs if hasattr(obs, "obs") else obs if fractions is None: (logits, fractions, quantiles_tau), hidden = model( obs_next, propose_model=self.fraction_model, state=state, info=batch.info, ) else: (logits, _, quantiles_tau), hidden = model( obs_next, propose_model=self.fraction_model, fractions=fractions, state=state, info=batch.info, ) weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits q = DiscreteQLearningPolicy.compute_q_value( self, weighted_logits.sum(2), getattr(obs, "mask", None) ) act = to_numpy(q.max(dim=1)[1]) result = Batch( logits=logits, act=act, state=hidden, fractions=fractions, quantiles_tau=quantiles_tau, ) return cast(FQFBatchProtocol, result) class FQF(QRDQN[FQFPolicy]): """Implementation of Fully Parameterized Quantile Function for Distributional Reinforcement Learning. arXiv:1911.02140.""" def __init__( self, *, policy: FQFPolicy, optim: OptimizerFactory, fraction_optim: OptimizerFactory, gamma: float = 0.99, num_fractions: int = 32, ent_coef: float = 0.0, n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ :param policy: the policy :param optim: the optimizer factory for the policy's main Q-function model :param fraction_optim: the optimizer factory for the policy's fraction model :param action_space: the environment's action space. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_fractions: the number of fractions to use. :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. Controls the exploration-exploitation trade-off by encouraging policy entropy. Higher values promote more exploration by encouraging a more uniform action distribution. Lower values focus more on exploitation of the current policy's knowledge. Typically set between 0.01 and 0.05 for most actor-critic implementations. :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. """ super().__init__( policy=policy, optim=optim, gamma=gamma, num_quantiles=num_fractions, n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.ent_coef = ent_coef self.fraction_optim = self._create_optimizer(self.policy.fraction_model, fraction_optim) @override def _create_policy_optimizer(self, optim: OptimizerFactory) -> Algorithm.Optimizer: # Override to leave out the fraction model (use main model only), as we want # to use a separate optimizer for the fraction model return self._create_optimizer(self.policy.model, optim) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} if self.use_target_network: result = self.policy(obs_next_batch) act, fractions = result.act, result.fractions next_dist = self.policy( obs_next_batch, model=self.model_old, fractions=fractions ).logits else: next_batch = self.policy(obs_next_batch) act = next_batch.act next_dist = next_batch.logits return next_dist[np.arange(len(act)), act, :] def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> FQFTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) out = self.policy(batch) curr_dist_orig = out.logits taus, tau_hats = out.fractions.taus, out.fractions.tau_hats act = batch.act curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( ( dist_diff * (tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.0).float()).abs() ) .sum(-1) .mean(1) ) quantile_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer # calculate fraction loss with torch.no_grad(): sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :] sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :] # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169 values_1 = sa_quantiles - sa_quantile_hats[:, :-1] signs_1 = sa_quantiles > torch.cat( [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1, ) values_2 = sa_quantiles - sa_quantile_hats[:, 1:] signs_2 = sa_quantiles < torch.cat( [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1, ) gradient_of_taus = torch.where(signs_1, values_1, -values_1) + torch.where( signs_2, values_2, -values_2, ) fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean() # calculate entropy loss entropy_loss = out.fractions.entropies.mean() fraction_entropy_loss = fraction_loss - self.ent_coef * entropy_loss self.fraction_optim.step(fraction_entropy_loss, retain_graph=True) self.optim.step(quantile_loss) return FQFTrainingStats( loss=quantile_loss.item() + fraction_entropy_loss.item(), quantile_loss=quantile_loss.item(), fraction_loss=fraction_loss.item(), entropy_loss=entropy_loss.item(), ) ================================================ FILE: tianshou/algorithm/modelfree/iqn.py ================================================ from typing import Any, cast import gymnasium as gym import numpy as np import torch import torch.nn.functional as F from tianshou.algorithm import QRDQN from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, to_numpy from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( ObsBatchProtocol, QuantileRegressionBatchProtocol, RolloutBatchProtocol, ) class IQNPolicy(QRDQNPolicy): def __init__( self, *, model: torch.nn.Module, action_space: gym.spaces.Space, sample_size: int = 32, online_sample_size: int = 8, target_sample_size: int = 8, observation_space: gym.Space | None = None, eps_training: float = 0.0, eps_inference: float = 0.0, ) -> None: """ :param model: :param action_space: the environment's action space :param sample_size: :param online_sample_size: :param target_sample_size: :param observation_space: the environment's observation space :param eps_training: the epsilon value for epsilon-greedy exploration during training. When collecting data for training, this is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, i.e. non-training cases (such as evaluation during test steps). The epsilon value is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). """ assert isinstance(action_space, gym.spaces.Discrete) assert sample_size > 1, f"sample_size should be greater than 1 but got: {sample_size}" assert online_sample_size > 1, ( f"online_sample_size should be greater than 1 but got: {online_sample_size}" ) assert target_sample_size > 1, ( f"target_sample_size should be greater than 1 but got: {target_sample_size}" ) super().__init__( model=model, action_space=action_space, observation_space=observation_space, eps_training=eps_training, eps_inference=eps_inference, ) self.sample_size = sample_size self.online_sample_size = online_sample_size self.target_sample_size = target_sample_size def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, model: torch.nn.Module | None = None, **kwargs: Any, ) -> QuantileRegressionBatchProtocol: is_model_old = model is not None if is_model_old: sample_size = self.target_sample_size elif self.training: sample_size = self.online_sample_size else: sample_size = self.sample_size if model is None: model = self.model obs = batch.obs # TODO: this seems very contrived! obs_next = obs.obs if hasattr(obs, "obs") else obs (logits, taus), hidden = model( obs_next, sample_size=sample_size, state=state, info=batch.info, ) q = self.compute_q_value(logits, getattr(obs, "mask", None)) act = to_numpy(q.max(dim=1)[1]) result = Batch(logits=logits, act=act, state=hidden, taus=taus) return cast(QuantileRegressionBatchProtocol, result) class IQN(QRDQN[IQNPolicy]): """Implementation of Implicit Quantile Network. arXiv:1806.06923.""" def __init__( self, *, policy: IQNPolicy, optim: OptimizerFactory, gamma: float = 0.99, num_quantiles: int = 200, n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ :param policy: the policy :param optim: the optimizer factory for the policy's model :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_quantiles: the number of quantile midpoints in the inverse cumulative distribution function of the value. :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. """ super().__init__( policy=policy, optim=optim, gamma=gamma, num_quantiles=num_quantiles, n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> SimpleLossTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) action_batch = self.policy(batch) curr_dist, taus = action_batch.logits, action_batch.taus act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( ( dist_diff * (taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.0).float()).abs() ) .sum(-1) .mean(1) ) loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer self.optim.step(loss) return SimpleLossTrainingStats(loss=loss.item()) ================================================ FILE: tianshou/algorithm/modelfree/npg.py ================================================ from dataclasses import dataclass from typing import Any import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.distributions import kl_divergence from tianshou.algorithm.algorithm_base import TrainingStats from tianshou.algorithm.modelfree.a2c import ActorCriticOnPolicyAlgorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic @dataclass(kw_only=True) class NPGTrainingStats(TrainingStats): actor_loss: SequenceSummaryStats vf_loss: SequenceSummaryStats kl: SequenceSummaryStats class NPG(ActorCriticOnPolicyAlgorithm): """Implementation of Natural Policy Gradient. https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf """ def __init__( self, *, policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, optim_critic_iters: int = 5, trust_region_size: float = 0.5, advantage_normalization: bool = True, gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, return_scaling: bool = False, ) -> None: """ :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the critic network. :param optim_critic_iters: the number of optimization steps performed on the critic network for each policy (actor) update. Controls the learning rate balance between critic and actor. Higher values prioritize critic accuracy by training the value function more extensively before each policy update, which can improve stability but slow down training. Lower values maintain a more even learning pace between policy and value function but may lead to less reliable advantage estimates. Typically set between 1 and 10, depending on the complexity of the value function. :param trust_region_size: the parameter delta - a scalar multiplier for policy updates in the natural gradient direction. The mathematical meaning is the trust region size, which is the maximum KL divergence allowed between the old and new policy distributions. Controls how far the policy parameters move in the calculated direction during each update. Higher values allow for faster learning but may cause instability or policy deterioration; lower values provide more stable but slower learning. Unlike regular policy gradients, natural gradients already account for the local geometry of the parameter space, making this step size more robust to different parameterizations. Typically set between 0.1 and 1.0 for most reinforcement learning tasks. :param advantage_normalization: whether to do per mini-batch advantage normalization. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values (closer to 1) reduce bias but increase variance by giving more weight to longer trajectories, while lower values (closer to 0) reduce variance but increase bias by relying more on the immediate TD error and value function estimates. At λ=0, GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. :param max_batchsize: the maximum number of samples to process at once when computing generalized advantage estimation (GAE) and value function predictions. Controls memory usage by breaking large batches into smaller chunks processed sequentially. Higher values may increase speed but require more GPU/CPU memory; lower values reduce memory requirements but may increase computation time. Should be adjusted based on available hardware resources and total batch size of your training data. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param return_scaling: flag indicating whether to enable scaling of estimated returns by dividing them by their running standard deviation without centering the mean. This reduces the magnitude variation of advantages across different episodes while preserving their signs and relative ordering. The use of running statistics (rather than batch-specific scaling) means that early training experiences may be scaled differently than later ones as the statistics evolve. When enabled, this improves training stability in environments with highly variable reward scales and makes the algorithm less sensitive to learning rate settings. However, it may reduce the algorithm's ability to distinguish between episodes with different absolute return magnitudes. Best used in environments where the relative ordering of actions is more important than the absolute scale of returns. """ super().__init__( policy=policy, critic=critic, optim=optim, optim_include_actor=False, gae_lambda=gae_lambda, max_batchsize=max_batchsize, gamma=gamma, return_scaling=return_scaling, ) self.advantage_normalization = advantage_normalization self.optim_critic_iters = optim_critic_iters self.trust_region_size = trust_region_size # adjusts Hessian-vector product calculation for numerical stability self._damping = 0.1 def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithAdvantagesProtocol: batch = self._add_returns_and_advantages(batch, buffer, indices) batch.act = to_torch_as(batch.act, batch.v_s) old_log_prob = [] with torch.no_grad(): for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): old_log_prob.append(self.policy(minibatch).dist.log_prob(minibatch.act)) batch.logp_old = torch.cat(old_log_prob, dim=0) if self.advantage_normalization: batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() return batch def _update_with_batch( # type: ignore[override] self, batch: BatchWithAdvantagesProtocol, batch_size: int | None, repeat: int, ) -> NPGTrainingStats: actor_losses, vf_losses, kls = [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): for minibatch in batch.split(split_batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient dist = self.policy(minibatch).dist log_prob = dist.log_prob(minibatch.act) log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1) actor_loss = -(log_prob * minibatch.adv).mean() flat_grads = self._get_flat_grad( actor_loss, self.policy.actor, retain_graph=True ).detach() # direction: calculate natural gradient with torch.no_grad(): old_dist = self.policy(minibatch).dist kl = kl_divergence(old_dist, dist).mean() # calculate first order gradient of kl with respect to theta flat_kl_grad = self._get_flat_grad(kl, self.policy.actor, create_graph=True) search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, nsteps=10) # step with torch.no_grad(): flat_params = torch.cat( [param.data.view(-1) for param in self.policy.actor.parameters()], ) new_flat_params = flat_params + self.trust_region_size * search_direction self._set_from_flat_params(self.policy.actor, new_flat_params) new_dist = self.policy(minibatch).dist kl = kl_divergence(old_dist, new_dist).mean() # optimize critic for _ in range(self.optim_critic_iters): value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) self.optim.step(vf_loss) actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) kls.append(kl.item()) return NPGTrainingStats( actor_loss=SequenceSummaryStats.from_sequence(actor_losses), vf_loss=SequenceSummaryStats.from_sequence(vf_losses), kl=SequenceSummaryStats.from_sequence(kls), ) def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: """Matrix vector product.""" # caculate second order gradient of kl with respect to theta kl_v = (flat_kl_grad * v).sum() flat_kl_grad_grad = self._get_flat_grad(kl_v, self.policy.actor, retain_graph=True).detach() return flat_kl_grad_grad + v * self._damping def _conjugate_gradients( self, minibatch: torch.Tensor, flat_kl_grad: torch.Tensor, nsteps: int = 10, residual_tol: float = 1e-10, ) -> torch.Tensor: x = torch.zeros_like(minibatch) r, p = minibatch.clone(), minibatch.clone() # Note: should be 'r, p = minibatch - MVP(x)', but for x=0, MVP(x)=0. # Change if doing warm start. rdotr = r.dot(r) for _ in range(nsteps): z = self._MVP(p, flat_kl_grad) alpha = rdotr / p.dot(z) x += alpha * p r -= alpha * z new_rdotr = r.dot(r) if new_rdotr < residual_tol: break p = r + new_rdotr / rdotr * p rdotr = new_rdotr return x def _get_flat_grad(self, y: torch.Tensor, model: nn.Module, **kwargs: Any) -> torch.Tensor: grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore return torch.cat([grad.reshape(-1) for grad in grads]) def _set_from_flat_params(self, model: nn.Module, flat_params: torch.Tensor) -> nn.Module: prev_ind = 0 for param in model.parameters(): flat_size = int(np.prod(list(param.size()))) param.data.copy_(flat_params[prev_ind : prev_ind + flat_size].view(param.size())) prev_ind += flat_size return model ================================================ FILE: tianshou/algorithm/modelfree/ppo.py ================================================ from typing import cast import numpy as np import torch from tianshou.algorithm import A2C from tianshou.algorithm.modelfree.a2c import A2CTrainingStats from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic class PPO(A2C): """Implementation of Proximal Policy Optimization. arXiv:1707.06347.""" def __init__( self, *, policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, eps_clip: float = 0.2, dual_clip: float | None = None, value_clip: bool = False, advantage_normalization: bool = True, recompute_advantage: bool = False, vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, return_scaling: bool = False, ) -> None: r""" :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the policy's actor network and the critic networks. :param eps_clip: determines the range of allowed change in the policy during a policy update: The ratio of action probabilities indicated by the new and old policy is constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. Small values thus force the new policy to stay close to the old policy. Typical values range between 0.1 and 0.3, the value of 0.2 is recommended in the original PPO paper. The optimal value depends on the environment; more stochastic environments may need larger values. :param dual_clip: a clipping parameter (denoted as c in the literature) that prevents excessive pessimism in policy updates for negative-advantage actions. Excessive pessimism occurs when the policy update too strongly reduces the probability of selecting actions that led to negative advantages, potentially eliminating useful actions based on limited negative experiences. When enabled (c > 1), the objective for negative advantages becomes: max(min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A), c*A), where min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) is the original single-clipping objective determined by `eps_clip`. This creates a floor on negative policy gradients, maintaining some probability of exploring actions despite initial negative outcomes. Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer to 1.0 provide less protection against pessimistic updates. Set to None to disable dual clipping. :param value_clip: flag indicating whether to enable clipping for value function updates. When enabled, restricts how much the value function estimate can change from its previous prediction, using the same clipping range as the policy updates (eps_clip). This stabilizes training by preventing large fluctuations in value estimates, particularly useful in environments with high reward variance. The clipped value loss uses a pessimistic approach, taking the maximum of the original and clipped value errors: max((returns - value)², (returns - v_clipped)²) Setting to True often improves training stability but may slow convergence. Implementation follows the approach mentioned in arXiv:1811.02553v3 Sec. 4.1. :param advantage_normalization: whether to do per mini-batch advantage normalization. :param recompute_advantage: whether to recompute advantage every update repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. :param vf_coef: coefficient that weights the value loss relative to the actor loss in the overall loss function. Higher values prioritize accurate value function estimation over policy improvement. Controls the trade-off between policy optimization and value function fitting. Typically set between 0.5 and 1.0 for most actor-critic implementations. :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. Controls the exploration-exploitation trade-off by encouraging policy entropy. Higher values promote more exploration by encouraging a more uniform action distribution. Lower values focus more on exploitation of the current policy's knowledge. Typically set between 0.01 and 0.05 for most actor-critic implementations. :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. When not None, gradients will be rescaled using to ensure their L2 norm does not exceed this value. This prevents exploding gradients and stabilizes training by limiting the magnitude of parameter updates. Set to None to disable gradient clipping. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values (closer to 1) reduce bias but increase variance by giving more weight to longer trajectories, while lower values (closer to 0) reduce variance but increase bias by relying more on the immediate TD error and value function estimates. At λ=0, GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. :param max_batchsize: the maximum size of the batch when computing GAE. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param return_scaling: flag indicating whether to enable scaling of estimated returns by dividing them by their running standard deviation without centering the mean. This reduces the magnitude variation of advantages across different episodes while preserving their signs and relative ordering. The use of running statistics (rather than batch-specific scaling) means that early training experiences may be scaled differently than later ones as the statistics evolve. When enabled, this improves training stability in environments with highly variable reward scales and makes the algorithm less sensitive to learning rate settings. However, it may reduce the algorithm's ability to distinguish between episodes with different absolute return magnitudes. Best used in environments where the relative ordering of actions is more important than the absolute scale of returns. """ assert dual_clip is None or dual_clip > 1.0, ( f"Dual-clip PPO parameter should greater than 1.0 but got {dual_clip}" ) super().__init__( policy=policy, critic=critic, optim=optim, vf_coef=vf_coef, ent_coef=ent_coef, max_grad_norm=max_grad_norm, gae_lambda=gae_lambda, max_batchsize=max_batchsize, gamma=gamma, return_scaling=return_scaling, ) self.eps_clip = eps_clip self.dual_clip = dual_clip self.value_clip = value_clip self.advantage_normalization = advantage_normalization self.recompute_adv = recompute_advantage def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> LogpOldProtocol: if self.recompute_adv: # buffer input `buffer` and `indices` to be used in `_update_with_batch()`. self._buffer, self._indices = buffer, indices batch = self._add_returns_and_advantages(batch, buffer, indices) batch.act = to_torch_as(batch.act, batch.v_s) logp_old = [] with torch.no_grad(): for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): logp_old.append(self.policy(minibatch).dist.log_prob(minibatch.act)) batch.logp_old = torch.cat(logp_old, dim=0).flatten() return cast(LogpOldProtocol, batch) def _update_with_batch( # type: ignore[override] self, batch: LogpOldProtocol, batch_size: int | None, repeat: int, ) -> A2CTrainingStats: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] gradient_steps = 0 split_batch_size = batch_size or -1 for step in range(repeat): if self.recompute_adv and step > 0: batch = cast( LogpOldProtocol, self._add_returns_and_advantages(batch, self._buffer, self._indices), ) for minibatch in batch.split(split_batch_size, merge_last=True): gradient_steps += 1 # calculate loss for actor advantages = minibatch.adv dist = self.policy(minibatch).dist if self.advantage_normalization: mean, std = advantages.mean(), advantages.std() advantages = (advantages - mean) / (std + self._eps) # per-batch norm ratios = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() ratios = ratios.reshape(ratios.size(0), -1).transpose(0, 1) surr1 = ratios * advantages surr2 = ratios.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip) * advantages if self.dual_clip: clip1 = torch.min(surr1, surr2) clip2 = torch.max(clip1, self.dual_clip * advantages) clip_loss = -torch.where(advantages < 0, clip2, clip1).mean() else: clip_loss = -torch.min(surr1, surr2).mean() # calculate loss for critic value = self.critic(minibatch.obs).flatten() if self.value_clip: v_clip = minibatch.v_s + (value - minibatch.v_s).clamp( -self.eps_clip, self.eps_clip, ) vf1 = (minibatch.returns - value).pow(2) vf2 = (minibatch.returns - v_clip).pow(2) vf_loss = torch.max(vf1, vf2).mean() else: vf_loss = (minibatch.returns - value).pow(2).mean() # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = clip_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss self.optim.step(loss) clip_losses.append(clip_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) return A2CTrainingStats( loss=SequenceSummaryStats.from_sequence(losses), actor_loss=SequenceSummaryStats.from_sequence(clip_losses), vf_loss=SequenceSummaryStats.from_sequence(vf_losses), ent_loss=SequenceSummaryStats.from_sequence(ent_losses), gradient_steps=gradient_steps, ) ================================================ FILE: tianshou/algorithm/modelfree/qrdqn.py ================================================ import warnings from typing import Generic, TypeVar import numpy as np import torch import torch.nn.functional as F from tianshou.algorithm.modelfree.dqn import ( DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol class QRDQNPolicy(DiscreteQLearningPolicy): def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: return super().compute_q_value(logits.mean(2), mask) TQRDQNPolicy = TypeVar("TQRDQNPolicy", bound=QRDQNPolicy) class QRDQN( QLearningOffPolicyAlgorithm[TQRDQNPolicy], Generic[TQRDQNPolicy], ): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.""" def __init__( self, *, policy: TQRDQNPolicy, optim: OptimizerFactory, gamma: float = 0.99, num_quantiles: int = 200, n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ :param policy: the policy :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_quantiles: the number of quantiles used to represent the return distribution for each action. Determines the granularity of the approximated distribution function. Higher values provide a more fine-grained approximation of the true return distribution but increase computational and memory requirements. Lower values reduce computational cost but may not capture the distribution accurately enough. The original QRDQN paper used 200 quantiles for Atari environments. Must be greater than 1, as at least two quantiles are needed to represent a distribution. :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. """ assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" super().__init__( policy=policy, optim=optim, gamma=gamma, n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.num_quantiles = num_quantiles tau = torch.linspace(0, 1, self.num_quantiles + 1) self.tau_hat = torch.nn.Parameter( ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False, ) warnings.filterwarnings("ignore", message="Using a target size") def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} if self.use_target_network: act = self.policy(obs_next_batch).act next_dist = self.policy(obs_next_batch, model=self.model_old).logits else: next_batch = self.policy(obs_next_batch) act = next_batch.act next_dist = next_batch.logits return next_dist[np.arange(len(act)), act, :] def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> SimpleLossTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) curr_dist = self.policy(batch).logits act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( (dist_diff * (self.tau_hat - (target_dist - curr_dist).detach().le(0.0).float()).abs()) .sum(-1) .mean(1) ) loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer self.optim.step(loss) return SimpleLossTrainingStats(loss=loss.item()) ================================================ FILE: tianshou/algorithm/modelfree/rainbow.py ================================================ from dataclasses import dataclass from torch import nn from tianshou.algorithm.modelfree.c51 import C51, C51Policy from tianshou.algorithm.modelfree.reinforce import LossSequenceTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data.types import RolloutBatchProtocol from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.discrete import NoisyLinear @dataclass(kw_only=True) class RainbowTrainingStats: loss: float class RainbowDQN(C51): """Implementation of Rainbow DQN. arXiv:1710.02298.""" def __init__( self, *, policy: C51Policy, optim: OptimizerFactory, gamma: float = 0.99, n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ :param policy: a policy following the rules (s -> action_values_BA) :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. """ super().__init__( policy=policy, optim=optim, gamma=gamma, n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.model_old: nn.Module | None # type: ignore[assignment] # Remove the wrapper that forces eval mode for the target network, # because Rainbow requires it to be set to train mode for sampling noise # in NoisyLinear layers to take effect. # (minor violation of Liskov Substitution Principle) if self.use_target_network: assert isinstance(self.model_old, EvalModeModuleWrapper) self.model_old = self.model_old.module @staticmethod def _sample_noise(model: nn.Module) -> bool: """Sample the random noises of NoisyLinear modules in the model. Returns True if at least one NoisyLinear submodule was found. :param model: a PyTorch module which may have NoisyLinear submodules. :returns: True if model has at least one NoisyLinear submodule; otherwise, False. """ sampled_any_noise = False for m in model.modules(): if isinstance(m, NoisyLinear): m.sample() sampled_any_noise = True return sampled_any_noise def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> LossSequenceTrainingStats: self._sample_noise(self.policy.model) if self.use_target_network: assert self.model_old is not None self._sample_noise(self.model_old) return super()._update_with_batch(batch) ================================================ FILE: tianshou/algorithm/modelfree/redq.py ================================================ from dataclasses import dataclass from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np import torch from torch.distributions import Independent, Normal from tianshou.algorithm.modelfree.ddpg import ( ActorCriticOffPolicyAlgorithm, ContinuousPolicyWithExplorationNoise, DDPGTrainingStats, ) from tianshou.algorithm.modelfree.sac import Alpha from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch from tianshou.data.types import ( DistLogProbBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise from tianshou.utils.net.continuous import ContinuousActorProbabilistic @dataclass class REDQTrainingStats(DDPGTrainingStats): """A data structure for storing loss statistics of the REDQ learn step.""" alpha: float | None = None alpha_loss: float | None = None TREDQTrainingStats = TypeVar("TREDQTrainingStats", bound=REDQTrainingStats) class REDQPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, actor: torch.nn.Module | ContinuousActorProbabilistic, exploration_noise: BaseNoise | Literal["default"] | None = None, action_space: gym.spaces.Space, deterministic_eval: bool = True, action_scaling: bool = True, action_bound_method: Literal["clip"] | None = "clip", observation_space: gym.Space | None = None, ): """ :param actor: The actor network following the rules (s -> model_output) :param action_space: the environment's action_space. :param deterministic_eval: flag indicating whether the policy should use deterministic actions (using the mode of the action distribution) instead of stochastic ones (using random sampling) during evaluation. When enabled, the policy will always select the most probable action according to the learned distribution during evaluation phases, while still using stochastic sampling during training. This creates a clear distinction between exploration (training) and exploitation (evaluation) behaviors. Deterministic actions are generally preferred for final deployment and reproducible evaluation as they provide consistent behavior, reduce variance in performance metrics, and are more interpretable for human observers. Note that this parameter only affects behavior when the policy is not within a training step. When collecting rollouts for training, actions remain stochastic regardless of this setting to maintain proper exploration behaviour. :param observation_space: the environment's observation space :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. This applies to continuous action spaces only (gym.spaces.Box) and has no effect for discrete spaces. When enabled, policy outputs are expected to be in the normalized range [-1, 1] (after bounding), and are then linearly transformed to the actual required range. This improves neural network training stability, allows the same algorithm to work across environments with different action ranges, and standardizes exploration strategies. Should be disabled if the actor model already produces outputs in the correct range. :param action_bound_method: the method used for bounding actions in continuous action spaces to the range [-1, 1] before scaling them to the environment's action space (provided that `action_scaling` is enabled). This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None for discrete spaces. When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly constrains outputs to [-1, 1] while preserving gradients. The choice of bounding method affects both training dynamics and exploration behavior. Clipping provides hard boundaries but may create plateau regions in the gradient landscape, while tanh provides smoother transitions but can compress sensitivity near the boundaries. Should be set to None if the actor model inherently produces bounded outputs. Typically used together with `action_scaling=True`. """ super().__init__( exploration_noise=exploration_noise, action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, action_bound_method=action_bound_method, ) self.actor = actor self.deterministic_eval = deterministic_eval self._eps = np.finfo(np.float32).eps.item() def forward( # type: ignore self, batch: ObsBatchProtocol, state: dict | Batch | np.ndarray | None = None, **kwargs: Any, ) -> DistLogProbBatchProtocol: (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc_B, scale_B), 1) if self.deterministic_eval and not self.is_within_training_step: act_B = dist.mode else: act_B = dist.rsample() log_prob = dist.log_prob(act_B).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. squashed_action = torch.tanh(act_B) log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self._eps).sum( -1, keepdim=True, ) result = Batch( logits=(loc_B, scale_B), act=squashed_action, state=h_BH, dist=dist, log_prob=log_prob, ) return cast(DistLogProbBatchProtocol, result) class REDQ(ActorCriticOffPolicyAlgorithm[REDQPolicy, DistLogProbBatchProtocol]): """Implementation of REDQ. arXiv:2101.05982.""" def __init__( self, *, policy: REDQPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module, critic_optim: OptimizerFactory, ensemble_size: int = 10, subset_size: int = 2, tau: float = 0.005, gamma: float = 0.99, alpha: float | Alpha = 0.2, n_step_return_horizon: int = 1, actor_delay: int = 20, deterministic_eval: bool = True, target_mode: Literal["mean", "min"] = "min", ) -> None: """ :param policy: the policy :param policy_optim: the optimizer factory for the policy's model. :param critic: the critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer factory for the critic network. :param ensemble_size: the total number of critic networks in the ensemble. This parameter implements the randomized ensemble approach described in REDQ. The algorithm maintains `ensemble_size` different critic networks that all share the same architecture. During target value computation, a random subset of these networks (determined by `subset_size`) is used. Larger values increase the diversity of the ensemble but require more memory and computation. The original paper recommends a value of 10 for most tasks, balancing performance and computational efficiency. :param subset_size: the number of critic networks randomly selected from the ensemble for computing target Q-values. During each update, the algorithm samples `subset_size` networks from the ensemble of `ensemble_size` networks without replacement. The target Q-value is then calculated as either the minimum or mean (based on `target_mode`) of the predictions from this subset. Smaller values increase randomization and sample efficiency but may introduce more variance. Larger values provide more stable estimates but reduce the benefits of randomization. The REDQ paper recommends a value of 2 for optimal sample efficiency. Must satisfy 0 < subset_size <= ensemble_size. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param alpha: the entropy regularization coefficient, which balances exploration and exploitation. This coefficient controls how much the agent values randomness in its policy versus pursuing higher rewards. Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent for maintaining diverse action choices, even if this means selecting some lower-value actions. Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become more focused on the highest-value actions. A value of 0 would completely remove entropy regularization, potentially leading to premature convergence to suboptimal deterministic policies. Can be provided as a fixed float (0.2 is a reasonable default) or as an instance of, in particular, class `AutoAlpha` for automatic tuning during training. :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param actor_delay: the number of critic updates performed before each actor update. The actor network is only updated once for every actor_delay critic updates, implementing a delayed policy update strategy similar to TD3. Larger values stabilize training by allowing critics to become more accurate before policy updates. Smaller values allow the policy to adapt more quickly but may lead to less stable learning. The REDQ paper recommends a value of 20 for most tasks. :param target_mode: the method used to aggregate Q-values from the subset of critic networks. Can be either "min" or "mean". If "min", uses the minimum Q-value across the selected subset of critics for each state-action pair. If "mean", uses the average Q-value across the selected subset of critics. Using "min" helps prevent overestimation bias but may lead to more conservative value estimates. Using "mean" provides more optimistic value estimates but may suffer from overestimation bias. Default is "min" following the conservative value estimation approach common in recent Q-learning algorithms. """ if target_mode not in ("min", "mean"): raise ValueError(f"Unsupported target_mode: {target_mode}") if not 0 < subset_size <= ensemble_size: raise ValueError( f"Invalid choice of ensemble size or subset size. " f"Should be 0 < {subset_size=} <= {ensemble_size=}", ) super().__init__( policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, tau=tau, gamma=gamma, n_step_return_horizon=n_step_return_horizon, ) self.ensemble_size = ensemble_size self.subset_size = subset_size self.target_mode = target_mode self.critic_gradient_step = 0 self.actor_delay = actor_delay self.deterministic_eval = deterministic_eval self.__eps = np.finfo(np.float32).eps.item() self._last_actor_loss = 0.0 # only for logging purposes self.alpha = Alpha.from_float_or_instance(alpha) def _target_q_compute_value( self, obs_batch: Batch, act_batch: DistLogProbBatchProtocol ) -> torch.Tensor: a_ = act_batch.act sample_ensemble_idx = np.random.choice(self.ensemble_size, self.subset_size, replace=False) qs = self.critic_old(obs_batch.obs, a_)[sample_ensemble_idx, ...] if self.target_mode == "min": target_q, _ = torch.min(qs, dim=0) elif self.target_mode == "mean": target_q = torch.mean(qs, dim=0) else: raise ValueError(f"Invalid target_mode: {self.target_mode}") target_q -= self.alpha.value * act_batch.log_prob return target_q def _update_with_batch(self, batch: RolloutBatchProtocol) -> REDQTrainingStats: # type: ignore # critic ensemble weight = getattr(batch, "weight", 1.0) current_qs = self.critic(batch.obs, batch.act).flatten(1) target_q = batch.returns.flatten() td = current_qs - target_q critic_loss = (td.pow(2) * weight).mean() self.critic_optim.step(critic_loss) batch.weight = torch.mean(td, dim=0) # prio-buffer self.critic_gradient_step += 1 alpha_loss = None # actor if self.critic_gradient_step % self.actor_delay == 0: obs_result = self.policy(batch) a = obs_result.act current_qa = self.critic(batch.obs, a).mean(dim=0).flatten() actor_loss = (self.alpha.value * obs_result.log_prob.flatten() - current_qa).mean() self.policy_optim.step(actor_loss) # The entropy of a Gaussian policy can be expressed as -log_prob + a constant (which we ignore) entropy = -obs_result.log_prob.detach() alpha_loss = self.alpha.update(entropy) self._last_actor_loss = actor_loss.item() self._update_lagged_network_weights() return REDQTrainingStats( actor_loss=self._last_actor_loss, critic_loss=critic_loss.item(), alpha=self.alpha.value, alpha_loss=alpha_loss, ) ================================================ FILE: tianshou/algorithm/modelfree/reinforce.py ================================================ import logging import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Literal, TypeVar, cast import gymnasium as gym import numpy as np import torch from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( OnPolicyAlgorithm, Policy, TrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ( Batch, ReplayBuffer, SequenceSummaryStats, to_torch, to_torch_as, ) from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( BatchWithReturnsProtocol, DistBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ( AbstractContinuousActorProbabilistic, AbstractDiscreteActor, ActionReprNet, ) from tianshou.utils.net.discrete import dist_fn_categorical_from_logits log = logging.getLogger(__name__) # Dimension Naming Convention # B - Batch Size # A - Action # D - Dist input (usually 2, loc and scale) # H - Dimension of hidden, can be None TDistFnContinuous = Callable[ [tuple[torch.Tensor, torch.Tensor]], torch.distributions.Distribution, ] TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Distribution] TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete @dataclass(kw_only=True) class LossSequenceTrainingStats(TrainingStats): loss: SequenceSummaryStats @dataclass(kw_only=True) class SimpleLossTrainingStats(TrainingStats): loss: float class ProbabilisticActorPolicy(Policy): """ A policy that outputs (representations of) probability distributions from which actions can be sampled. """ def __init__( self, *, actor: AbstractContinuousActorProbabilistic | AbstractDiscreteActor | ActionReprNet, dist_fn: TDistFnDiscrOrCont, deterministic_eval: bool = False, action_space: gym.Space, observation_space: gym.Space | None = None, action_scaling: bool = True, action_bound_method: Literal["clip", "tanh"] | None = "clip", ) -> None: """ :param actor: the actor network following the rules: If `self.action_type == "discrete"`: (`s_B` -> `action_values_BA`). If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param dist_fn: the function/type which creates a distribution from the actor output, i.e. it maps the tensor(s) generated by the actor to a torch distribution. For continuous action spaces, the output is typically a pair of tensors (mean, std) and the distribution is a Gaussian distribution. For discrete action spaces, the output is typically a tensor of unnormalized log probabilities ("logits" in PyTorch terminology) or a tensor of probabilities which can serve as the parameters of a Categorical distribution. Note that if the actor uses softmax activation in its final layer, it will produce probabilities, whereas if it uses no activation, it can be considered as producing "logits". As a user, you are responsible for ensuring that the distribution is compatible with the output of the actor model and the action space. :param deterministic_eval: flag indicating whether the policy should use deterministic actions (using the mode of the action distribution) instead of stochastic ones (using random sampling) during evaluation. When enabled, the policy will always select the most probable action according to the learned distribution during evaluation phases, while still using stochastic sampling during training. This creates a clear distinction between exploration (training) and exploitation (evaluation) behaviors. Deterministic actions are generally preferred for final deployment and reproducible evaluation as they provide consistent behavior, reduce variance in performance metrics, and are more interpretable for human observers. Note that this parameter only affects behavior when the policy is not within a training step. When collecting rollouts for training, actions remain stochastic regardless of this setting to maintain proper exploration behaviour. :param action_space: the environment's action space. :param observation_space: the environment's observation space. :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. This applies to continuous action spaces only (gym.spaces.Box) and has no effect for discrete spaces. When enabled, policy outputs are expected to be in the normalized range [-1, 1] (after bounding), and are then linearly transformed to the actual required range. This improves neural network training stability, allows the same algorithm to work across environments with different action ranges, and standardizes exploration strategies. Should be disabled if the actor model already produces outputs in the correct range. :param action_bound_method: the method used for bounding actions in continuous action spaces to the range [-1, 1] before scaling them to the environment's action space (provided that `action_scaling` is enabled). This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None for discrete spaces. When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly constrains outputs to [-1, 1] while preserving gradients. The choice of bounding method affects both training dynamics and exploration behavior. Clipping provides hard boundaries but may create plateau regions in the gradient landscape, while tanh provides smoother transitions but can compress sensitivity near the boundaries. Should be set to None if the actor model inherently produces bounded outputs. Typically used together with `action_scaling=True`. """ super().__init__( action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, action_bound_method=action_bound_method, ) if action_scaling: try: max_action = float(actor.max_action) if np.isclose(max_action, 1.0): warnings.warn( "action_scaling and action_bound_method are only intended " "to deal with unbounded model action space, but found actor model " f"bound action space with max_action={actor.max_action}. " "Consider using unbounded=True option of the actor model, " "or set action_scaling to False and action_bound_method to None.", ) except BaseException: pass self.actor = actor self.dist_fn = dist_fn self._eps = 1e-8 self.deterministic_eval = deterministic_eval def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, ) -> DistBatchProtocol: """Compute action over the given batch data by applying the actor. Will sample from the dist_fn, if appropriate. Returns a new object representing the processed batch data (contrary to other methods that modify the input batch inplace). """ action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A # therefore action_dist_input_BD is equivalent to logits_BA # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian) # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked dist = self.dist_fn(action_dist_input_BD) act_B = ( dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() ) # act is of dimension BA in continuous case and of dimension B in discrete result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) return cast(DistBatchProtocol, result) class DiscreteActorPolicy(ProbabilisticActorPolicy): def __init__( self, *, actor: AbstractDiscreteActor | ActionReprNet, dist_fn: TDistFnDiscrete = dist_fn_categorical_from_logits, deterministic_eval: bool = False, action_space: gym.Space, observation_space: gym.Space | None = None, ) -> None: """ :param actor: the actor network following the rules: (`s_B` -> `dist_input_BD`). :param dist_fn: the function/type which creates a distribution from the actor output, i.e. it maps the tensor(s) generated by the actor to a torch distribution. For discrete action spaces, the output is typically a tensor of unnormalized log probabilities ("logits" in PyTorch terminology) or a tensor of probabilities which serve as the parameters of a Categorical distribution. Note that if the actor uses softmax activation in its final layer, it will produce probabilities, whereas if it uses no activation, it can be considered as producing "logits". As a user, you are responsible for ensuring that the distribution is compatible with the output of the actor model and the action space. :param deterministic_eval: flag indicating whether the policy should use deterministic actions (using the mode of the action distribution) instead of stochastic ones (using random sampling) during evaluation. When enabled, the policy will always select the most probable action according to the learned distribution during evaluation phases, while still using stochastic sampling during training. This creates a clear distinction between exploration (training) and exploitation (evaluation) behaviors. Deterministic actions are generally preferred for final deployment and reproducible evaluation as they provide consistent behavior, reduce variance in performance metrics, and are more interpretable for human observers. Note that this parameter only affects behavior when the policy is not within a training step. When collecting rollouts for training, actions remain stochastic regardless of this setting to maintain proper exploration behaviour. :param action_space: the environment's (discrete) action space. :param observation_space: the environment's observation space. """ if not isinstance(action_space, gym.spaces.Discrete): raise ValueError(f"Action space must be an instance of Discrete; got {action_space}") super().__init__( actor=actor, dist_fn=dist_fn, deterministic_eval=deterministic_eval, action_space=action_space, observation_space=observation_space, action_scaling=False, action_bound_method=None, ) TActorPolicy = TypeVar("TActorPolicy", bound=ProbabilisticActorPolicy) class DiscountedReturnComputation: def __init__( self, gamma: float = 0.99, return_standardization: bool = False, ): """ :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param return_standardization: whether to standardize episode returns by subtracting the running mean and dividing by the running standard deviation. Note that this is known to be detrimental to performance in many cases! """ assert 0.0 <= gamma <= 1.0, "discount factor gamma should be in [0, 1]" self.gamma = gamma self.return_standardization = return_standardization self.ret_rms = RunningMeanStd() self.eps = 1e-8 def add_discounted_returns( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray ) -> BatchWithReturnsProtocol: r"""Compute the discounted returns (Monte Carlo estimates) for each transition. They are added to the batch under the field `returns`. Note: this function will modify the input batch! .. math:: G_t = \sum_{i=t}^T \gamma^{i-t}r_i where :math:`T` is the terminal time step, :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. :param batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recognized by buffer.unfinished_index(). :param buffer: the corresponding replay buffer. :param indices: tell batch's location in buffer, batch is equal to buffer[indices]. """ v_s_ = np.full(indices.shape, self.ret_rms.mean) # gae_lambda = 1.0 means we use Monte Carlo estimate unnormalized_returns, _ = Algorithm.compute_episodic_return( batch, buffer, indices, v_s_=v_s_, gamma=self.gamma, gae_lambda=1.0, ) if self.return_standardization: batch.returns = (unnormalized_returns - self.ret_rms.mean) / np.sqrt( self.ret_rms.var + self.eps, ) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns return cast(BatchWithReturnsProtocol, batch) class Reinforce(OnPolicyAlgorithm[ProbabilisticActorPolicy]): """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm.""" def __init__( self, *, policy: ProbabilisticActorPolicy, gamma: float = 0.99, return_standardization: bool = False, optim: OptimizerFactory, ) -> None: """ :param policy: the policy :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param return_standardization: if True, will scale/standardize returns by subtracting the running mean and dividing by the running standard deviation. Can be detrimental to performance! """ super().__init__( policy=policy, ) self.discounted_return_computation = DiscountedReturnComputation( gamma=gamma, return_standardization=return_standardization, ) self.optim = self._create_optimizer(self.policy, optim) def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithReturnsProtocol: return self.discounted_return_computation.add_discounted_returns( batch, buffer, indices, ) # Needs BatchWithReturnsProtocol, which violates the substitution principle. But not a problem since it's a private method and # the remainder of the class was adjusted to provide the correct batch def _update_with_batch( # type: ignore[override] self, batch: BatchWithReturnsProtocol, batch_size: int | None, repeat: int, ) -> LossSequenceTrainingStats: losses = [] split_batch_size = batch_size or -1 for _ in range(repeat): for minibatch in batch.split(split_batch_size, merge_last=True): result = self.policy(minibatch) dist = result.dist act = to_torch_as(minibatch.act, result.act) ret = to_torch(minibatch.returns, torch.float, result.act.device) log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) loss = -(log_prob * ret).mean() self.optim.step(loss) losses.append(loss.item()) return LossSequenceTrainingStats(loss=SequenceSummaryStats.from_sequence(losses)) ================================================ FILE: tianshou/algorithm/modelfree/sac.py ================================================ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Generic, Literal, TypeVar, Union, cast import gymnasium as gym import numpy as np import torch from torch.distributions import Independent, Normal from tianshou.algorithm.algorithm_base import TrainingStats from tianshou.algorithm.modelfree.ddpg import ContinuousPolicyWithExplorationNoise from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch from tianshou.data.types import ( DistLogProbBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise from tianshou.utils.conversion import to_optional_float from tianshou.utils.net.continuous import ContinuousActorProbabilistic def correct_log_prob_gaussian_tanh( log_prob: torch.Tensor, tanh_squashed_action: torch.Tensor, eps: float = np.finfo(np.float32).eps.item(), ) -> torch.Tensor: """Apply correction for Tanh squashing when computing `log_prob` from Gaussian. See equation 21 in the original `SAC paper `_. :param log_prob: log probability of the action :param tanh_squashed_action: action squashed to values in (-1, 1) range by tanh :param eps: epsilon for numerical stability """ log_prob_correction = torch.log(1 - tanh_squashed_action.pow(2) + eps).sum(-1, keepdim=True) return log_prob - log_prob_correction @dataclass(kw_only=True) class SACTrainingStats(TrainingStats): actor_loss: float critic1_loss: float critic2_loss: float alpha: float | None = None alpha_loss: float | None = None TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats) class SACPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, actor: torch.nn.Module | ContinuousActorProbabilistic, exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, action_scaling: bool = True, action_space: gym.Space, observation_space: gym.Space | None = None, ): """ :param actor: the actor network following the rules (s -> dist_input_BD) :param exploration_noise: add noise to action for exploration. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). :param deterministic_eval: flag indicating whether the policy should use deterministic actions (using the mode of the action distribution) instead of stochastic ones (using random sampling) during evaluation. When enabled, the policy will always select the most probable action according to the learned distribution during evaluation phases, while still using stochastic sampling during training. This creates a clear distinction between exploration (training) and exploitation (evaluation) behaviors. Deterministic actions are generally preferred for final deployment and reproducible evaluation as they provide consistent behavior, reduce variance in performance metrics, and are more interpretable for human observers. Note that this parameter only affects behavior when the policy is not within a training step. When collecting rollouts for training, actions remain stochastic regardless of this setting to maintain proper exploration behaviour. :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. This applies to continuous action spaces only (gym.spaces.Box) and has no effect for discrete spaces. When enabled, policy outputs are expected to be in the normalized range [-1, 1] (after bounding), and are then linearly transformed to the actual required range. This improves neural network training stability, allows the same algorithm to work across environments with different action ranges, and standardizes exploration strategies. Should be disabled if the actor model already produces outputs in the correct range. :param action_space: the environment's action_space. :param observation_space: the environment's observation space """ super().__init__( exploration_noise=exploration_noise, action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, # actions already squashed by tanh action_bound_method=None, ) self.actor = actor self.deterministic_eval = deterministic_eval def forward( # type: ignore self, batch: ObsBatchProtocol, state: dict | Batch | np.ndarray | None = None, **kwargs: Any, ) -> DistLogProbBatchProtocol: (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) if self.deterministic_eval and not self.is_within_training_step: act_B = dist.mode else: act_B = dist.rsample() log_prob = dist.log_prob(act_B).unsqueeze(-1) squashed_action = torch.tanh(act_B) log_prob = correct_log_prob_gaussian_tanh(log_prob, squashed_action) result = Batch( logits=(loc_B, scale_B), act=squashed_action, state=hidden_BH, dist=dist, log_prob=log_prob, ) return cast(DistLogProbBatchProtocol, result) class Alpha(ABC): """Defines the interface for the entropy regularization coefficient alpha.""" @staticmethod def from_float_or_instance(alpha: Union[float, "Alpha"]) -> "Alpha": if isinstance(alpha, float): return FixedAlpha(alpha) elif isinstance(alpha, Alpha): return alpha else: raise ValueError(f"Expected float or Alpha instance, but got {alpha=}") @property @abstractmethod def value(self) -> float: """Retrieves the current value of alpha.""" @abstractmethod def update(self, entropy: torch.Tensor) -> float | None: """ Updates the alpha value based on the entropy. :param entropy: the entropy of the policy. :return: the loss value if alpha is auto-tuned, otherwise None. """ return None class FixedAlpha(Alpha): """Represents a fixed entropy regularization coefficient alpha.""" def __init__(self, alpha: float): self._value = alpha @property def value(self) -> float: return self._value def update(self, entropy: torch.Tensor) -> float | None: return None class AutoAlpha(torch.nn.Module, Alpha): """Represents an entropy regularization coefficient alpha that is automatically tuned.""" def __init__(self, target_entropy: float, log_alpha: float, optim: OptimizerFactory): """ :param target_entropy: the target entropy value. For discrete action spaces, it is usually `-log(|A|)` for a balance between stochasticity and determinism or `-log(1/|A|)=log(|A|)` for maximum stochasticity or, more generally, `lambda*log(|A|)`, e.g. with `lambda` close to 1 (e.g. 0.98) for pronounced stochasticity. For continuous action spaces, it is usually `-dim(A)` for a balance between stochasticity and determinism, with similar generalizations as for discrete action spaces. :param log_alpha: the (initial) value of the log of the entropy regularization coefficient alpha. :param optim: the factory with which to create the optimizer for `log_alpha`. """ super().__init__() self._target_entropy = target_entropy self._log_alpha = torch.nn.Parameter(torch.tensor(log_alpha)) self._optim, lr_scheduler = optim.create_instances(self) if lr_scheduler is not None: raise ValueError( f"Learning rate schedulers are not supported by {self.__class__.__name__}" ) @property def value(self) -> float: return self._log_alpha.detach().exp().item() def update(self, entropy: torch.Tensor) -> float: entropy_deficit = self._target_entropy - entropy alpha_loss = -(self._log_alpha * entropy_deficit).mean() self._optim.zero_grad() alpha_loss.backward() self._optim.step() return alpha_loss.item() class SAC( ActorDualCriticsOffPolicyAlgorithm[SACPolicy, DistLogProbBatchProtocol], Generic[TSACTrainingStats], ): """Implementation of Soft Actor-Critic. arXiv:1812.05905.""" def __init__( self, *, policy: SACPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module, critic_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, alpha: float | Alpha = 0.2, n_step_return_horizon: int = 1, deterministic_eval: bool = True, ) -> None: """ :param policy: the policy :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network. (s, a -> Q(s, a)). If None, copy the first critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, use the first critic's factory. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param alpha: the entropy regularization coefficient, which balances exploration and exploitation. This coefficient controls how much the agent values randomness in its policy versus pursuing higher rewards. Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent for maintaining diverse action choices, even if this means selecting some lower-value actions. Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become more focused on the highest-value actions. A value of 0 would completely remove entropy regularization, potentially leading to premature convergence to suboptimal deterministic policies. Can be provided as a fixed float (0.2 is a reasonable default) or as an instance of, in particular, class `AutoAlpha` for automatic tuning during training. :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. """ super().__init__( policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, critic2=critic2, critic2_optim=critic2_optim, tau=tau, gamma=gamma, n_step_return_horizon=n_step_return_horizon, ) self.deterministic_eval = deterministic_eval self.alpha = Alpha.from_float_or_instance(alpha) self._check_field_validity() def _check_field_validity(self) -> None: if not isinstance(self.policy.action_space, gym.spaces.Box): raise ValueError( f"SACPolicy only supports gym.spaces.Box, but got {self.action_space=}." f"Please use DiscreteSACPolicy for discrete action spaces.", ) def _target_q_compute_value( self, obs_batch: Batch, act_batch: DistLogProbBatchProtocol ) -> torch.Tensor: min_q_value = super()._target_q_compute_value(obs_batch, act_batch) return min_q_value - self.alpha.value * act_batch.log_prob def _update_with_batch(self, batch: RolloutBatchProtocol) -> TSACTrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._minimize_critic_squared_loss( batch, self.critic, self.critic_optim ) td2, critic2_loss = self._minimize_critic_squared_loss( batch, self.critic2, self.critic2_optim ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor obs_result = self.policy(batch) act = obs_result.act current_q1a = self.critic(batch.obs, act).flatten() current_q2a = self.critic2(batch.obs, act).flatten() actor_loss = ( self.alpha.value * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a) ).mean() self.policy_optim.step(actor_loss) # The entropy of a Gaussian policy can be expressed as -log_prob + a constant (which we ignore) entropy = -obs_result.log_prob.detach() alpha_loss = self.alpha.update(entropy) self._update_lagged_network_weights() return SACTrainingStats( # type: ignore[return-value] actor_loss=actor_loss.item(), critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), alpha=to_optional_float(self.alpha.value), alpha_loss=to_optional_float(alpha_loss), ) ================================================ FILE: tianshou/algorithm/modelfree/td3.py ================================================ from abc import ABC from copy import deepcopy from dataclasses import dataclass from typing import Any import torch from tianshou.algorithm.algorithm_base import ( TPolicy, TrainingStats, ) from tianshou.algorithm.modelfree.ddpg import ( ActorCriticOffPolicyAlgorithm, ContinuousDeterministicPolicy, TActBatchProtocol, ) from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch from tianshou.data.types import ( ActStateBatchProtocol, RolloutBatchProtocol, ) @dataclass(kw_only=True) class TD3TrainingStats(TrainingStats): actor_loss: float critic1_loss: float critic2_loss: float class ActorDualCriticsOffPolicyAlgorithm( ActorCriticOffPolicyAlgorithm[TPolicy, TActBatchProtocol], ABC, ): """A base class for off-policy algorithms with two critics, where the target Q-value is computed as the minimum of the two lagged critics' values. """ def __init__( self, *, policy: Any, policy_optim: OptimizerFactory, critic: torch.nn.Module, critic_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. For continuous action spaces: (s, a -> Q(s, a)). **NOTE**: The default implementation of `_target_q_compute_value` assumes a continuous action space; override this method if using discrete actions. :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network (analogous functionality to the first). If None, copy the first critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, use the first critic's factory. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks """ super().__init__( policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, tau=tau, gamma=gamma, n_step_return_horizon=n_step_return_horizon, ) self.critic2 = critic2 or deepcopy(critic) self.critic2_old = self._add_lagged_network(self.critic2) self.critic2_optim = self._create_optimizer(self.critic2, critic2_optim or critic_optim) def _target_q_compute_value( self, obs_batch: Batch, act_batch: TActBatchProtocol ) -> torch.Tensor: # compute the Q-value as the minimum of the two lagged critics act = act_batch.act return torch.min( self.critic_old(obs_batch.obs, act), self.critic2_old(obs_batch.obs, act), ) class TD3( ActorDualCriticsOffPolicyAlgorithm[ContinuousDeterministicPolicy, ActStateBatchProtocol], ): """Implementation of TD3, arXiv:1802.09477.""" def __init__( self, *, policy: ContinuousDeterministicPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module, critic_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, policy_noise: float = 0.2, update_actor_freq: int = 2, noise_clip: float = 0.5, n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network. (s, a -> Q(s, a)). If None, copy the first critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, use the first critic's factory. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param policy_noise: scaling factor for the Gaussian noise added to target policy actions. This parameter implements target policy smoothing, a regularization technique described in the TD3 paper. The noise is sampled from a normal distribution and multiplied by this value before being added to actions. Higher values increase exploration in the target policy, helping to address function approximation error. The added noise is optionally clipped to a range determined by the noise_clip parameter. Typically set between 0.1 and 0.5 relative to the action scale of the environment. :param update_actor_freq: the frequency of actor network updates relative to critic network updates (the actor network is only updated once for every `update_actor_freq` critic updates). This implements the "delayed" policy updates from the TD3 algorithm, where the actor is updated less frequently than the critics. Higher values (e.g., 2-5) help stabilize training by allowing the critic to become more accurate before updating the policy. The default value of 2 follows the original TD3 paper's recommendation of updating the policy at half the rate of the Q-functions. :param noise_clip: defines the maximum absolute value of the noise added to target policy actions, i.e. noise values are clipped to the range [-noise_clip, noise_clip] (after generating and scaling the noise via `policy_noise`). This parameter implements bounded target policy smoothing as described in the TD3 paper. It prevents extreme noise values from causing unrealistic target values during training. Setting it 0.0 (or a negative value) disables clipping entirely. It is typically set to about twice the `policy_noise` value (e.g. 0.5 when `policy_noise` is 0.2). """ super().__init__( policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, critic2=critic2, critic2_optim=critic2_optim, tau=tau, gamma=gamma, n_step_return_horizon=n_step_return_horizon, ) self.actor_old = self._add_lagged_network(self.policy.actor) self.policy_noise = policy_noise self.update_actor_freq = update_actor_freq self.noise_clip = noise_clip self._cnt = 0 self._last = 0 def _target_q_compute_action(self, obs_batch: Batch) -> ActStateBatchProtocol: # compute action using lagged actor act_batch = self.policy(obs_batch, model=self.actor_old) act_ = act_batch.act # add noise noise = torch.randn(size=act_.shape, device=act_.device) * self.policy_noise if self.noise_clip > 0.0: noise = noise.clamp(-self.noise_clip, self.noise_clip) act_ += noise act_batch.act = act_ return act_batch def _update_with_batch(self, batch: RolloutBatchProtocol) -> TD3TrainingStats: # critic 1&2 td1, critic1_loss = self._minimize_critic_squared_loss( batch, self.critic, self.critic_optim ) td2, critic2_loss = self._minimize_critic_squared_loss( batch, self.critic2, self.critic2_optim ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor if self._cnt % self.update_actor_freq == 0: actor_loss = -self.critic(batch.obs, self.policy(batch, eps=0.0).act).mean() self._last = actor_loss.item() self.policy_optim.step(actor_loss) self._update_lagged_network_weights() self._cnt += 1 return TD3TrainingStats( actor_loss=self._last, critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), ) ================================================ FILE: tianshou/algorithm/modelfree/trpo.py ================================================ import warnings from dataclasses import dataclass import torch import torch.nn.functional as F from torch.distributions import kl_divergence from tianshou.algorithm import NPG from tianshou.algorithm.modelfree.npg import NPGTrainingStats from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import SequenceSummaryStats from tianshou.data.types import BatchWithAdvantagesProtocol from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic @dataclass(kw_only=True) class TRPOTrainingStats(NPGTrainingStats): step_size: SequenceSummaryStats class TRPO(NPG): """Implementation of Trust Region Policy Optimization. arXiv:1502.05477.""" def __init__( self, *, policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, max_kl: float = 0.01, backtrack_coeff: float = 0.8, max_backtracks: int = 10, optim_critic_iters: int = 5, trust_region_size: float = 0.5, advantage_normalization: bool = True, gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, return_scaling: bool = False, ) -> None: """ :param policy: the policy :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the critic network. :param max_kl: max kl-divergence used to constrain each actor network update. :param backtrack_coeff: Coefficient to be multiplied by step size when constraints are not met. :param max_backtracks: Max number of backtracking times in linesearch. :param optim_critic_iters: the number of optimization steps performed on the critic network for each policy (actor) update. Controls the learning rate balance between critic and actor. Higher values prioritize critic accuracy by training the value function more extensively before each policy update, which can improve stability but slow down training. Lower values maintain a more even learning pace between policy and value function but may lead to less reliable advantage estimates. Typically set between 1 and 10, depending on the complexity of the value function. :param trust_region_size: the parameter delta - a scalar multiplier for policy updates in the natural gradient direction. The mathematical meaning is the trust region size, which is the maximum KL divergence allowed between the old and new policy distributions. Controls how far the policy parameters move in the calculated direction during each update. Higher values allow for faster learning but may cause instability or policy deterioration; lower values provide more stable but slower learning. Unlike regular policy gradients, natural gradients already account for the local geometry of the parameter space, making this step size more robust to different parameterizations. Typically set between 0.1 and 1.0 for most reinforcement learning tasks. :param advantage_normalization: whether to do per mini-batch advantage normalization. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values (closer to 1) reduce bias but increase variance by giving more weight to longer trajectories, while lower values (closer to 0) reduce variance but increase bias by relying more on the immediate TD error and value function estimates. At λ=0, GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. :param max_batchsize: the maximum number of samples to process at once when computing generalized advantage estimation (GAE) and value function predictions. Controls memory usage by breaking large batches into smaller chunks processed sequentially. Higher values may increase speed but require more GPU/CPU memory; lower values reduce memory requirements but may increase computation time. Should be adjusted based on available hardware resources and total batch size of your training data. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param return_scaling: flag indicating whether to enable scaling of estimated returns by dividing them by their running standard deviation without centering the mean. This reduces the magnitude variation of advantages across different episodes while preserving their signs and relative ordering. The use of running statistics (rather than batch-specific scaling) means that early training experiences may be scaled differently than later ones as the statistics evolve. When enabled, this improves training stability in environments with highly variable reward scales and makes the algorithm less sensitive to learning rate settings. However, it may reduce the algorithm's ability to distinguish between episodes with different absolute return magnitudes. Best used in environments where the relative ordering of actions is more important than the absolute scale of returns. """ super().__init__( policy=policy, critic=critic, optim=optim, optim_critic_iters=optim_critic_iters, trust_region_size=trust_region_size, advantage_normalization=advantage_normalization, gae_lambda=gae_lambda, max_batchsize=max_batchsize, gamma=gamma, return_scaling=return_scaling, ) self.max_backtracks = max_backtracks self.max_kl = max_kl self.backtrack_coeff = backtrack_coeff def _update_with_batch( # type: ignore[override] self, batch: BatchWithAdvantagesProtocol, batch_size: int | None, repeat: int, ) -> TRPOTrainingStats: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): for minibatch in batch.split(split_batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient dist = self.policy(minibatch).dist ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) actor_loss = -(ratio * minibatch.adv).mean() flat_grads = self._get_flat_grad( actor_loss, self.policy.actor, retain_graph=True ).detach() # direction: calculate natural gradient with torch.no_grad(): old_dist = self.policy(minibatch).dist kl = kl_divergence(old_dist, dist).mean() # calculate first order gradient of kl with respect to theta flat_kl_grad = self._get_flat_grad(kl, self.policy.actor, create_graph=True) search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, nsteps=10) # stepsize: calculate max stepsize constrained by kl bound step_size = torch.sqrt( 2 * self.max_kl / (search_direction * self._MVP(search_direction, flat_kl_grad)).sum( 0, keepdim=True, ), ) # stepsize: linesearch stepsize with torch.no_grad(): flat_params = torch.cat( [param.data.view(-1) for param in self.policy.actor.parameters()], ) for i in range(self.max_backtracks): new_flat_params = flat_params + step_size * search_direction self._set_from_flat_params(self.policy.actor, new_flat_params) # calculate kl and if in bound, loss actually down new_dist = self.policy(minibatch).dist new_dratio = ( (new_dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() ) new_dratio = new_dratio.reshape(new_dratio.size(0), -1).transpose(0, 1) new_actor_loss = -(new_dratio * minibatch.adv).mean() kl = kl_divergence(old_dist, new_dist).mean() if kl < self.max_kl and new_actor_loss < actor_loss: if i > 0: warnings.warn(f"Backtracking to step {i}.") break if i < self.max_backtracks - 1: step_size = step_size * self.backtrack_coeff else: self._set_from_flat_params(self.policy.actor, new_flat_params) step_size = torch.tensor([0.0]) warnings.warn( "Line search failed! It seems hyperparamters" " are poor and need to be changed.", ) # optimize critic for _ in range(self.optim_critic_iters): value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) self.optim.step(vf_loss) actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) step_sizes.append(step_size.item()) kls.append(kl.item()) actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) kl_summary_stat = SequenceSummaryStats.from_sequence(kls) step_size_stat = SequenceSummaryStats.from_sequence(step_sizes) return TRPOTrainingStats( actor_loss=actor_loss_summary_stat, vf_loss=vf_loss_summary_stat, kl=kl_summary_stat, step_size=step_size_stat, ) ================================================ FILE: tianshou/algorithm/multiagent/__init__.py ================================================ ================================================ FILE: tianshou/algorithm/multiagent/marl.py ================================================ from collections.abc import Callable from typing import Any, Generic, Literal, Protocol, Self, TypeVar, cast, overload import numpy as np from overrides import override from sensai.util.helper import mark_used from torch.nn import ModuleList from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( OffPolicyAlgorithm, OnPolicyAlgorithm, Policy, TrainingStats, ) from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol, IndexType from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol try: from tianshou.env.pettingzoo_env import PettingZooEnv except ImportError: PettingZooEnv = None # type: ignore mark_used(ActBatchProtocol) class MapTrainingStats(TrainingStats): def __init__( self, agent_id_to_stats: dict[str | int, TrainingStats], train_time_aggregator: Literal["min", "max", "mean"] = "max", ) -> None: self._agent_id_to_stats = agent_id_to_stats train_times = [agent_stats.train_time for agent_stats in agent_id_to_stats.values()] match train_time_aggregator: case "max": aggr_function = max case "min": aggr_function = min case "mean": aggr_function = np.mean # type: ignore case _: raise ValueError( f"Unknown {train_time_aggregator=}", ) self.train_time = aggr_function(train_times) self.smoothed_loss = {} @override def get_loss_stats_dict(self) -> dict[str, float]: """Collects loss_stats_dicts from all agents, prepends agent_id to all keys, and joins results.""" result_dict = {} for agent_id, stats in self._agent_id_to_stats.items(): agent_loss_stats_dict = stats.get_loss_stats_dict() for k, v in agent_loss_stats_dict.items(): result_dict[f"{agent_id}/" + k] = v return result_dict class MAPRolloutBatchProtocol(RolloutBatchProtocol, Protocol): # TODO: this might not be entirely correct. # The whole MAP data processing pipeline needs more documentation and possibly some refactoring @overload def __getitem__(self, index: str) -> RolloutBatchProtocol: ... @overload def __getitem__(self, index: IndexType) -> Self: ... def __getitem__(self, index: str | IndexType) -> Any: ... class MultiAgentPolicy(Policy): def __init__(self, policies: dict[str | int, Policy]): p0 = next(iter(policies.values())) super().__init__( action_space=p0.action_space, observation_space=p0.observation_space, action_scaling=False, action_bound_method=None, ) self.policies = policies self._submodules = ModuleList(policies.values()) _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") def add_exploration_noise( self, act: _TArrOrActBatch, batch: ObsBatchProtocol, ) -> _TArrOrActBatch: """Add exploration noise from sub-policy onto act.""" if not isinstance(batch.obs, Batch): raise TypeError( f"here only observations of type Batch are permitted, but got {type(batch.obs)}", ) for agent_id, policy in self.policies.items(): agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: continue act[agent_index] = policy.add_exploration_noise(act[agent_index], batch[agent_index]) return act def forward( # type: ignore self, batch: Batch, state: dict | Batch | None = None, **kwargs: Any, ) -> Batch: """Dispatch batch data from obs.agent_id to every policy's forward. :param batch: TODO: document what is expected at input and make a BatchProtocol for it :param state: if None, it means all agents have no state. If not None, it should contain keys of "agent_1", "agent_2", ... :return: a Batch with the following contents: TODO: establish a BatcProtocol for this :: { "act": actions corresponding to the input "state": { "agent_1": output state of agent_1's policy for the state "agent_2": xxx ... "agent_n": xxx} "out": { "agent_1": output of agent_1's policy for the input "agent_2": xxx ... "agent_n": xxx} } """ results: list[tuple[bool, np.ndarray, Batch, np.ndarray | Batch, Batch]] = [] for agent_id, policy in self.policies.items(): # This part of code is difficult to understand. # Let's follow an example with two agents # batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6) # each agent plays for three transitions # agent_index for agent 1 is [0, 2, 4] # agent_index for agent 2 is [1, 3, 5] # we separate the transition of each agent according to agent_id agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: # (has_data, agent_index, out, act, state) results.append((False, np.array([-1]), Batch(), Batch(), Batch())) continue tmp_batch = batch[agent_index] if "rew" in tmp_batch.get_keys() and isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]] if not hasattr(tmp_batch.obs, "mask"): if hasattr(tmp_batch.obs, "obs"): tmp_batch.obs = tmp_batch.obs.obs if hasattr(tmp_batch.obs_next, "obs"): tmp_batch.obs_next = tmp_batch.obs_next.obs out = policy( batch=tmp_batch, state=None if state is None else state[agent_id], **kwargs, ) act = out.act each_state = out.state if (hasattr(out, "state") and out.state is not None) else Batch() results.append((True, agent_index, out, act, each_state)) holder: Batch = Batch.cat( [{"act": act} for (has_data, agent_index, out, act, each_state) in results if has_data], ) state_dict, out_dict = {}, {} for (agent_id, _), (has_data, agent_index, out, act, state) in zip( self.policies.items(), results, strict=True, ): if has_data: holder.act[agent_index] = act state_dict[agent_id] = state out_dict[agent_id] = out holder["out"] = out_dict holder["state"] = state_dict return holder TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) class MARLDispatcher(Generic[TAlgorithm]): """ Supports multi-agent learning by dispatching calls to the corresponding algorithm for each agent. """ def __init__(self, algorithms: list[TAlgorithm], env: PettingZooEnv): agent_ids = env.agents assert len(algorithms) == len(agent_ids), "One policy must be assigned for each agent." self.algorithms: dict[str | int, TAlgorithm] = dict(zip(agent_ids, algorithms, strict=True)) """maps agent_id to the corresponding algorithm.""" self.agent_idx = env.agent_idx """maps agent_id to 0-based index.""" def create_policy(self) -> MultiAgentPolicy: return MultiAgentPolicy({agent_id: a.policy for agent_id, a in self.algorithms.items()}) def dispatch_process_fn( self, batch: MAPRolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> MAPRolloutBatchProtocol: """Dispatch batch data from `obs.agent_id` to every algorithm's processing function. Save original multi-dimensional rew in "save_rew", set rew to the reward of each agent during their "process_fn", and restore the original reward afterwards. """ # TODO: maybe only str is actually allowed as agent_id? See MAPRolloutBatchProtocol results: dict[str | int, RolloutBatchProtocol] = {} assert isinstance( batch.obs, BatchProtocol, ), f"here only observations of type Batch are permitted, but got {type(batch.obs)}" # reward can be empty Batch (after initial reset) or nparray. has_rew = isinstance(buffer.rew, np.ndarray) if has_rew: # save the original reward in save_rew # Since we do not override buffer.__setattr__, here we use _meta to # change buffer.rew, otherwise buffer.rew = Batch() has no effect. save_rew, buffer._meta.rew = buffer.rew, Batch() # type: ignore for agent, algorithm in self.algorithms.items(): agent_index = np.nonzero(batch.obs.agent_id == agent)[0] if len(agent_index) == 0: results[agent] = cast(RolloutBatchProtocol, Batch()) continue tmp_batch, tmp_indice = batch[agent_index], indices[agent_index] if has_rew: tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]] buffer._meta.rew = save_rew[:, self.agent_idx[agent]] if not hasattr(tmp_batch.obs, "mask"): if hasattr(tmp_batch.obs, "obs"): tmp_batch.obs = tmp_batch.obs.obs if hasattr(tmp_batch.obs_next, "obs"): tmp_batch.obs_next = tmp_batch.obs_next.obs results[agent] = algorithm._preprocess_batch(tmp_batch, buffer, tmp_indice) if has_rew: # restore from save_rew buffer._meta.rew = save_rew return cast(MAPRolloutBatchProtocol, Batch(results)) def dispatch_update_with_batch( self, batch: MAPRolloutBatchProtocol, algorithm_update_with_batch_fn: Callable[[TAlgorithm, RolloutBatchProtocol], TrainingStats], ) -> MapTrainingStats: """Dispatch the respective subset of the batch data to each algorithm. :param batch: must map agent_ids to rollout batches :param algorithm_update_with_batch_fn: a function that performs the algorithm-specific update with the given agent-specific batch data """ agent_id_to_stats = {} for agent_id, algorithm in self.algorithms.items(): data = batch[agent_id] if len(data.get_keys()) != 0: train_stats = algorithm_update_with_batch_fn(algorithm, data) agent_id_to_stats[agent_id] = train_stats return MapTrainingStats(agent_id_to_stats) class MultiAgentOffPolicyAlgorithm(OffPolicyAlgorithm[MultiAgentPolicy]): """Multi-agent reinforcement learning where each agent uses off-policy learning.""" def __init__( self, *, algorithms: list[OffPolicyAlgorithm], env: PettingZooEnv, ) -> None: """ :param algorithms: a list of off-policy algorithms. :param env: the multi-agent RL environment """ self._dispatcher: MARLDispatcher[OffPolicyAlgorithm] = MARLDispatcher(algorithms, env) super().__init__( policy=self._dispatcher.create_policy(), ) self._submodules = ModuleList(algorithms) def get_algorithm(self, agent_id: str | int) -> OffPolicyAlgorithm: return self._dispatcher.algorithms[agent_id] def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: batch = cast(MAPRolloutBatchProtocol, batch) return self._dispatcher.dispatch_process_fn(batch, buffer, indices) def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> MapTrainingStats: batch = cast(MAPRolloutBatchProtocol, batch) def update(algorithm: OffPolicyAlgorithm, data: RolloutBatchProtocol) -> TrainingStats: return algorithm._update_with_batch(data) return self._dispatcher.dispatch_update_with_batch(batch, update) class MultiAgentOnPolicyAlgorithm(OnPolicyAlgorithm[MultiAgentPolicy]): """Multi-agent reinforcement learning where each agent uses on-policy learning.""" def __init__( self, *, algorithms: list[OnPolicyAlgorithm], env: PettingZooEnv, ) -> None: """ :param algorithms: a list of off-policy algorithms. :param env: the multi-agent RL environment """ self._dispatcher: MARLDispatcher[OnPolicyAlgorithm] = MARLDispatcher(algorithms, env) super().__init__( policy=self._dispatcher.create_policy(), ) self._submodules = ModuleList(algorithms) def get_algorithm(self, agent_id: str | int) -> OnPolicyAlgorithm: return self._dispatcher.algorithms[agent_id] def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: batch = cast(MAPRolloutBatchProtocol, batch) return self._dispatcher.dispatch_process_fn(batch, buffer, indices) def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int ) -> MapTrainingStats: batch = cast(MAPRolloutBatchProtocol, batch) def update(algorithm: OnPolicyAlgorithm, data: RolloutBatchProtocol) -> TrainingStats: return algorithm._update_with_batch(data, batch_size, repeat) return self._dispatcher.dispatch_update_with_batch(batch, update) ================================================ FILE: tianshou/algorithm/optim.py ================================================ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable from typing import Any, Self, TypeAlias import numpy as np import torch from sensai.util.string import ToStringMixin from torch.optim import Adam, RMSprop from torch.optim.lr_scheduler import LambdaLR, LRScheduler ParamsType: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] class LRSchedulerFactory(ToStringMixin, ABC): """Factory for the creation of a learning rate scheduler.""" @abstractmethod def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: pass class LRSchedulerFactoryLinear(LRSchedulerFactory): """ Factory for a learning rate scheduler where the learning rate linearly decays towards zero for the given trainer parameters. """ def __init__(self, max_epochs: int, epoch_num_steps: int, collection_step_num_env_steps: int): self.num_epochs = max_epochs self.epoch_num_steps = epoch_num_steps self.collection_step_num_env_steps = collection_step_num_env_steps def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: return LambdaLR(optim, lr_lambda=self._LRLambda(self).compute) class _LRLambda: def __init__(self, parent: "LRSchedulerFactoryLinear"): self.max_update_num = ( np.ceil(parent.epoch_num_steps / parent.collection_step_num_env_steps) * parent.num_epochs ) def compute(self, epoch: int) -> float: return 1.0 - epoch / self.max_update_num class OptimizerFactory(ABC, ToStringMixin): def __init__(self) -> None: self.lr_scheduler_factory: LRSchedulerFactory | None = None def with_lr_scheduler_factory(self, lr_scheduler_factory: LRSchedulerFactory) -> Self: self.lr_scheduler_factory = lr_scheduler_factory return self def create_instances( self, module: torch.nn.Module, ) -> tuple[torch.optim.Optimizer, LRScheduler | None]: optimizer = self._create_optimizer_for_params(module.parameters()) lr_scheduler = None if self.lr_scheduler_factory is not None: lr_scheduler = self.lr_scheduler_factory.create_scheduler(optimizer) return optimizer, lr_scheduler @abstractmethod def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: pass class TorchOptimizerFactory(OptimizerFactory): """General factory for arbitrary torch optimizers.""" def __init__(self, optim_class: Callable[..., torch.optim.Optimizer], **kwargs: Any): """ :param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), which will be passed the module parameters, the learning rate as `lr` and the kwargs provided. :param kwargs: keyword arguments to provide at optimizer construction """ super().__init__() self.optim_class = optim_class self.kwargs = kwargs def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: return self.optim_class(params, **self.kwargs) class AdamOptimizerFactory(OptimizerFactory): def __init__( self, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, ): super().__init__() self.lr = lr self.weight_decay = weight_decay self.eps = eps self.betas = betas def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: return Adam( params, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, ) class RMSpropOptimizerFactory(OptimizerFactory): def __init__( self, lr: float = 1e-2, alpha: float = 0.99, eps: float = 1e-08, weight_decay: float = 0, momentum: float = 0, centered: bool = False, ): super().__init__() self.lr = lr self.alpha = alpha self.momentum = momentum self.centered = centered self.weight_decay = weight_decay self.eps = eps def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: return RMSprop( params, lr=self.lr, alpha=self.alpha, eps=self.eps, weight_decay=self.weight_decay, momentum=self.momentum, centered=self.centered, ) ================================================ FILE: tianshou/algorithm/random.py ================================================ from typing import cast import gymnasium as gym import numpy as np from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, TrainingStats from tianshou.algorithm.algorithm_base import Policy as BasePolicy from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol class MARLRandomTrainingStats(TrainingStats): pass class MARLRandomDiscreteMaskedOffPolicyAlgorithm(OffPolicyAlgorithm): """A random agent used in multi-agent learning. It randomly chooses an action from the legal actions (according to the given mask). """ class Policy(BasePolicy): """A random agent used in multi-agent learning. It randomly chooses an action from the legal actions. """ def __init__(self, action_space: gym.spaces.Space) -> None: super().__init__(action_space=action_space) def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: dict, ) -> ActBatchProtocol: """Compute the random action over the given batch data. The input should contain a mask in batch.obs, with "True" to be available and "False" to be unavailable. For example, ``batch.obs.mask == np.array([[False, True, False]])`` means with batch size 1, action "1" is available but action "0" and "2" are unavailable. :return: A :class:`~tianshou.data.Batch` with "act" key, containing the random action. """ mask = batch.obs.mask # type: ignore logits = np.random.rand(*mask.shape) logits[~mask] = -np.inf result = Batch(act=logits.argmax(axis=-1)) return cast(ActBatchProtocol, result) def __init__(self, action_space: gym.spaces.Space) -> None: """:param action_space: the environment's action space.""" super().__init__(policy=self.Policy(action_space)) def _update_with_batch(self, batch: RolloutBatchProtocol) -> MARLRandomTrainingStats: # type: ignore """Since a random agent learns nothing, it returns an empty dict.""" return MARLRandomTrainingStats() ================================================ FILE: tianshou/config.py ================================================ ENABLE_VALIDATION = False """Validation can help catching bugs and issues but it slows down training and collection. Enable it only if needed.""" ================================================ FILE: tianshou/data/__init__.py ================================================ """Data package.""" # isort:skip_file from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer.buffer_base import ReplayBuffer from tianshou.data.buffer.prio import PrioritizedReplayBuffer from tianshou.data.buffer.her import HERReplayBuffer from tianshou.data.buffer.manager import ( ReplayBufferManager, PrioritizedReplayBufferManager, HERReplayBufferManager, ) from tianshou.data.buffer.vecbuf import ( HERVectorReplayBuffer, PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) from tianshou.data.buffer.cached import CachedReplayBuffer from tianshou.data.stats import ( EpochStats, InfoStats, SequenceSummaryStats, TimingStats, ) from tianshou.data.collector import ( Collector, AsyncCollector, CollectStats, CollectStatsBase, BaseCollector, ) __all__ = [ "AsyncCollector", "BaseCollector", "Batch", "CachedReplayBuffer", "CollectStats", "CollectStatsBase", "Collector", "EpochStats", "HERReplayBuffer", "HERReplayBufferManager", "HERVectorReplayBuffer", "InfoStats", "PrioritizedReplayBuffer", "PrioritizedReplayBufferManager", "PrioritizedVectorReplayBuffer", "ReplayBuffer", "ReplayBufferManager", "SegmentTree", "SequenceSummaryStats", "TimingStats", "VectorReplayBuffer", "to_numpy", "to_torch", "to_torch_as", ] ================================================ FILE: tianshou/data/batch.py ================================================ """This module implements :class:`Batch`, a flexible data structure for handling heterogeneous data in reinforcement learning algorithms. Such a data structure is needed since RL algorithms differ widely in the conceptual fields that they need. `Batch` is the main data carrier in Tianshou. It bears some similarities to `TensorDict `_ that is used for a similar purpose in `pytorch-rl `_. The main differences between the two are that `Batch` can hold arbitrary objects (and not just torch tensors), and that Tianshou implements `BatchProtocol` for enabling type checking and autocompletion (more on that below). The `Batch` class is designed to store and manipulate collections of data with varying types and structures. It strikes a balance between flexibility and type safety, the latter mainly achieved through the use of protocols. One can thing of it as a mixture of a dictionary and an array, as it has both key-value pairs and nesting, while also having a shape, being indexable and sliceable. Key features of the `Batch` class include: 1. Flexible data storage: Can hold numpy arrays, torch tensors, scalars, and nested Batch objects. 2. Dynamic attribute access: Allows setting and accessing data using attribute notation (e.g., `batch.observation`). This allows for type-safe and readable code and enables IDE autocompletion. See comments on `BatchProtocol` below. 3. Indexing and slicing: Supports numpy-like indexing and slicing operations. The slicing is extended to nested Batch objects and torch Distributions. 4. Batch operations: Provides methods for splitting, shuffling, concatenating and stacking multiple Batch objects. 5. Data type conversion: Offers methods to convert data between numpy arrays and torch tensors. 6. Value transformations: Allows applying functions to all values in the Batch recursively. 7. Analysis utilities: Provides methods for checking for missing values, dropping entries with missing values, and others. Since we want to keep `Batch` flexible and not fix a specific set of fields or their types, we don't have fixed interfaces for actual `Batch` objects that are used throughout tianshou (such interfaces could be dataclasses, for example). However, we still want to enable IDE autocompletion and type checking for `Batch` objects. To achieve this, we rely on dynamic duck typing by using `Protocol`. The :class:`BatchProtocol` defines the interface that all `Batch` objects should adhere to, and its various implementations (like :class:`~.types.ActBatchProtocol` or :class:`~.types.RolloutBatchProtocol`) define the specific fields that are expected in the respective `Batch` objects. The protocols are then used as type hints throughout the codebase. Protocols can't be instantiated, but we can cast to them. For example, we "instantiate" an `ActBatchProtocol` with something like: >>> act_batch = cast(ActBatchProtocol, Batch(act=my_action)) The users can decide for themselves how to structure their `Batch` objects, and can opt in to the `BatchProtocol` style to enable type checking and autocompletion. Opting out will have no effect on the functionality. """ import pprint import warnings from collections.abc import Callable, Collection, Iterable, Iterator, KeysView, Sequence from copy import deepcopy from numbers import Number from types import EllipsisType from typing import ( Any, Literal, Protocol, Self, TypeVar, Union, cast, overload, runtime_checkable, ) import numpy as np import pandas as pd import torch from deepdiff import DeepDiff from sensai.util import logging from torch.distributions import Categorical, Distribution, Independent, Normal _SingleIndexType = slice | int | EllipsisType IndexType = np.ndarray | _SingleIndexType | Sequence[_SingleIndexType] TBatch = TypeVar("TBatch", bound="BatchProtocol") TDistribution = TypeVar("TDistribution", bound=Distribution) T = TypeVar("T") TArr = torch.Tensor | np.ndarray TObsArr = torch.Tensor | np.ndarray log = logging.getLogger(__name__) def _is_batch_set(obj: Any) -> bool: # Batch set is a list/tuple of dict/Batch objects, # or 1-D np.ndarray with object type, # where each element is a dict/Batch object if isinstance(obj, np.ndarray): # most often case # "for element in obj" will just unpack the first dimension, # but obj.tolist() will flatten ndarray of objects # so do not use obj.tolist() if obj.shape == (): return False return obj.dtype == object and all(isinstance(element, dict | Batch) for element in obj) return ( isinstance(obj, list | tuple) and len(obj) > 0 and all(isinstance(element, dict | Batch) for element in obj) ) def _is_scalar(value: Any) -> bool: # check if the value is a scalar # 1. python bool object, number object: isinstance(value, Number) # 2. numpy scalar: isinstance(value, np.generic) # 3. python object rather than dict / Batch / tensor # the check of dict / Batch is omitted because this only checks a value. # a dict / Batch will eventually check their values if isinstance(value, torch.Tensor): return value.numel() == 1 and not value.shape # np.asanyarray will cause dead loop in some cases return np.isscalar(value) def _is_number(value: Any) -> bool: # isinstance(value, Number) checks 1, 1.0, np.int(1), np.float(1.0), etc. # isinstance(value, np.nummber) checks np.int32(1), np.float64(1.0), etc. # isinstance(value, np.bool_) checks np.bool_(True), etc. # similar to np.isscalar but np.isscalar('st') returns True return isinstance(value, Number | np.number | np.bool_) def _to_array_with_correct_type(obj: Any) -> np.ndarray: if isinstance(obj, np.ndarray) and issubclass(obj.dtype.type, np.bool_ | np.number): return obj # most often case # convert the value to np.ndarray # convert to object obj type if neither bool nor number # raises an exception if array's elements are tensors themselves try: obj_array = np.asanyarray(obj) except ValueError: obj_array = np.asanyarray(obj, dtype=object) if not issubclass(obj_array.dtype.type, np.bool_ | np.number): obj_array = obj_array.astype(object) if obj_array.dtype == object: # scalar ndarray with object obj type is very annoying # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) # a is not array([{}, {}], dtype=object), and a[0]={} results in # something very strange: # array([{}, array({}, dtype=object)], dtype=object) if not obj_array.shape: obj_array = obj_array.item(0) elif all(isinstance(arr, np.ndarray) for arr in obj_array.reshape(-1)): return obj_array # various length, np.array([[1], [2, 3], [4, 5, 6]]) elif any(isinstance(arr, torch.Tensor) for arr in obj_array.reshape(-1)): raise ValueError("Numpy arrays of tensors are not supported yet.") return obj_array def create_value( inst: Any, size: int, stack: bool = True, ) -> Union["Batch", np.ndarray, torch.Tensor]: """Create empty place-holders according to inst's shape. :param stack: whether to stack or to concatenate. E.g. if inst has shape of (3, 5), size = 10, stack=True returns an np.array with shape of (10, 3, 5), otherwise (10, 5) """ has_shape = isinstance(inst, np.ndarray | torch.Tensor) is_scalar = _is_scalar(inst) if not stack and is_scalar: # should never hit since it has already checked in Batch.cat_ , here we do not # consider scalar types, following the behavior of numpy which does not support # concatenation of zero-dimensional arrays (scalars) raise TypeError(f"cannot concatenate with {inst} which is scalar") if has_shape: shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) if isinstance(inst, np.ndarray): target_type = ( inst.dtype.type if issubclass(inst.dtype.type, np.bool_ | np.number) else object ) return np.full(shape, fill_value=None if target_type is object else 0, dtype=target_type) if isinstance(inst, torch.Tensor): return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype) if isinstance(inst, dict | Batch): zero_batch = Batch() for key, val in inst.items(): zero_batch.__dict__[key] = create_value(val, size, stack=stack) return zero_batch if is_scalar: return create_value(np.asarray(inst), size, stack=stack) # fall back to object return np.array([None for _ in range(size)], object) def _assert_type_keys(keys: Iterable[str]) -> None: assert all(isinstance(key, str) for key in keys), f"keys should all be string, but got {keys}" def _parse_value(obj: Any) -> Union["Batch", np.ndarray, torch.Tensor] | None: if isinstance(obj, Batch): # most often case return obj if ( (isinstance(obj, np.ndarray) and issubclass(obj.dtype.type, np.bool_ | np.number)) or isinstance(obj, torch.Tensor) or obj is None ): # third often case return obj if _is_number(obj): # second often case, but it is more time-consuming return np.asanyarray(obj) if isinstance(obj, dict): return Batch(obj) if ( not isinstance(obj, np.ndarray) and isinstance(obj, Collection) and len(obj) > 0 and all(isinstance(element, torch.Tensor) for element in obj) ): try: obj = cast(list[torch.Tensor], obj) return torch.stack(obj) except RuntimeError as exception: raise TypeError( "Batch does not support non-stackable iterable" " of torch.Tensor as unique value yet.", ) from exception if _is_batch_set(obj): obj = Batch(obj) # list of dict / Batch else: # None, scalar, normal obj list (main case) # or an actual list of objects try: obj = _to_array_with_correct_type(obj) except ValueError as exception: raise TypeError( "Batch does not support heterogeneous list/tuple of tensors as unique value yet.", ) from exception return obj def alloc_by_keys_diff( meta: "BatchProtocol", batch: "BatchProtocol", size: int, stack: bool = True, ) -> None: """Creates place-holders inside meta for keys that are in batch but not in meta. This mainly is an internal method, use it only if you know what you are doing. """ for key in batch.get_keys(): if key in meta.get_keys(): if isinstance(meta[key], Batch) and isinstance(batch[key], Batch): alloc_by_keys_diff(meta[key], batch[key], size, stack) elif isinstance(meta[key], Batch) and len(meta[key].get_keys()) == 0: meta[key] = create_value(batch[key], size, stack) else: meta[key] = create_value(batch[key], size, stack) class ProtocolCalledException(Exception): """The methods of a Protocol should never be called. Currently, no static type checker actually verifies that a class that inherits from a Protocol does in fact provide the correct interface. Thus, it may happen that a method of the protocol is called accidentally (this is an implementation error). The normal error for that is a somewhat cryptic AttributeError, wherefore we instead raise this custom exception in the BatchProtocol. Finally and importantly: using this in BatchProtocol makes mypy verify the fields in the various sub-protocols and thus renders is MUCH more useful! """ def get_sliced_dist(dist: TDistribution, index: IndexType) -> TDistribution: """Slice a distribution object by the given index.""" if isinstance(dist, Categorical): return Categorical(probs=dist.probs[index]) # type: ignore[return-value] if isinstance(dist, Normal): return Normal(loc=dist.loc[index], scale=dist.scale[index]) # type: ignore[return-value] if isinstance(dist, Independent): return Independent( get_sliced_dist(dist.base_dist, index), dist.reinterpreted_batch_ndims, ) # type: ignore[return-value] else: raise NotImplementedError(f"Unsupported distribution for slicing: {dist}") def get_len_of_dist(dist: Distribution) -> int: """Return the length (typically batch size) of a distribution object.""" if len(dist.batch_shape) == 0: raise TypeError(f"scalar Distribution has no length: {dist=}") return dist.batch_shape[0] def dist_to_atleast_2d(dist: TDistribution) -> TDistribution: """Convert a distribution to at least 2D, such that the `batch_shape` attribute has a len of at least 1.""" if len(dist.batch_shape) > 0: return dist if isinstance(dist, Categorical): return Categorical(probs=dist.probs.unsqueeze(0)) # type: ignore[return-value] elif isinstance(dist, Normal): return Normal(loc=dist.loc.unsqueeze(0), scale=dist.scale.unsqueeze(0)) # type: ignore[return-value] elif isinstance(dist, Independent): return Independent( dist_to_atleast_2d(dist.base_dist), dist.reinterpreted_batch_ndims, ) # type: ignore[return-value] else: raise NotImplementedError(f"Unsupported distribution for conversion to 2D: {type(dist)}") # Note: This is implemented as a protocol because the interface # of Batch is always extended by adding new fields. Having a hierarchy of # protocols building off this one allows for type safety and IDE support despite # the dynamic nature of Batch @runtime_checkable class BatchProtocol(Protocol): """The internal data structure in Tianshou. Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of objects that can be either numpy arrays, torch tensors, or batches themselves. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently. """ @property def shape(self) -> list[int]: raise ProtocolCalledException # NOTE: even though setattr and getattr are defined for any object, we need # to explicitly define them for the BatchProtocol, since otherwise mypy will # complain about new fields being added dynamically. For example, things like # `batch.new_field = ...` followed by using `batch.new_field` become type errors # if getattr and setattr are missing in the BatchProtocol. # # For the moment, tianshou relies on this kind of dynamic-field-addition # in many, many places. In principle, it would be better to construct new # objects with new combinations of fields instead of mutating existing ones - the # latter is error-prone and can't properly be expressed with types. May be in a # future, rather different version of tianshou it would be feasible to have stricter # typing. Then the need for Protocols would in fact disappear def __setattr__(self, key: str, value: Any) -> None: raise ProtocolCalledException def __getattr__(self, key: str) -> Any: raise ProtocolCalledException def __iter__(self) -> Iterator[Self]: raise ProtocolCalledException @overload def __getitem__(self, index: str) -> Any: raise ProtocolCalledException @overload def __getitem__(self, index: IndexType) -> Self: raise ProtocolCalledException def __getitem__(self, index: str | IndexType) -> Any: raise ProtocolCalledException def __setitem__(self, index: str | IndexType, value: Any) -> None: raise ProtocolCalledException def __iadd__(self, other: Self | Number | np.number) -> Self: raise ProtocolCalledException def __add__(self, other: Self | Number | np.number) -> Self: raise ProtocolCalledException def __imul__(self, value: Number | np.number) -> Self: raise ProtocolCalledException def __mul__(self, value: Number | np.number) -> Self: raise ProtocolCalledException def __itruediv__(self, value: Number | np.number) -> Self: raise ProtocolCalledException def __truediv__(self, value: Number | np.number) -> Self: raise ProtocolCalledException def __repr__(self) -> str: raise ProtocolCalledException def __eq__(self, other: Any) -> bool: raise ProtocolCalledException def to_numpy(self: Self) -> Self: """Change all torch.Tensor to numpy.ndarray and return a new Batch.""" raise ProtocolCalledException def to_numpy_(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" raise ProtocolCalledException def to_torch( self: Self, dtype: torch.dtype | None = None, device: str | int | torch.device = "cpu", ) -> Self: """Change all numpy.ndarray to torch.Tensor and return a new Batch.""" raise ProtocolCalledException def to_torch_( self, dtype: torch.dtype | None = None, device: str | int | torch.device = "cpu", ) -> None: """Change all numpy.ndarray to torch.Tensor in-place.""" raise ProtocolCalledException def cat_(self, batches: Self | Sequence[dict | Self]) -> None: """Concatenate a list of (or one) Batch objects into current batch.""" raise ProtocolCalledException @staticmethod def cat(batches: Sequence[dict | TBatch]) -> TBatch: """Concatenate a list of Batch object into a single new batch. For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g. :: >>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.cat([a, b]) >>> c.a.shape (7, 4) >>> c.b.shape (7, 3) >>> c.common.c.shape (7, 5) """ raise ProtocolCalledException def stack_(self, batches: Sequence[dict | Self], axis: int = 0) -> None: """Stack a list of Batch object into current batch.""" raise ProtocolCalledException @staticmethod def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch: """Stack a list of Batch object into a single new batch. For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g. :: >>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.stack([a, b]) >>> c.a.shape (2, 4, 4) >>> c.b.shape (2, 4, 6) >>> c.common.c.shape (2, 4, 5) .. note:: If there are keys that are not shared across all batches, ``stack`` with ``axis != 0`` is undefined, and will cause an exception. """ raise ProtocolCalledException def empty_(self, index: slice | IndexType | None = None) -> Self: """Return an empty Batch object with 0 or None filled. If "index" is specified, it will only reset the specific indexed-data. :: >>> data.empty_() >>> print(data) Batch( a: array([[0., 0.], [0., 0.]]), b: array([None, None], dtype=object), ) >>> b={'c': [2., 'st'], 'd': [1., 0.]} >>> data = Batch(a=[False, True], b=b) >>> data[0] = Batch.empty(data[1]) >>> data Batch( a: array([False, True]), b: Batch( c: array([None, 'st']), d: array([0., 0.]), ), ) """ raise ProtocolCalledException @staticmethod def empty(batch: TBatch, index: IndexType | None = None) -> TBatch: """Return an empty Batch object with 0 or None filled. The shape is the same as the given Batch. """ raise ProtocolCalledException def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: """Update this batch from another dict/Batch.""" raise ProtocolCalledException def __len__(self) -> int: raise ProtocolCalledException def split( self, size: int, shuffle: bool = True, merge_last: bool = False, ) -> Iterator[Self]: """Split whole data into multiple small batches. :param size: divide the data batch with the given size, but one batch if the length of the batch is smaller than "size". Size of -1 means the whole batch. :param shuffle: randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True. :param merge_last: merge the last batch into the previous one. Default to False. """ raise ProtocolCalledException def to_dict(self, recurse: bool = True) -> dict[str, Any]: raise ProtocolCalledException def to_list_of_dicts(self) -> list[dict[str, Any]]: raise ProtocolCalledException def get_keys(self) -> KeysView: raise ProtocolCalledException def set_array_at_key( self, seq: np.ndarray, key: str, index: IndexType | None = None, default_value: float | None = None, ) -> None: """Set a sequence of values at a given key. If `index` is not passed, the sequence must have the same length as the batch. :param seq: the array of values to set. :param key: the key to set the sequence at. :param index: the indices to set the sequence at. If None, the sequence must have the same length as the batch and will be set at all indices. :param default_value: this only applies if `index` is passed and the key does not exist yet in the batch. In that case, entries outside the passed index will be filled with this default value. Note that the array at the key will be of the same dtype as the passed sequence, so `default_value` should be such that numpy can cast it to this dtype. """ raise ProtocolCalledException def isnull(self) -> Self: """Return a boolean mask of the same shape, indicating missing values.""" raise ProtocolCalledException def hasnull(self) -> bool: """Return whether the batch has missing values.""" raise ProtocolCalledException def dropnull(self) -> Self: """Return a batch where all items in which any value is null are dropped. Note that it is not the same as just dropping the entries of the sequence. For example, with >>> b = Batch(a=[None, 2, 3, 4], b=[4, 5, None, 7]) >>> b.dropnull() will result in >>> Batch(a=[2, 4], b=[5, 7]) This logic is applied recursively to all nested batches. The result is the same as if the batch was flattened, entries were dropped, and then the batch was reshaped back to the original nested structure. """ ... @overload def apply_values_transform( self, values_transform: Callable[[np.ndarray | torch.Tensor], Any], ) -> Self: ... @overload def apply_values_transform( self, values_transform: Callable, inplace: Literal[True], ) -> None: ... @overload def apply_values_transform( self, values_transform: Callable[[np.ndarray | torch.Tensor], Any], inplace: Literal[False], ) -> Self: ... def apply_values_transform( self, values_transform: Callable[[np.ndarray | torch.Tensor], Any], inplace: bool = False, ) -> None | Self: """Apply a function to all arrays in the batch, including nested ones. :param values_transform: the function to apply to the arrays. :param inplace: whether to apply the function in-place. If False, a new batch is returned, otherwise the batch is modified in-place and None is returned. """ raise ProtocolCalledException def get(self, key: str, default: Any | None = None) -> Any: raise ProtocolCalledException def pop(self, key: str, default: Any | None = None) -> Any: raise ProtocolCalledException def to_at_least_2d(self) -> Self: """Ensures that all arrays and dists in the batch have at least 2 dimensions. This is useful for ensuring that all arrays in the batch can be concatenated along a new axis. """ raise ProtocolCalledException class Batch(BatchProtocol): """See :class:`~tianshou.data.batch.BatchProtocol`.""" __doc__ = BatchProtocol.__doc__ def __init__( self, batch_dict: dict | BatchProtocol | Sequence[dict | BatchProtocol] | np.ndarray | None = None, copy: bool = False, **kwargs: Any, ) -> None: if copy: batch_dict = deepcopy(batch_dict) if batch_dict is not None: if isinstance(batch_dict, dict | BatchProtocol): _assert_type_keys(batch_dict.keys()) for batch_key, obj in batch_dict.items(): self.__dict__[batch_key] = _parse_value(obj) elif _is_batch_set(batch_dict): batch_dict = cast(Sequence[dict | BatchProtocol], batch_dict) self.stack_(batch_dict) if len(kwargs) > 0: # TODO: that's a rather weird pattern, is it really needed? # Feels like kwargs could be just merged into batch_dict in the beginning self.__init__(kwargs, copy=copy) # type: ignore def to_dict(self, recursive: bool = True) -> dict[str, Any]: result = {} for k, v in self.__dict__.items(): if recursive and isinstance(v, Batch): v = v.to_dict(recursive=recursive) result[k] = v return result def get_keys(self) -> KeysView: return self.__dict__.keys() def get(self, key: str, default: Any | None = None) -> Any: return self.__dict__.get(key, default) def pop(self, key: str, default: Any | None = None) -> Any: return self.__dict__.pop(key, default) def to_list_of_dicts(self) -> list[dict[str, Any]]: return [entry.to_dict() for entry in self] def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" self.__dict__[key] = _parse_value(value) def __getattr__(self, key: str) -> Any: """Return self.key. The "Any" return type is needed for mypy.""" return getattr(self.__dict__, key) def __contains__(self, key: str) -> bool: """Return key in self.""" return key in self.__dict__ def __getstate__(self) -> dict[str, Any]: """Pickling interface. Only the actual data are serialized for both efficiency and simplicity. """ state = {} for batch_key, obj in self.items(): if isinstance(obj, Batch): state[batch_key] = obj.__getstate__() else: state[batch_key] = obj return state def __setstate__(self, state: dict[str, Any]) -> None: """Unpickling interface. At this point, self is an empty Batch instance that has not been initialized, so it can safely be initialized by the pickle state. """ self.__init__(**state) # type: ignore @overload def __getitem__(self, index: str) -> Any: ... @overload def __getitem__(self, index: IndexType) -> Self: ... def __getitem__(self, index: str | IndexType) -> Any: """Returns either the value of a key or a sliced Batch object.""" if isinstance(index, str): return self.__dict__[index] batch_items = self.items() if len(batch_items) > 0: new_batch = Batch() sliced_obj: Any for batch_key, obj in batch_items: # None and empty Batches as values are added to any slice if obj is None: sliced_obj = None elif isinstance(obj, Batch) and len(obj.get_keys()) == 0: sliced_obj = Batch() # We attempt slicing of a distribution. This is hacky, but presents an important special case elif isinstance(obj, Distribution): sliced_obj = get_sliced_dist(obj, index) # All other objects are either array-like or Batch-like, so hopefully sliceable # A batch should have no scalars else: sliced_obj = obj[index] new_batch.__dict__[batch_key] = sliced_obj return new_batch raise IndexError("Cannot access item from empty Batch object.") def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False this_batch_no_torch_tensor = self.to_numpy() other_batch_no_torch_tensor = other.to_numpy() # DeepDiff 7.0.1 cannot compare 0-dimensional arrays # so, we ensure with this transform that all array values have at least 1 dim this_batch_no_torch_tensor.apply_values_transform( values_transform=np.atleast_1d, inplace=True, ) other_batch_no_torch_tensor.apply_values_transform( values_transform=np.atleast_1d, inplace=True, ) this_dict = this_batch_no_torch_tensor.to_dict(recursive=True) other_dict = other_batch_no_torch_tensor.to_dict(recursive=True) return not DeepDiff(this_dict, other_dict) def __iter__(self) -> Iterator[Self]: # TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea if len(self.__dict__) == 0: yield from [] else: for i in range(len(self)): yield self[i] def __setitem__(self, index: str | IndexType, value: Any) -> None: """Assign value to self[index].""" value = _parse_value(value) if isinstance(index, str): self.__dict__[index] = value return if not isinstance(value, Batch): raise ValueError( "Batch does not supported tensor assignment. " "Use a compatible Batch or dict instead.", ) if not set(value.keys()).issubset(self.__dict__.keys()): raise ValueError("Creating keys is not supported by item assignment.") for key, val in self.items(): try: self.__dict__[key][index] = value[key] except KeyError: if isinstance(val, Batch): self.__dict__[key][index] = Batch() elif isinstance(val, torch.Tensor) or ( isinstance(val, np.ndarray) and issubclass(val.dtype.type, np.bool_ | np.number) ): self.__dict__[key][index] = 0 else: self.__dict__[key][index] = None def __iadd__(self, other: Self | Number | np.number) -> Self: """Algebraic addition with another Batch instance in-place.""" if isinstance(other, Batch): for (batch_key, obj), value in zip( self.__dict__.items(), other.__dict__.values(), strict=True, ): # TODO are keys consistent? if isinstance(obj, Batch) and len(obj.get_keys()) == 0: continue self.__dict__[batch_key] += value return self if _is_number(other): for batch_key, obj in self.items(): if isinstance(obj, Batch) and len(obj.get_keys()) == 0: continue self.__dict__[batch_key] += other return self raise TypeError("Only addition of Batch or number is supported.") def __add__(self, other: Self | Number | np.number) -> Self: """Algebraic addition with another Batch instance out-of-place.""" return deepcopy(self).__iadd__(other) def __imul__(self, value: Number | np.number) -> Self: """Algebraic multiplication with a scalar value in-place.""" assert _is_number(value), "Only multiplication by a number is supported." for batch_key, obj in self.__dict__.items(): if isinstance(obj, Batch) and len(obj.get_keys()) == 0: continue self.__dict__[batch_key] *= value return self def __mul__(self, value: Number | np.number) -> Self: """Algebraic multiplication with a scalar value out-of-place.""" return deepcopy(self).__imul__(value) def __itruediv__(self, value: Number | np.number) -> Self: """Algebraic division with a scalar value in-place.""" assert _is_number(value), "Only division by a number is supported." for batch_key, obj in self.__dict__.items(): if isinstance(obj, Batch) and len(obj.get_keys()) == 0: continue self.__dict__[batch_key] /= value return self def __truediv__(self, value: Number | np.number) -> Self: """Algebraic division with a scalar value out-of-place.""" return deepcopy(self).__itruediv__(value) def __repr__(self) -> str: """Return str(self).""" self_str = self.__class__.__name__ + "(\n" flag = False for batch_key, obj in self.__dict__.items(): rpl = "\n" + " " * (6 + len(batch_key)) obj_name = pprint.pformat(obj).replace("\n", rpl) self_str += f" {batch_key}: {obj_name},\n" flag = True if flag: self_str += ")" else: self_str = self.__class__.__name__ + "()" return self_str def to_numpy(self: Self) -> Self: result = deepcopy(self) result.to_numpy_() return result def to_numpy_(self) -> None: def arr_to_numpy(arr: TArr) -> TArr: if isinstance(arr, torch.Tensor): return arr.detach().cpu().numpy() return arr self.apply_values_transform(arr_to_numpy, inplace=True) def to_torch( self: Self, dtype: torch.dtype | None = None, device: str | int | torch.device = "cpu", ) -> Self: result = deepcopy(self) result.to_torch_(dtype=dtype, device=device) return result def to_torch_( self, dtype: torch.dtype | None = None, device: str | int | torch.device = "cpu", ) -> None: if not isinstance(device, torch.device): device = torch.device(device) def arr_to_torch(arr: TArr) -> TArr: if isinstance(arr, np.ndarray): return torch.from_numpy(arr).to(device) # TODO: simplify if ( (dtype is not None and arr.dtype != dtype) or arr.device.type != device.type or device.index != arr.device.index ): if dtype is not None: arr = arr.type(dtype) return arr.to(device) return arr self.apply_values_transform(arr_to_torch, inplace=True) def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: """Private method for Batch.cat_. :: >>> a = Batch(a=np.random.randn(3, 4)) >>> x = Batch(a=a, b=np.random.randn(4, 4)) >>> y = Batch(a=Batch(a=Batch()), b=np.random.randn(4, 4)) If we want to concatenate x and y, we want to pad y.a.a with zeros. Without ``lens`` as a hint, when we concatenate x.a and y.a, we would not be able to know how to pad y.a. So ``Batch.cat_`` should compute the ``lens`` to give ``Batch.__cat`` a hint. :: >>> ans = Batch.cat([x, y]) >>> # this is equivalent to the following line >>> ans = Batch(); ans.__cat([x, y], lens=[3, 4]) >>> # this lens is equal to [len(a), len(b)] """ # partial keys will be padded by zeros # with the shape of [len, rest_shape] sum_lens = [0] for len_ in lens: sum_lens.append(sum_lens[-1] + len_) # collect non-empty keys keys_map = [ { batch_key for batch_key, obj in batch.items() if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0) } for batch in batches ] keys_shared = set.intersection(*keys_map) values_shared = [[batch[key] for batch in batches] for key in keys_shared] for key, shared_value in zip(keys_shared, values_shared, strict=True): if all(isinstance(element, dict | Batch) for element in shared_value): batch_holder = Batch() batch_holder.__cat(shared_value, lens=lens) self.__dict__[key] = batch_holder elif all(isinstance(element, torch.Tensor) for element in shared_value): self.__dict__[key] = torch.cat(shared_value) else: # cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch())) # will fail here self.__dict__[key] = _to_array_with_correct_type(np.concatenate(shared_value)) keys_total = set.union(*[set(batch.keys()) for batch in batches]) keys_reserve_or_partial = set.difference(keys_total, keys_shared) # keys that are reserved in all batches keys_reserve = set.difference(keys_total, set.union(*keys_map)) # keys that occur only in some batches, but not all keys_partial = keys_reserve_or_partial.difference(keys_reserve) for key in keys_reserve: # reserved keys self.__dict__[key] = Batch() for key in keys_partial: for i, batch in enumerate(batches): if key not in batch.__dict__: continue value = batch.get(key) if isinstance(value, Batch) and len(value.get_keys()) == 0: continue try: self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value except KeyError: self.__dict__[key] = create_value(value, sum_lens[-1], stack=False) self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: if isinstance(batches, Batch | dict): batches = [batches] # check input format batch_list = [] original_keys_only_batch = None """A batch with all values removed, just keys left. Can be considered a sort of schema. Will be either the schema of self, or of the first non-empty batch in the sequence. """ if len(self) > 0: original_keys_only_batch = self.apply_values_transform(lambda x: None) original_keys_only_batch.replace_empty_batches_by_none() for batch in batches: if isinstance(batch, dict): batch = Batch(batch) if not isinstance(batch, Batch): raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_") if len(batch.get_keys()) == 0: continue if original_keys_only_batch is None: original_keys_only_batch = batch.apply_values_transform(lambda x: None) original_keys_only_batch.replace_empty_batches_by_none() batch_list.append(batch) continue cur_keys_only_batch = batch.apply_values_transform(lambda x: None) cur_keys_only_batch.replace_empty_batches_by_none() if original_keys_only_batch != cur_keys_only_batch: raise ValueError( f"Batch.cat_ only supports concatenation of batches with the same structure but got " f"structures: \n{original_keys_only_batch}\n and\n{cur_keys_only_batch}.", ) batch_list.append(batch) if len(batch_list) == 0: return batches = batch_list # TODO: lot's of the remaining logic is devoted to filling up remaining keys with zeros # this should be removed, and also the check above should be extended to nested keys try: # len(batch) here means batch is a nested empty batch # like Batch(a=Batch), and we have to treat it as length zero and # keep it. lens = [0 if len(batch) == 0 else len(batch) for batch in batches] except TypeError as exception: raise ValueError( "Batch.cat_ meets an exception. Maybe because there is any " f"scalar in {batches} but Batch.cat_ does not support the " "concatenation of scalar.", ) from exception if len(self.get_keys()) != 0: batches = [self, *list(batches)] # len of zero means that that item is Batch() and should be ignored lens = [0 if len(self) == 0 else len(self), *lens] self.__cat(batches, lens) @staticmethod def cat(batches: Sequence[dict | TBatch]) -> TBatch: batch = Batch() batch.cat_(batches) return batch # type: ignore def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None: # check input format batch_list = [] for batch in batches: if isinstance(batch, dict): if len(batch) > 0: batch_list.append(Batch(batch)) elif isinstance(batch, Batch): if len(batch.get_keys()) != 0: batch_list.append(batch) else: raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_") if len(batch_list) == 0: return batches = batch_list if len(self.get_keys()) != 0: batches = [self, *batches] # collect non-empty keys keys_map = [ { batch_key for batch_key, obj in batch.items() if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0) } for batch in batches ] keys_shared = set.intersection(*keys_map) values_shared = [[batch[key] for batch in batches] for key in keys_shared] for shared_key, value in zip(keys_shared, values_shared, strict=True): # second often if all(isinstance(element, torch.Tensor) for element in value): self.__dict__[shared_key] = torch.stack(value, axis) # third often elif all(isinstance(element, Batch | dict) for element in value): self.__dict__[shared_key] = Batch.stack(value, axis) else: # most often case is np.ndarray try: self.__dict__[shared_key] = _to_array_with_correct_type(np.stack(value, axis)) except ValueError: warnings.warn( "You are using tensors with different shape," " fallback to dtype=object by default.", ) self.__dict__[shared_key] = np.array(value, dtype=object) # all the keys keys_total = set.union(*[set(batch.keys()) for batch in batches]) # keys that are reserved in all batches keys_reserve = set.difference(keys_total, set.union(*keys_map)) # keys that are either partial or reserved keys_reserve_or_partial = set.difference(keys_total, keys_shared) # keys that occur only in some batches, but not all keys_partial = keys_reserve_or_partial.difference(keys_reserve) if keys_partial and axis != 0: raise ValueError( f"Stack of Batch with non-shared keys {keys_partial} is only " f"supported with axis=0, but got axis={axis}!", ) for key in keys_reserve: # reserved keys self.__dict__[key] = Batch() for key in keys_partial: for i, batch in enumerate(batches): if key not in batch.__dict__: continue value = batch.get(key) # TODO: fix code/annotations s.t. the ignores can be removed if ( isinstance(value, Batch) # type: ignore and len(value.get_keys()) == 0 # type: ignore ): continue # type: ignore try: self.__dict__[key][i] = value except KeyError: self.__dict__[key] = create_value(value, len(batches)) self.__dict__[key][i] = value @staticmethod def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch: batch = Batch() batch.stack_(batches, axis) # can't cast to a generic type, so we have to ignore the type here return batch # type: ignore def empty_(self, index: slice | IndexType | None = None) -> Self: for batch_key, obj in self.items(): if isinstance(obj, torch.Tensor): # most often case self.__dict__[batch_key][index] = 0 elif obj is None: continue elif isinstance(obj, np.ndarray): if obj.dtype == object: self.__dict__[batch_key][index] = None else: self.__dict__[batch_key][index] = 0 elif isinstance(obj, Batch): self.__dict__[batch_key].empty_(index=index) else: # scalar value warnings.warn( "You are calling Batch.empty on a NumPy scalar, " "which may cause undefined behaviors.", ) if _is_number(obj): self.__dict__[batch_key] = obj.__class__(0) else: self.__dict__[batch_key] = None return self @staticmethod def empty(batch: TBatch, index: IndexType | None = None) -> TBatch: return deepcopy(batch).empty_(index) def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: if batch is None: self.update(kwargs) return for batch_key, obj in batch.items(): self.__dict__[batch_key] = _parse_value(obj) if kwargs: self.update(kwargs) def __len__(self) -> int: """Raises `TypeError` if any value in the batch has no len(), typically meaning it's a batch of scalars.""" lens = [] for key, obj in self.__dict__.items(): # TODO: causes inconsistent behavior to batch with empty batches # and batch with empty sequences of other type. Remove, but only after # Buffer and Collectors have been improved to no longer rely on this if isinstance(obj, Batch) and len(obj) == 0: continue if obj is None: continue if hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0): lens.append(len(obj)) continue if isinstance(obj, Distribution): lens.append(get_len_of_dist(obj)) continue raise TypeError(f"Entry for {key} in {self} is {obj} has no len()") if not lens: return 0 return min(lens) @property def shape(self) -> list[int]: """Return self.shape.""" if len(self.get_keys()) == 0: return [] data_shape = [] for obj in self.__dict__.values(): try: data_shape.append(list(obj.shape)) except AttributeError: data_shape.append([]) return ( list(map(min, zip(*data_shape, strict=False))) if len(data_shape) > 1 else data_shape[0] ) def split( self, size: int, shuffle: bool = True, merge_last: bool = False, ) -> Iterator[Self]: length = len(self) if size == -1: size = length assert size >= 1 # size can be greater than length, return whole batch indices = np.random.permutation(length) if shuffle else np.arange(length) merge_last = merge_last and length % size > 0 for idx in range(0, length, size): if merge_last and idx + size + size >= length: yield self[indices[idx:]] break yield self[indices[idx : idx + size]] @overload def apply_values_transform( self, values_transform: Callable, ) -> Self: ... @overload def apply_values_transform( self, values_transform: Callable, inplace: Literal[True], ) -> None: ... @overload def apply_values_transform( self, values_transform: Callable, inplace: Literal[False], ) -> Self: ... def apply_values_transform( self, values_transform: Callable, inplace: bool = False, ) -> None | Self: """Applies a function to all non-batch-values in the batch, including values in nested batches. A batch with keys pointing to either batches or to non-batch values can be thought of as a tree of Batch nodes. This function traverses the tree and applies the function to all leaf nodes (i.e. values that are not batches themselves). The values are usually arrays, but can also be scalar values of an arbitrary type since retrieving a single entry from a Batch a la `batch[0]` will return a batch with scalar values. """ return _apply_batch_values_func_recursively(self, values_transform, inplace=inplace) def set_array_at_key( self, arr: np.ndarray, key: str, index: IndexType | None = None, default_value: float | None = None, ) -> None: if index is not None: if key not in self.get_keys(): log.info( f"Key {key} not found in batch, " f"creating a sequence of len {len(self)} with {default_value=} for it.", ) try: self[key] = np.array([default_value] * len(self), dtype=arr.dtype) except TypeError as exception: raise TypeError( f"Cannot create a sequence of dtype {arr.dtype} with default value {default_value}. " f"You can fix this either by passing an array with the correct dtype or by passing " f"a different default value that can be cast to the array's dtype (or both).", ) from exception else: existing_entry = self[key] if isinstance(existing_entry, Batch): raise ValueError( f"Cannot set sequence at key {key} because it is a nested batch, " f"can only set a subsequence of an array.", ) self[key][index] = arr else: if len(arr) != len(self): raise ValueError( f"Sequence length {len(arr)} does not match " f"batch length {len(self)}. For setting a subsequence with missing " f"entries filled up by default values, consider passing an index.", ) self[key] = arr def isnull(self) -> Self: return self.apply_values_transform(pd.isnull, inplace=False) def hasnull(self) -> bool: isnan_batch = self.isnull() is_any_null_batch = isnan_batch.apply_values_transform(np.any, inplace=False) def is_any_true(boolean_batch: BatchProtocol) -> bool: for val in boolean_batch.values(): if isinstance(val, Batch): if is_any_true(val): return True else: assert val.size == 1, "This shouldn't have happened, it's a bug!" # an unsized array with a boolean, e.g. np.array(False). behaves like the boolean itself if val: return True return False return is_any_true(is_any_null_batch) def dropnull(self) -> Self: # we need to use dicts since a batch retrieved for a single index has no length and cat fails # TODO: make cat work with batches containing scalars? sub_batches = [] for b in self: if b.hasnull(): continue # needed for cat to work b = b.apply_values_transform(np.atleast_1d) sub_batches.append(b) return Batch.cat(sub_batches) def replace_empty_batches_by_none(self) -> None: """Goes through the batch-tree" recursively and replaces empty batches by None. This is useful for extracting the structure of a batch without the actual data, especially in combination with `apply_values_transform` with a transform function a la `lambda x: None`. """ empty_batch = Batch() for key, val in self.items(): if isinstance(val, Batch): if val == empty_batch: self[key] = None else: val.replace_empty_batches_by_none() def to_at_least_2d(self) -> Self: """Ensures that all arrays and dists in the batch have at least 2 dimensions. This is useful for ensuring that all arrays in the batch can be concatenated along a new axis. """ result = self.apply_values_transform(np.atleast_2d, inplace=False) for key, val in self.items(): if isinstance(val, Distribution): result[key] = dist_to_atleast_2d(val) return result def _apply_batch_values_func_recursively( batch: TBatch, values_transform: Callable, inplace: bool = False, ) -> TBatch | None: """Applies the desired function on all values of the batch recursively. See docstring of the corresponding method in the Batch class for more details. """ result = batch if inplace else deepcopy(batch) for key, val in batch.__dict__.items(): if isinstance(val, Batch): result[key] = _apply_batch_values_func_recursively(val, values_transform, inplace=False) else: result[key] = values_transform(val) if not inplace: return result return None ================================================ FILE: tianshou/data/buffer/__init__.py ================================================ def _backward_compatibility() -> None: import sys from . import buffer_base # backward compatibility with persisted buffers from v1 for determinism tests sys.modules["tianshou.data.buffer.base"] = buffer_base _backward_compatibility() ================================================ FILE: tianshou/data/buffer/buffer_base.py ================================================ from collections.abc import Sequence from typing import Any, ClassVar, Self, TypeVar, cast import h5py import numpy as np from sensai.util.pickle import setstate from tianshou.data import Batch from tianshou.data.batch import ( IndexType, alloc_by_keys_diff, create_value, log, ) from tianshou.data.types import RolloutBatchProtocol from tianshou.data.utils.converter import from_hdf5, to_hdf5 TBuffer = TypeVar("TBuffer", bound="ReplayBuffer") class MalformedBufferError(RuntimeError): pass class ReplayBuffer: """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style. :param size: the maximum size of replay buffer. :param stack_num: the frame-stack sampling argument, should be greater than or equal to 1. Default to 1 (no stacking). :param ignore_obs_next: whether to not store obs_next. :param save_only_last_obs: only save the last obs/obs_next when it has a shape of (timestep, ...) because of temporal stacking. :param sample_avail: whether to sample only available indices when using the frame-stack sampling method. """ _reserved_keys = ( "obs", "act", "rew", "terminated", "truncated", "done", "obs_next", "info", "policy", ) _input_keys = ( "obs", "act", "rew", "terminated", "truncated", "obs_next", "info", "policy", ) _required_keys_for_add: ClassVar[set[str]] = { "obs", "act", "rew", "terminated", "truncated", "done", } def __init__( self, size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, random_seed: int = 42, **kwargs: Any, # otherwise PrioritizedVectorReplayBuffer will cause TypeError ) -> None: # TODO: why do we need this? Just for readout? self.options: dict[str, Any] = { "stack_num": stack_num, "ignore_obs_next": ignore_obs_next, "save_only_last_obs": save_only_last_obs, "sample_avail": sample_avail, } super().__init__() self.maxsize = int(size) assert stack_num > 0, "stack_num should be greater than 0" self.stack_num = stack_num self._indices = np.arange(size) # TODO: remove double negation and different name self._save_obs_next = not ignore_obs_next self._save_only_last_obs = save_only_last_obs self._sample_avail = sample_avail self._meta = cast(RolloutBatchProtocol, Batch()) self._random_state = np.random.RandomState(random_seed) # Keep in sync with reset! self.last_index = np.array([0]) self._insertion_idx = self._size = 0 self._ep_return, self._ep_len, self._ep_start_idx = 0.0, 0, 0 def __setstate__(self, state: dict[str, Any]) -> None: setstate( ReplayBuffer, self, state, new_default_properties={"_random_state": np.random.RandomState(42)}, ) @property def subbuffer_edges(self) -> np.ndarray: """Edges of contained buffers, mostly needed as part of the VectorReplayBuffer interface. For the standard ReplayBuffer it is always [0, maxsize]. Transitions can be added to the buffer indefinitely, and one episode can "go over the edge". Having the edges available is useful for fishing out whole episodes from the buffer and for input validation. """ return np.array([0, self.maxsize], dtype=int) def _get_start_stop_tuples_for_edge_crossing_interval( self, start: int, stop: int, ) -> tuple[tuple[int, int], tuple[int, int]]: """Assumes that stop < start and retrieves tuples corresponding to the two slices that determine the interval within the buffer. Example: ------- >>> list(self.subbuffer_edges) == [0, 5, 10] >>> start = 4 >>> stop = 2 >>> self._get_start_stop_tuples_for_edge_crossing_interval(start, stop) ((4, 5), (0, 2)) The buffer sliced from 4 to 5 and then from 0 to 2 will contain the transitions corresponding to the provided start and stop values. """ if stop >= start: raise ValueError( f"Expected stop < start, but got {start=}, {stop=}. " f"For stop larger-equal than start this method should never be called. This can occur either due to an implementation error, " f"or due a bad configuration of the buffer that resulted in a single episode being so long that " f"it completely filled a subbuffer (of size len(buffer)/degree_of_vectorization). " f"Consider either shortening the episode, increasing the size of the buffer, or decreasing the " f"degree of vectorization.", ) subbuffer_edges = cast(Sequence[int], self.subbuffer_edges) edge_after_start_idx = int(np.searchsorted(subbuffer_edges, start, side="left")) """This is the crossed edge""" if edge_after_start_idx == 0: raise ValueError( f"The start value should be larger than the first edge, but got {start=}, {subbuffer_edges[1]=}.", ) edge_after_start = subbuffer_edges[edge_after_start_idx] edge_before_stop = subbuffer_edges[edge_after_start_idx - 1] """It's the edge before the crossed edge""" if edge_before_stop >= stop: raise ValueError( f"The edge before the crossed edge should be smaller than the stop, but got {edge_before_stop=}, {stop=}.", ) return (start, edge_after_start), (edge_before_stop, stop) def get_buffer_indices(self, start: int, stop: int) -> np.ndarray: """Get the indices of the transitions in the buffer between start and stop. The special thing about this is that stop may actually be smaller than start, since one often is interested in a sequence of transitions that goes over a subbuffer edge. The main use case for this method is to retrieve an episode from the buffer, in which case start is the index of the first transition in the episode and stop is the index where `done` is True + 1. This can be done with the following code: .. code-block:: python episode_indices = buffer.get_buffer_indices(episode_start_index, episode_done_index + 1) episode = buffer[episode_indices] Even when `start` is smaller than `stop`, it will be validated that they are in the same subbuffer. Example: -------- >>> list(buffer.subbuffer_edges) == [0, 5, 10] >>> buffer.get_buffer_indices(start=2, stop=4) [2, 3] >>> buffer.get_buffer_indices(start=4, stop=2) [4, 0, 1] >>> buffer.get_buffer_indices(start=8, stop=7) [8, 9, 5, 6] >>> buffer.get_buffer_indices(start=1, stop=6) ValueError: Start and stop indices must be within the same subbuffer. >>> buffer.get_buffer_indices(start=8, stop=1) ValueError: Start and stop indices must be within the same subbuffer. :param start: The start index of the interval. :param stop: The stop index of the interval. :return: The indices of the transitions in the buffer between start and stop. """ start_left_edge = np.searchsorted(self.subbuffer_edges, start, side="right") - 1 stop_left_edge = np.searchsorted(self.subbuffer_edges, stop - 1, side="right") - 1 if start_left_edge != stop_left_edge: raise ValueError( f"Start and stop indices must be within the same subbuffer. " f"Got {start=} in subbuffer edge {start_left_edge} and {stop=} in subbuffer edge {stop_left_edge}.", ) if stop >= start: return np.arange(start, stop, dtype=int) else: ( (start, upper_edge), ( lower_edge, stop, ), ) = self._get_start_stop_tuples_for_edge_crossing_interval( start, stop, ) log.debug(f"{start=}, {upper_edge=}, {lower_edge=}, {stop=}") return np.concatenate( ( np.arange(start, upper_edge, dtype=int), np.arange(lower_edge, stop, dtype=int), ), ) def __len__(self) -> int: return self._size def __repr__(self) -> str: wrapped_batch_repr = self._meta.__repr__()[len(self._meta.__class__.__name__) :] return self.__class__.__name__ + wrapped_batch_repr def __getattr__(self, key: str) -> Any: try: return self._meta[key] except KeyError as exception: raise AttributeError from exception def __setattr__(self, key: str, value: Any) -> None: assert key not in self._reserved_keys, f"key '{key}' is reserved and cannot be assigned" super().__setattr__(key, value) def save_hdf5(self, path: str, compression: str | None = None) -> None: """Save replay buffer to HDF5 file.""" with h5py.File(path, "w") as f: to_hdf5(self.__dict__, f, compression=compression) @classmethod def load_hdf5(cls, path: str, device: str | None = None) -> Self: """Load replay buffer from HDF5 file.""" with h5py.File(path, "r") as f: buf = cls.__new__(cls) buf.__setstate__(from_hdf5(f, device=device)) # type: ignore return buf @classmethod def from_data( cls, obs: h5py.Dataset, act: h5py.Dataset, rew: h5py.Dataset, terminated: h5py.Dataset, truncated: h5py.Dataset, done: h5py.Dataset, obs_next: h5py.Dataset, ) -> Self: size = len(obs) assert all( len(dset) == size for dset in [obs, act, rew, terminated, truncated, done, obs_next] ), "Lengths of all hdf5 datasets need to be equal." buf = cls(size) if size == 0: return buf batch = Batch( obs=obs, act=act, rew=rew, terminated=terminated, truncated=truncated, done=done, obs_next=obs_next, ) batch = cast(RolloutBatchProtocol, batch) buf.set_batch(batch) buf._size = size return buf def reset(self, keep_statistics: bool = False) -> None: """Clear all the data in replay buffer and episode statistics.""" # Keep in sync with init! self.last_index = np.array([0]) self._insertion_idx = self._size = self._ep_start_idx = 0 if not keep_statistics: self._ep_return, self._ep_len = 0.0, 0 # TODO: is this method really necessary? It's kinda dangerous, can accidentally # remove all references to collected data def set_batch(self, batch: RolloutBatchProtocol) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" assert len(batch) == self.maxsize and set(batch.get_keys()).issubset( self._reserved_keys, ), "Input batch doesn't meet ReplayBuffer's data form requirement." self._meta = batch def unfinished_index(self) -> np.ndarray: """Return the index of unfinished episode.""" last = (self._insertion_idx - 1) % self._size if self._size else 0 return np.array([last] if not self.done[last] and self._size else [], int) def prev(self, index: int | np.ndarray) -> np.ndarray: """Return the index of previous transition. The index won't be modified if it is the beginning of an episode. """ index = (index - 1) % self._size end_flag = self.done[index] | (index == self.last_index[0]) return (index + end_flag) % self._size def next(self, index: int | np.ndarray) -> np.ndarray: """Return the index of next transition. The index won't be modified if it is the end of an episode. """ end_flag = self.done[index] | (index == self.last_index[0]) return (index + (1 - end_flag)) % self._size def update(self, buffer: "ReplayBuffer") -> np.ndarray: """Move the data from the given buffer to current buffer. Return the updated indices. If update fails, return an empty array. """ if len(buffer) == 0 or self.maxsize == 0: return np.array([], int) stack_num, buffer.stack_num = buffer.stack_num, 1 from_indices = buffer.sample_indices(0) # get all available indices buffer.stack_num = stack_num if len(from_indices) == 0: return np.array([], int) updated_indices = [] for _ in range(len(from_indices)): updated_indices.append(self._insertion_idx) self.last_index[0] = self._insertion_idx self._insertion_idx = (self._insertion_idx + 1) % self.maxsize self._size = min(self._size + 1, self.maxsize) updated_indices = np.array(updated_indices) if len(self._meta.get_keys()) == 0: self._meta = create_value(buffer._meta, self.maxsize, stack=False) # type: ignore self._meta[updated_indices] = buffer._meta[from_indices] return updated_indices def _update_state_pre_add( self, rew: float | np.ndarray, done: bool, ) -> tuple[int, float, int, int]: """Update the buffer's state before adding one data batch. Updates the `_size` and `_insertion_idx`, adds the reward and len internally maintained `_ep_len` and `_ep_return`. If `done` is `True`, will reset `_ep_len` and `_ep_return` to zero, and set `_ep_start_idx` to `_insertion_idx` Returns a tuple with: 0. the index at which to insert the next transition, 1. the episode len (if done=True, otherwise 0) 2. the episode return (if done=True, otherwise 0) 3. the episode start index. """ self.last_index[0] = cur_insertion_idx = self._insertion_idx self._size = min(self._size + 1, self.maxsize) self._insertion_idx = (self._insertion_idx + 1) % self.maxsize self._ep_return += rew # type: ignore self._ep_len += 1 if self._ep_start_idx > len(self): raise MalformedBufferError( f"Encountered a starting index {self._ep_start_idx} that is outside " f"the currently available samples {len(self)=}. " f"The buffer is malformed. This might be caused by a bug or by manual modifications of the buffer " f"by users.", ) # return 0 for unfinished episodes if done: ep_return = self._ep_return ep_len = self._ep_len else: if isinstance(self._ep_return, np.ndarray): # type: ignore[unreachable] # TODO: [original remark by MischaPanch] Check whether the entire else case is really correct/necessary. # ep_return should be a scalar but is a numpy array. # This doesn't make sense for a ReplayBuffer, but currently tests of CachedReplayBuffer require # this behavior for some reason; it also occurs in the MARL notebook, for example. # Will return an array of zeros instead of a scalar zero. pass ep_return = np.zeros_like(self._ep_return) # type: ignore ep_len = 0 result = cur_insertion_idx, ep_return, ep_len, self._ep_start_idx if done: # prepare for next episode collection # set return and len to zero, set start idx to next insertion idx self._ep_return, self._ep_len, self._ep_start_idx = ( 0.0, 0, self._insertion_idx, ) return result def add( self, batch: RolloutBatchProtocol, buffer_ids: np.ndarray | list[int] | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. :param batch: the input data batch. "obs", "act", "rew", "terminated", "truncated" are required keys. :param buffer_ids: id's of subbuffers, allowed here to be consistent with classes similar to :class:`~tianshou.data.buffer.vecbuf.VectorReplayBuffer`. Since the `ReplayBuffer` has a single subbuffer, if this is not None, it must be a single element with value 0. In that case, the batch is expected to have the shape (1, len(data)). Failure to adhere to this will result in a `ValueError`. Return `(current_index, episode_return, episode_length, episode_start_index)`. If the episode is not finished, the return value of episode_length and episode_reward is 0. """ # preprocess and copy batch into a new Batch object to avoid mutating the input # TODO: can't we just copy? Why do we need to rely on setting inside __dict__? new_batch = Batch() for key in batch.get_keys(): new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) # has to be done after preprocess batch if not self._required_keys_for_add.issubset( batch.get_keys(), ): raise ValueError( f"Input batch must have the following keys: {self._required_keys_for_add}", ) batch_is_stacked = False """True when instead of passing a batch of shape (len(data)), a batch of shape (1, len(data)) is passed.""" if buffer_ids is not None: if len(buffer_ids) != 1 and buffer_ids[0] != 0: raise ValueError( "If `buffer_ids` is not None, it must be a single element with value 0 for the non-vectorized `ReplayBuffer`. " f"Got {buffer_ids=}.", ) if len(batch) != 1: raise ValueError( f"If `buffer_ids` is not None, the batch must have the shape (1, len(data)) but got {len(batch)=}.", ) batch_is_stacked = True # block dealing with exotic options that are currently only used for atari, see various TODOs about that # These options have interactions with the case when buffer_ids is not None if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if batch_is_stacked else batch.obs[-1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = batch.obs_next[:, -1] if batch_is_stacked else batch.obs_next[-1] if batch_is_stacked: rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done insertion_idx, ep_return, ep_len, ep_start_idx = ( np.array([x]) for x in self._update_state_pre_add(rew, done) ) # TODO: improve this, don'r rely on try-except, instead process the batch if needed try: self._meta[insertion_idx] = batch except ValueError: stack = not batch_is_stacked batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) batch.terminated = batch.terminated.astype(bool) batch.truncated = batch.truncated.astype(bool) if len(self._meta.get_keys()) == 0: self._meta = create_value(batch, self.maxsize, stack) # type: ignore else: # dynamic key pops up in batch alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) self._meta[insertion_idx] = batch return insertion_idx, ep_return, ep_len, ep_start_idx def sample_indices(self, batch_size: int | None) -> np.ndarray: """Get a random sample of index with size = batch_size. Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 or no available index can be sampled. :param batch_size: the number of indices to be sampled. If None, it will be set to the length of the buffer (i.e. return all available indices in a random order). """ if batch_size is None: batch_size = len(self) if self.stack_num == 1 or not self._sample_avail: # most often case if batch_size > 0: return self._random_state.choice(self._size, batch_size) # TODO: is this behavior really desired? if batch_size == 0: # construct current available indices return np.concatenate( [ np.arange(self._insertion_idx, self._size), np.arange(self._insertion_idx), ], ) return np.array([], int) # TODO: raise error on negative batch_size instead? if batch_size < 0: return np.array([], int) # TODO: simplify this code - shouldn't have such a large if-else # with many returns for handling different stack nums. # It is also not clear whether this is really necessary - frame stacking usually is handled # by environment wrappers (e.g. FrameStack) and not by the replay buffer. all_indices = prev_indices = np.concatenate( [ np.arange(self._insertion_idx, self._size), np.arange(self._insertion_idx), ], ) for _ in range(self.stack_num - 2): prev_indices = self.prev(prev_indices) all_indices = all_indices[prev_indices != self.prev(prev_indices)] if batch_size > 0: return self._random_state.choice(all_indices, batch_size) return all_indices def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]: """Get a random sample from buffer with size = batch_size. Return all the data in the buffer if batch_size is 0. :return: Sample data and its corresponding index inside the buffer. """ indices = self.sample_indices(batch_size) return self[indices], indices def get( self, index: int | list[int] | np.ndarray, key: str, default_value: Any = None, # TODO 1: this is only here because of atari, it should never be needed (can be solved with index) # and should be removed # TODO 2: does something entirely different from getitem # TODO 3: key should not be required stack_num: int | None = None, ) -> Batch | np.ndarray: """Return the stacked result. E.g., if you set ``key = "obs", stack_num = 4, index = t``, it returns the stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``. :param index: the index for getting stacked data. :param key: the key to get, should be one of the reserved_keys. :param default_value: if the given key's data is not found and default_value is set, return this default_value. :param stack_num: Default to self.stack_num. """ if key not in self._meta.get_keys() and default_value is not None: return default_value val = self._meta[key] if stack_num is None: stack_num = self.stack_num try: if stack_num == 1: # the most common case return val[index] stack = list[Any]() indices = np.array(index) if isinstance(index, list) else index # NOTE: stack_num > 1, so the range is not empty and indices is turned into # np.ndarray by self.prev for _ in range(stack_num): stack = [val[indices], *stack] indices = self.prev(indices) indices = cast(np.ndarray, indices) if isinstance(val, Batch): return Batch.stack(stack, axis=indices.ndim) return np.stack(stack, axis=indices.ndim) except IndexError as exception: if not (isinstance(val, Batch) and len(val.keys()) == 0): raise exception # val != Batch() return Batch() def __getitem__(self, index: IndexType) -> RolloutBatchProtocol: """Return a data batch: self[index]. If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, ...). """ # TODO: this is a seriously problematic hack leading to # buffer[slice] != buffer[np.arange(slice.start, slice.stop)] # Fix asap, high priority!!! if isinstance(index, slice): # change slice to np array # buffer[:] will get all available data indices = ( self.sample_indices(0) if index == slice(None) else self._indices[: len(self)][index] ) else: indices = index # type: ignore # raise KeyError first instead of AttributeError, # to support np.array([ReplayBuffer()]) obs = self.get(indices, "obs") if self._save_obs_next: obs_next = self.get(indices, "obs_next", Batch()) else: obs_next_indices = self.next(indices) obs_next = self.get(obs_next_indices, "obs", Batch()) # TODO: don't do this batch_dict = { "obs": obs, "act": self.act[indices], "rew": self.rew[indices], "terminated": self.terminated[indices], "truncated": self.truncated[indices], "done": self.done[indices], "obs_next": obs_next, "info": self.get(indices, "info", Batch()), # TODO: what's the use of this key? "policy": self.get(indices, "policy", Batch()), } # TODO: don't do this, reduce complexity. Why such a big difference between what is returned # and sub-batches of self._meta? missing_keys = set(self._meta.get_keys()) - set(self._input_keys) for key in missing_keys: batch_dict[key] = self._meta[key][indices] return cast(RolloutBatchProtocol, Batch(batch_dict)) def set_array_at_key( self, seq: np.ndarray, key: str, index: IndexType | None = None, default_value: float | None = None, ) -> None: self._meta.set_array_at_key(seq, key, index, default_value) def hasnull(self) -> bool: return self[:].hasnull() def isnull(self) -> RolloutBatchProtocol: return self[:].isnull() def dropnull(self) -> None: # TODO: may fail, needs more testing with VectorBuffers self._meta = self._meta.dropnull() self._size = len(self._meta) self._insertion_idx = len(self._meta) ================================================ FILE: tianshou/data/buffer/cached.py ================================================ import numpy as np from tianshou.data import ReplayBuffer, ReplayBufferManager from tianshou.data.types import RolloutBatchProtocol class CachedReplayBuffer(ReplayBufferManager): """CachedReplayBuffer contains a given main buffer and n cached buffers, ``cached_buffer_num * ReplayBuffer(size=max_episode_length)``. The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1] |``. The data is first stored in cached buffers. When an episode is terminated, the data will move to the main buffer and the corresponding cached buffer will be reset. :param main_buffer: the main buffer whose ``.update()`` function behaves normally. :param cached_buffer_num: number of ReplayBuffer needs to be created for cached buffer. :param max_episode_length: the maximum length of one episode, used in each cached buffer's maxsize. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ def __init__( self, main_buffer: ReplayBuffer, cached_buffer_num: int, max_episode_length: int, ) -> None: assert cached_buffer_num > 0 assert max_episode_length > 0 assert isinstance(main_buffer, ReplayBuffer) kwargs = main_buffer.options buffers = [main_buffer] + [ ReplayBuffer(max_episode_length, **kwargs) for _ in range(cached_buffer_num) ] super().__init__(buffer_list=buffers) self.main_buffer = self.buffers[0] self.cached_buffers = self.buffers[1:] self.cached_buffer_num = cached_buffer_num def add( self, batch: RolloutBatchProtocol, buffer_ids: np.ndarray | list[int] | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into CachedReplayBuffer. Each of the data's length (first dimension) must equal to the length of buffer_ids. By default the buffer_ids is [0, 1, ..., cached_buffer_num - 1]. Return (current_index, episode_reward, episode_length, episode_start_index) with each of the shape (len(buffer_ids), ...), where (current_index[i], episode_reward[i], episode_length[i], episode_start_index[i]) refers to the cached_buffer_ids[i]th cached buffer's corresponding episode result. """ if buffer_ids is None: cached_buffer_ids = np.arange(1, 1 + self.cached_buffer_num) else: # make sure it is np.ndarray, +1 means it's never the main buffer cached_buffer_ids = np.asarray(buffer_ids) + 1 insertion_idx, ep_return, ep_len, ep_start_idx = super().add( batch, buffer_ids=cached_buffer_ids, ) # find the terminated episode, move data from cached buf to main buf updated_insertion_idx, updated_ep_start_idx = [], [] done = np.logical_or(batch.terminated, batch.truncated) for buffer_idx in cached_buffer_ids[done]: index = self.main_buffer.update(self.buffers[buffer_idx]) if len(index) == 0: # unsuccessful move, replace with -1 index = [-1] updated_ep_start_idx.append(index[0]) updated_insertion_idx.append(index[-1]) self.buffers[buffer_idx].reset() self._lengths[0] = len(self.main_buffer) self._lengths[buffer_idx] = 0 self.last_index[0] = index[-1] self.last_index[buffer_idx] = self._offset[buffer_idx] insertion_idx[done] = updated_insertion_idx ep_start_idx[done] = updated_ep_start_idx return insertion_idx, ep_return, ep_len, ep_start_idx ================================================ FILE: tianshou/data/buffer/her.py ================================================ from collections.abc import Callable from typing import Any, Union, cast import numpy as np from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol from tianshou.data.types import RolloutBatchProtocol class HERReplayBuffer(ReplayBuffer): """Implementation of Hindsight Experience Replay. arXiv:1707.01495. HERReplayBuffer is to be used with goal-based environment where the observation is a dictionary with keys ``observation``, ``achieved_goal`` and ``desired_goal``. Currently support only HER's future strategy, online sampling. :param size: the size of the replay buffer. :param compute_reward_fn: a function that takes 2 ``np.array`` arguments, ``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``. The two arguments are of shape (batch_size, ...original_shape) and the returned rewards must be of shape (batch_size,). :param horizon: the maximum number of steps in an episode. :param future_k: the 'k' parameter introduced in the paper. In short, there will be at most k episodes that are re-written for every 1 unaltered episode during the sampling. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ def __init__( self, size: int, compute_reward_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], horizon: int, future_k: float = 8.0, **kwargs: Any, ) -> None: super().__init__(size, **kwargs) self.horizon = horizon self.future_p = 1 - 1 / future_k self.compute_reward_fn = compute_reward_fn self._original_meta = Batch() self._altered_indices = np.array([]) def _restore_cache(self) -> None: """Write cached original meta back to `self._meta`. It's called everytime before 'writing', 'sampling' or 'saving' the buffer. """ if not hasattr(self, "_altered_indices"): return if self._altered_indices.size == 0: return self._meta[self._altered_indices] = self._original_meta # Clean self._original_meta = Batch() self._altered_indices = np.array([]) def reset(self, keep_statistics: bool = False) -> None: self._restore_cache() return super().reset(keep_statistics) def save_hdf5(self, path: str, compression: str | None = None) -> None: self._restore_cache() return super().save_hdf5(path, compression) def set_batch(self, batch: RolloutBatchProtocol) -> None: self._restore_cache() return super().set_batch(batch) def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray: self._restore_cache() return super().update(buffer) def add( self, batch: RolloutBatchProtocol, buffer_ids: np.ndarray | list[int] | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: self._restore_cache() return super().add(batch, buffer_ids) def sample_indices(self, batch_size: int | None) -> np.ndarray: """Get a random sample of index with size = batch_size. Return all available indices in the buffer if batch_size is 0; return an \ empty numpy array if batch_size < 0 or no available index can be sampled. \ Additionally, some episodes of the sampled transitions will be re-written \ according to HER. """ self._restore_cache() indices = super().sample_indices(batch_size=batch_size) self.rewrite_transitions(indices.copy()) return indices def rewrite_transitions(self, indices: np.ndarray) -> None: """Re-write the goal of some sampled transitions' episodes according to HER. Currently applies only HER's 'future' strategy. The new goals will be written \ directly to the internal batch data temporarily and will be restored right \ before the next sampling or when using some of the buffer's method (e.g. \ `add`, `save_hdf5`, etc.). This is to make sure that n-step returns \ calculation etc., performs correctly without additional alteration. """ if indices.size == 0: return # Sort indices keeping chronological order indices[indices < self._insertion_idx] += self.maxsize indices = np.sort(indices) indices[indices >= self.maxsize] -= self.maxsize # Construct episode trajectories indices = [indices] for _ in range(self.horizon - 1): indices.append(self.next(indices[-1])) indices = np.stack(indices) # Calculate future timestep to use current = indices[0] terminal = indices[-1] episodes_len = (terminal - current + self.maxsize) % self.maxsize future_offset = np.random.uniform(size=len(indices[0])) * episodes_len future_offset = np.round(future_offset).astype(int) future_t = (current + future_offset) % self.maxsize # Compute indices # open indices are used to find longest, unique trajectories among # presented episodes unique_ep_open_indices = np.sort(np.unique(terminal, return_index=True)[1]) unique_ep_indices = indices[:, unique_ep_open_indices] # close indices are used to find max future_t among presented episodes unique_ep_close_indices = np.hstack([(unique_ep_open_indices - 1)[1:], len(terminal) - 1]) # episode indices that will be altered her_ep_indices = np.random.choice( len(unique_ep_open_indices), size=int(len(unique_ep_open_indices) * self.future_p), replace=False, ) # Cache original meta self._altered_indices = unique_ep_indices.copy() self._original_meta = self._meta[self._altered_indices].copy() # Copy original obs, ep_rew (and obs_next), and obs of future time step ep_obs = self[unique_ep_indices].obs # to satisfy mypy # TODO: add protocol covering these batches assert isinstance(ep_obs, Batch) ep_rew = self[unique_ep_indices].rew if self._save_obs_next: ep_obs_next = self[unique_ep_indices].obs_next # to satisfy mypy assert isinstance(ep_obs_next, Batch) future_obs = self[future_t[unique_ep_close_indices]].obs_next else: future_obs = self[self.next(future_t[unique_ep_close_indices])].obs future_obs = cast(BatchProtocol, future_obs) # Re-assign goals and rewards via broadcast assignment ep_obs.desired_goal[:, her_ep_indices] = future_obs.achieved_goal[None, her_ep_indices] if self._save_obs_next: ep_obs_next = cast(BatchProtocol, ep_obs_next) ep_obs_next.desired_goal[:, her_ep_indices] = future_obs.achieved_goal[ None, her_ep_indices, ] ep_rew[:, her_ep_indices] = self._compute_reward(ep_obs_next)[:, her_ep_indices] else: tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs assert isinstance(tmp_ep_obs_next, Batch) ep_rew[:, her_ep_indices] = self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices] # Sanity check assert ep_obs.desired_goal.shape[:2] == unique_ep_indices.shape assert ep_obs.achieved_goal.shape[:2] == unique_ep_indices.shape assert ep_rew.shape == unique_ep_indices.shape # Re-write meta assert isinstance(self._meta.obs, Batch) self._meta.obs[unique_ep_indices] = ep_obs if self._save_obs_next: self._meta.obs_next[unique_ep_indices] = ep_obs_next # type: ignore self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32) def _compute_reward(self, obs: BatchProtocol, lead_dims: int = 2) -> np.ndarray: lead_shape = obs.observation.shape[:lead_dims] g = obs.desired_goal.reshape(-1, *obs.desired_goal.shape[lead_dims:]) ag = obs.achieved_goal.reshape(-1, *obs.achieved_goal.shape[lead_dims:]) rewards = self.compute_reward_fn(ag, g) return rewards.reshape(*lead_shape, *rewards.shape[1:]) ================================================ FILE: tianshou/data/buffer/manager.py ================================================ from collections.abc import Sequence from typing import Union, cast import numpy as np from numba import njit from overrides import override from tianshou.data import Batch, HERReplayBuffer, PrioritizedReplayBuffer, ReplayBuffer from tianshou.data.batch import alloc_by_keys_diff, create_value from tianshou.data.types import RolloutBatchProtocol class ReplayBufferManager(ReplayBuffer): """ReplayBufferManager contains a list of ReplayBuffer with exactly the same configuration. These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory. :param buffer_list: a list of ReplayBuffer needed to be handled. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ def __init__(self, buffer_list: list[ReplayBuffer] | list[HERReplayBuffer]) -> None: self.buffer_num = len(buffer_list) self.buffers = np.array(buffer_list, dtype=object) last_index: list[int] = [] offset, size = [], 0 buffer_type = type(self.buffers[0]) kwargs = self.buffers[0].options for buf in self.buffers: buf = cast(ReplayBuffer, buf) assert len(buf._meta.get_keys()) == 0 assert isinstance(buf, buffer_type) assert buf.options == kwargs offset.append(size) if len(buf.last_index) != 1: raise ValueError( f"{self.__class__.__name__} only supports buffers with a single index " f"(non-vector buffers), but got {last_index=}. " f"Did you try to use a {self.__class__.__name__} within a {self.__class__.__name__}?", ) last_index.append(size + buf.last_index[0]) size += buf.maxsize super().__init__(size=size, **kwargs) self._offset = np.array(offset) self._extend_offset = np.array([*offset, size]) self._lengths = np.zeros_like(offset) self.last_index = np.array(last_index) self._compile() self._meta: RolloutBatchProtocol @property @override def subbuffer_edges(self) -> np.ndarray: return self._extend_offset def _compile(self) -> None: lens = last = index = np.array([0]) offset = np.array([0, 1]) done = np.array([False, False]) _prev_index(index, offset, done, last, lens) _next_index(index, offset, done, last, lens) def __len__(self) -> int: return int(self._lengths.sum()) def reset(self, keep_statistics: bool = False) -> None: # keep in sync with init! self.last_index = self._offset.copy() self._lengths = np.zeros_like(self._offset) for buf in self.buffers: buf.reset(keep_statistics=keep_statistics) def _set_batch_for_children(self) -> None: for offset, buf in zip(self._offset, self.buffers, strict=True): buf.set_batch(self._meta[offset : offset + buf.maxsize]) def set_batch(self, batch: RolloutBatchProtocol) -> None: super().set_batch(batch) self._set_batch_for_children() def unfinished_index(self) -> np.ndarray: return np.concatenate( [ buf.unfinished_index() + offset for offset, buf in zip(self._offset, self.buffers, strict=True) ], ) def prev(self, index: int | np.ndarray) -> np.ndarray: if isinstance(index, list | np.ndarray): return _prev_index( np.asarray(index), self._extend_offset, self.done, self.last_index, self._lengths, ) return _prev_index( np.array([index]), self._extend_offset, self.done, self.last_index, self._lengths, )[0] def next(self, index: int | np.ndarray) -> np.ndarray: if isinstance(index, list | np.ndarray): return _next_index( np.asarray(index), self._extend_offset, self.done, self.last_index, self._lengths, ) return _next_index( np.array([index]), self._extend_offset, self.done, self.last_index, self._lengths, )[0] def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ReplayBufferManager cannot be updated by any buffer.""" raise NotImplementedError def add( self, batch: RolloutBatchProtocol, buffer_ids: np.ndarray | list[int] | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. Each of the data's length (first dimension) must equal to the length of buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. """ # preprocess batch new_batch = Batch() for key in set(self._reserved_keys).intersection(batch.get_keys()): new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.get_keys()) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = batch.obs_next[:, -1] # get index if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) insertion_indxS, ep_lens, ep_returns, ep_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): # TODO: don't access private method! insertion_index, ep_return, ep_len, ep_start_idx = self.buffers[ buffer_id ]._update_state_pre_add( batch.rew[batch_idx], batch.done[batch_idx], ) offset_insertion_idx = insertion_index + self._offset[buffer_id] offset_ep_start_idx = ep_start_idx + self._offset[buffer_id] insertion_indxS.append(offset_insertion_idx) ep_lens.append(ep_len) ep_returns.append(ep_return) ep_idxs.append(offset_ep_start_idx) self.last_index[buffer_id] = insertion_index + self._offset[buffer_id] self._lengths[buffer_id] = len(self.buffers[buffer_id]) insertion_indxS = np.array(insertion_indxS) try: self._meta[insertion_indxS] = batch # TODO: don't do this! except ValueError: batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) batch.terminated = batch.terminated.astype(bool) batch.truncated = batch.truncated.astype(bool) if len(self._meta.get_keys()) == 0: self._meta = create_value(batch, self.maxsize, stack=False) # type: ignore else: # dynamic key pops up in batch alloc_by_keys_diff(self._meta, batch, self.maxsize, False) self._set_batch_for_children() self._meta[insertion_indxS] = batch return ( insertion_indxS, np.array(ep_returns), np.array(ep_lens), np.array(ep_idxs), ) def sample_indices(self, batch_size: int | None) -> np.ndarray: # TODO: simplify this code if batch_size is not None and batch_size < 0: # TODO: raise error instead? return np.array([], int) if self._sample_avail and self.stack_num > 1: all_indices = np.concatenate( [ buf.sample_indices(0) + offset for offset, buf in zip(self._offset, self.buffers, strict=True) ], ) if batch_size == 0: return all_indices if batch_size is None: batch_size = len(all_indices) return self._random_state.choice(all_indices, batch_size) if batch_size == 0 or batch_size is None: # get all available indices sample_num = np.zeros(self.buffer_num, int) else: buffer_idx = self._random_state.choice( self.buffer_num, batch_size, p=self._lengths / self._lengths.sum(), ) sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) # avoid batch_size > 0 and sample_num == 0 -> get child's all data sample_num[sample_num == 0] = -1 return np.concatenate( [ buf.sample_indices(int(bsz)) + offset for offset, buf, bsz in zip(self._offset, self.buffers, sample_num, strict=True) ], ) # TODO: unintuitively, the order of inheritance has to stay this way for tests to pass # As also described in the todo below, this is a bad design and should be refactored class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): """PrioritizedReplayBufferManager contains a list of PrioritizedReplayBuffer with exactly the same configuration. These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory. :param buffer_list: a list of PrioritizedReplayBuffer needed to be handled. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: ReplayBufferManager.__init__(self, buffer_list) # type: ignore # last_index = copy(self.last_index) kwargs = buffer_list[0].options last_index_from_buffer_manager = self.last_index for buf in buffer_list: del buf.weight PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs) # TODO: the line below is needed since we now set the last_index of the manager in init # (previously it was only set in reset), and it clashes with multiple inheritance # Initializing the ReplayBufferManager after the PrioritizedReplayBuffer would be a better solution, # but it currently leads to infinite recursion. This kind of multiple inheritance with overlapping # interfaces is evil and we should get rid of it self.last_index = last_index_from_buffer_manager class HERReplayBufferManager(ReplayBufferManager): """HERReplayBufferManager contains a list of HERReplayBuffer with exactly the same configuration. These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory. :param buffer_list: a list of HERReplayBuffer needed to be handled. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ def __init__(self, buffer_list: list[HERReplayBuffer]) -> None: super().__init__(buffer_list) def _restore_cache(self) -> None: for buf in self.buffers: buf._restore_cache() def save_hdf5(self, path: str, compression: str | None = None) -> None: self._restore_cache() return super().save_hdf5(path, compression) def set_batch(self, batch: RolloutBatchProtocol) -> None: self._restore_cache() return super().set_batch(batch) def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray: self._restore_cache() return super().update(buffer) def add( self, batch: RolloutBatchProtocol, buffer_ids: np.ndarray | list[int] | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: self._restore_cache() return super().add(batch, buffer_ids) @njit def _prev_index( index: np.ndarray, offset: np.ndarray, done: np.ndarray, last_index: np.ndarray, lengths: np.ndarray, ) -> np.ndarray: index = index % offset[-1] prev_index = np.zeros_like(index) # disable B905 until strict=True in zip is implemented in numba # https://github.com/numba/numba/issues/8943 for start, end, cur_len, last in zip( # noqa: B905 offset[:-1], offset[1:], lengths, last_index, ): mask = (start <= index) & (index < end) correct_cur_len = max(1, cur_len) if np.sum(mask) > 0: subind = index[mask] subind = (subind - start - 1) % correct_cur_len end_flag = done[subind + start] | (subind + start == last) prev_index[mask] = (subind + end_flag) % correct_cur_len + start return prev_index @njit def _next_index( index: np.ndarray, offset: np.ndarray, done: np.ndarray, last_index: np.ndarray, lengths: np.ndarray, ) -> np.ndarray: index = index % offset[-1] next_index = np.zeros_like(index) # disable B905 until strict=True in zip is implemented in numba # https://github.com/numba/numba/issues/8943 for start, end, cur_len, last in zip( # noqa: B905 offset[:-1], offset[1:], lengths, last_index, ): mask = (start <= index) & (index < end) correct_cur_len = max(1, cur_len) if np.sum(mask) > 0: subind = index[mask] end_flag = done[subind] | (subind == last) next_index[mask] = (subind - start + 1 - end_flag) % correct_cur_len + start return next_index ================================================ FILE: tianshou/data/buffer/prio.py ================================================ from collections.abc import Sequence from typing import Any, cast import numpy as np import torch from tianshou.data import ReplayBuffer, SegmentTree, to_numpy from tianshou.data.batch import IndexType from tianshou.data.types import PrioBatchProtocol, RolloutBatchProtocol class PrioritizedReplayBuffer(ReplayBuffer): """Implementation of Prioritized Experience Replay. arXiv:1511.05952. :param alpha: the prioritization exponent. :param beta: the importance sample soft coefficient. :param weight_norm: whether to normalize returned weights with the maximum weight value within the batch. Default to True. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ def __init__( self, size: int, alpha: float, beta: float, weight_norm: bool = True, **kwargs: Any, ) -> None: # will raise KeyError in PrioritizedVectorReplayBuffer # super().__init__(size, **kwargs) ReplayBuffer.__init__(self, size, **kwargs) assert alpha > 0.0 assert beta >= 0.0 self._alpha, self._beta = alpha, beta self._max_prio = self._min_prio = 1.0 # save weight directly in this class instead of self._meta self.weight = SegmentTree(size) self.__eps = np.finfo(np.float32).eps.item() self.options.update(alpha=alpha, beta=beta) self._weight_norm = weight_norm def init_weight(self, index: int | np.ndarray) -> None: self.weight[index] = self._max_prio**self._alpha def update(self, buffer: ReplayBuffer) -> np.ndarray: indices = super().update(buffer) self.init_weight(indices) return indices def add( self, batch: RolloutBatchProtocol, buffer_ids: np.ndarray | list[int] | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids) self.init_weight(ptr) return ptr, ep_rew, ep_len, ep_idx def sample_indices(self, batch_size: int | None) -> np.ndarray: if batch_size is not None and batch_size > 0 and len(self) > 0: scalar = np.random.rand(batch_size) * self.weight.reduce() return self.weight.get_prefix_sum_idx(scalar) # type: ignore return super().sample_indices(batch_size) def get_weight(self, index: int | np.ndarray) -> float | np.ndarray: """Get the importance sampling weight. The "weight" in the returned Batch is the weight on loss function to debias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) return (self.weight[index] / self._min_prio) ** (-self._beta) def update_weight(self, index: np.ndarray, new_weight: np.ndarray | torch.Tensor) -> None: """Update priority weight by index in this buffer. :param np.ndarray index: index you want to update weight. :param np.ndarray new_weight: new priority weight you want to update. """ weight = np.abs(to_numpy(new_weight)) + self.__eps self.weight[index] = weight**self._alpha self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) def __getitem__(self, index: IndexType) -> PrioBatchProtocol: indices: Sequence[int] | np.ndarray if isinstance(index, slice): # change slice to np array # buffer[:] will get all available data indices = ( self.sample_indices(0) if index == slice(None) else self._indices[: len(self)][index] ) else: indices = cast(np.ndarray, index) batch = super().__getitem__(indices) weight = self.get_weight(indices) # ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154 batch.weight = weight / np.max(weight) if self._weight_norm else weight return cast(PrioBatchProtocol, batch) def sample(self, batch_size: int | None) -> tuple[PrioBatchProtocol, np.ndarray]: return cast(tuple[PrioBatchProtocol, np.ndarray], super().sample(batch_size=batch_size)) def set_beta(self, beta: float) -> None: self._beta = beta ================================================ FILE: tianshou/data/buffer/vecbuf.py ================================================ from typing import Any import numpy as np from tianshou.data import ( HERReplayBuffer, HERReplayBufferManager, PrioritizedReplayBuffer, PrioritizedReplayBufferManager, ReplayBuffer, ReplayBufferManager, ) class VectorReplayBuffer(ReplayBufferManager): """VectorReplayBuffer contains n ReplayBuffer with the same size. It is used for storing transition from different environments yet keeping the order of time. :param total_size: the total size of VectorReplayBuffer. :param buffer_num: the number of ReplayBuffer it uses, which are under the same configuration. Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail) are the same as :class:`~tianshou.data.ReplayBuffer`. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)] super().__init__(buffer_list) class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): """PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size. It is used for storing transition from different environments yet keeping the order of time. :param total_size: the total size of PrioritizedVectorReplayBuffer. :param buffer_num: the number of PrioritizedReplayBuffer it uses, which are under the same configuration. Other input arguments (alpha/beta/stack_num/ignore_obs_next/save_only_last_obs/ sample_avail) are the same as :class:`~tianshou.data.PrioritizedReplayBuffer`. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) buffer_list = [PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)] super().__init__(buffer_list) def set_beta(self, beta: float) -> None: for buffer in self.buffers: buffer.set_beta(beta) class HERVectorReplayBuffer(HERReplayBufferManager): """HERVectorReplayBuffer contains n HERReplayBuffer with same size. It is used for storing transition from different environments yet keeping the order of time. :param total_size: the total size of HERVectorReplayBuffer. :param buffer_num: the number of HERReplayBuffer it uses, which are under the same configuration. Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) buffer_list = [HERReplayBuffer(size, **kwargs) for _ in range(buffer_num)] super().__init__(buffer_list) ================================================ FILE: tianshou/data/collector.py ================================================ import logging import time import warnings from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass, field from typing import Any, Generic, Optional, Protocol, Self, TypedDict, TypeVar, cast import gymnasium as gym import numpy as np import torch from overrides import override from torch.distributions import Categorical, Distribution from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go from tianshou.config import ENABLE_VALIDATION from tianshou.data import ( Batch, CachedReplayBuffer, ReplayBuffer, ReplayBufferManager, SequenceSummaryStats, VectorReplayBuffer, to_numpy, ) from tianshou.data.buffer.buffer_base import MalformedBufferError from tianshou.data.stats import compute_dim_to_summary_stats from tianshou.data.types import ( ActBatchProtocol, DistBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.utils.determinism import TraceLogger from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode log = logging.getLogger(__name__) DEFAULT_BUFFER_MAXSIZE = int(1e4) _TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") TScalarArrayShape = TypeVar("TScalarArrayShape") class CollectActionBatchProtocol(Protocol): """A protocol for results of computing actions from a batch of observations within a single collect step. All fields all have length R (the dist is a Distribution of batch size R), where R is the number of ready envs. """ act: np.ndarray | torch.Tensor act_normalized: np.ndarray | torch.Tensor policy_entry: Batch dist: Distribution | None hidden_state: np.ndarray | torch.Tensor | Batch | None class CollectStepBatchProtocol(RolloutBatchProtocol): """A batch of steps collected from a single collect step from multiple envs in parallel. All fields have length R (the dist is a Distribution of batch size R), where R is the number of ready envs. This is essentially the response of the vectorized environment to making a step with :class:`CollectActionBatchProtocol`. """ dist: Distribution | None class EpisodeBatchProtocol(RolloutBatchProtocol): """Marker interface for a batch containing a single episode. Instances are created by retrieving an episode from the buffer when the :class:`Collector` encounters `done=True`. """ def get_stddev_from_dist(dist: Distribution) -> torch.Tensor: """Return the standard deviation of the given distribution. Same as `dist.stddev` for all distributions except `Categorical`, where it is computed by assuming that the output values 0, ..., K have the corresponding numerical meaning. See `here `_ for a discussion on `stddev` and `mean` of `Categorical`. """ if isinstance(dist, Categorical): # torch doesn't implement stddev for Categorical, so we compute it ourselves probs = torch.atleast_2d(dist.probs) n_actions = probs.shape[-1] possible_actions = torch.arange(n_actions, device=dist.probs.device).float() mean = torch.sum(probs * possible_actions, dim=1) var = torch.sum(probs * (possible_actions - mean.unsqueeze(1)) ** 2, dim=1) stddev = torch.sqrt(var) if len(dist.batch_shape) == 0: return stddev return torch.atleast_2d(stddev).T return dist.stddev if dist is not None else torch.tensor([]) @dataclass(kw_only=True) class CollectStatsBase(DataclassPPrintMixin): """The most basic stats, often used for offline learning.""" n_collected_episodes: int = 0 """The number of collected episodes.""" n_collected_steps: int = 0 """The number of collected steps.""" @dataclass(kw_only=True) class CollectStats(CollectStatsBase): """A data structure for storing the statistics of rollouts. Custom stats collection logic can be implemented by subclassing this class and overriding the `update_` methods. Ideally, it is instantiated once with correct values and then never modified. However, during the collection process instances of modified using the `update_` methods. Then the arrays and their corresponding `_stats` fields may become out of sync (we don't update the stats after each update for performance reasons, only at the end of the collection). The same for the `collect_time` and `collect_speed`. In the `Collector`, :meth:`refresh_sequence_stats` and :meth:`set_collect_time` are is called at the end of the collection to update the stats. But for other use cases, the users should keep in mind to call this method manually if using the `update_` methods. """ collect_time: float = 0.0 """The time for collecting transitions.""" collect_speed: float = 0.0 """The speed of collecting (env_step per second).""" returns: np.ndarray = field(default_factory=lambda: np.array([], dtype=float)) """The collected episode returns.""" returns_stat: SequenceSummaryStats | None = None """Stats of the collected returns.""" lens: np.ndarray = field(default_factory=lambda: np.array([], dtype=int)) """The collected episode lengths.""" lens_stat: SequenceSummaryStats | None = None """Stats of the collected episode lengths.""" pred_dist_std_array: np.ndarray | None = None """The standard deviations of the predicted distributions.""" pred_dist_std_array_stat: dict[int, SequenceSummaryStats] | None = None """Stats of the standard deviations of the predicted distributions (maps action dim to stats)""" @classmethod def with_autogenerated_stats( cls, returns: np.ndarray, lens: np.ndarray, n_collected_episodes: int = 0, n_collected_steps: int = 0, collect_time: float = 0.0, collect_speed: float = 0.0, ) -> Self: """Return a new instance with the stats autogenerated from the given lists.""" returns_stat = SequenceSummaryStats.from_sequence(returns) if returns.size > 0 else None lens_stat = SequenceSummaryStats.from_sequence(lens) if lens.size > 0 else None return cls( n_collected_episodes=n_collected_episodes, n_collected_steps=n_collected_steps, collect_time=collect_time, collect_speed=collect_speed, returns=returns, returns_stat=returns_stat, lens=np.array(lens, int), lens_stat=lens_stat, ) def update_at_step_batch( self, step_batch: CollectStepBatchProtocol, refresh_sequence_stats: bool = False, ) -> None: self.n_collected_steps += len(step_batch) dist = step_batch.dist action_std: torch.Tensor | None = None if dist is not None: action_std = np.atleast_2d(to_numpy(get_stddev_from_dist(dist))) if self.pred_dist_std_array is None: self.pred_dist_std_array = np.atleast_2d(to_numpy(action_std)) else: self.pred_dist_std_array = np.concatenate( (self.pred_dist_std_array, np.atleast_2d(to_numpy(action_std))), ) if refresh_sequence_stats: self.refresh_std_array_stats() def update_at_episode_done( self, episode_batch: EpisodeBatchProtocol, # NOTE: in the MARL setting this is not actually a float but rather an array or list, see todo below episode_return: float, refresh_sequence_stats: bool = False, ) -> None: self.lens = np.concatenate((self.lens, [len(episode_batch)]), dtype=int) # type: ignore self.n_collected_episodes += 1 if self.returns.size == 0: # TODO: needed for non-1dim arrays returns that happen in the MARL setting # There are multiple places that assume the returns to be 1dim, so this is a hack # Since MARL support is currently not a priority, we should either raise an error or # implement proper support for it. At the moment tests like `test_collector_with_multi_agent` fail # when assuming 1d returns self.returns = np.array([episode_return], dtype=float) else: self.returns = np.concatenate((self.returns, [episode_return]), dtype=float) # type: ignore if refresh_sequence_stats: self.refresh_return_stats() self.refresh_len_stats() def set_collect_time(self, collect_time: float, update_collect_speed: bool = True) -> None: if collect_time < 0: raise ValueError(f"Collect time should be non-negative, but got {collect_time=}.") self.collect_time = collect_time if update_collect_speed: if collect_time == 0: log.error( "Collect time is 0, setting collect speed to 0. Did you make a rounding error?", ) self.collect_speed = 0.0 else: self.collect_speed = self.n_collected_steps / collect_time def refresh_return_stats(self) -> None: if self.returns.size > 0: self.returns_stat = SequenceSummaryStats.from_sequence(self.returns) else: self.returns_stat = None def refresh_len_stats(self) -> None: if self.lens.size > 0: self.lens_stat = SequenceSummaryStats.from_sequence(self.lens) else: self.lens_stat = None def refresh_std_array_stats(self) -> None: if self.pred_dist_std_array is not None and self.pred_dist_std_array.size > 0: # need to use .T because action dim supposed to be the first axis in compute_dim_to_summary_stats self.pred_dist_std_array_stat = compute_dim_to_summary_stats(self.pred_dist_std_array.T) else: self.pred_dist_std_array_stat = None def refresh_all_sequence_stats(self) -> None: self.refresh_return_stats() self.refresh_len_stats() self.refresh_std_array_stats() TCollectStats = TypeVar("TCollectStats", bound=CollectStats) def _nullable_slice(obj: _TArrLike, indices: np.ndarray) -> _TArrLike: """Return None, or the values at the given indices if the object is not None.""" if obj is not None: return obj[indices] # type: ignore[index, return-value] return None # type: ignore[unreachable] def _dict_of_arr_to_arr_of_dicts( dict_of_arr: dict[str, np.ndarray | dict], ) -> np.ndarray: return np.array(Batch(dict_of_arr).to_list_of_dicts()) def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: """TODO: this exists because of multiple bugs in Batch and to restore backwards compatibility. Batch should be fixed and this function should be removed asap!. """ if info_array.dtype != np.dtype("O"): raise ValueError( f"Expected info_array to have dtype=object, but got {info_array.dtype}.", ) truthy_info_indices = info_array.nonzero()[0] falsy_info_indices = set(range(len(info_array))) - set(truthy_info_indices) falsy_info_indices = np.array(list(falsy_info_indices), dtype=int) if len(falsy_info_indices) == len(info_array): return Batch() some_nonempty_info = None for info in info_array: if info: some_nonempty_info = info break info_array = copy(info_array) info_array[falsy_info_indices] = some_nonempty_info result_batch_parent = Batch(info=info_array) result_batch_parent.info[falsy_info_indices] = {} return result_batch_parent.info class BaseCollector(Generic[TCollectStats], ABC): """Used to collect data from a vector environment into a buffer using a given policy. .. note:: Please make sure the given environment has a time limitation if using `n_episode` collect option. .. note:: In past versions of Tianshou, the replay buffer passed to `__init__` was automatically reset. This is not done in the current implementation. """ def __init__( self, policy: Policy | Algorithm, env: BaseVectorEnv | gym.Env, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, # The typing is correct, there's a bug in mypy, see https://github.com/python/mypy/issues/3737 collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] raise_on_nan_in_buffer: bool = ENABLE_VALIDATION, ) -> None: """ :param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch of actions from a batch of observations. :param env: a ``gymnasium.Env`` environment or a vectorized instance of the :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with a gymnasium env the collection will not happen in parallel (a `DummyVectorEnv` will be constructed internally from the passed env) :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` of size :data:`DEFAULT_BUFFER_MAXSIZE` * (number of envs) as the default buffer. :param exploration_noise: determine whether the action needs to be modified with the corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. the rollout batch with this hook also modifies the data that is collected to the buffer! :param raise_on_nan_in_buffer: whether to raise a `RuntimeError` if NaNs are found in the buffer after a collection step. Especially useful when episode-level hooks are passed for making sure that nothing is broken during the collection. Consider setting to False if the NaN-check becomes a bottleneck. :param collect_stats_class: the class to use for collecting statistics. Allows customizing the stats collection logic by passing a subclass of :class:`CollectStats`. Changing this is rarely necessary and is mainly done by "power users". """ if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy env = DummyVectorEnv([lambda: env]) # type: ignore if buffer is None: buffer = VectorReplayBuffer(DEFAULT_BUFFER_MAXSIZE * len(env), len(env)) self.buffer: ReplayBuffer | ReplayBufferManager = buffer self.raise_on_nan_in_buffer = raise_on_nan_in_buffer self.policy = policy.policy if isinstance(policy, Algorithm) else policy self.env = cast(BaseVectorEnv, env) self.exploration_noise = exploration_noise self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 self._action_space = self.env.action_space self._is_closed = False self._validate_buffer() self.collect_stats_class = collect_stats_class def _validate_buffer(self) -> None: buf = self.buffer # TODO: a bit weird but true - all VectorReplayBuffers inherit from ReplayBufferManager. # We should probably rename the manager if isinstance(buf, ReplayBufferManager) and buf.buffer_num < self.env_num: raise ValueError( f"Buffer has only {buf.buffer_num} buffers, but at least {self.env_num=} are needed.", ) if isinstance(buf, CachedReplayBuffer) and buf.cached_buffer_num < self.env_num: raise ValueError( f"Buffer has only {buf.cached_buffer_num} cached buffers, but at least {self.env_num=} are needed.", ) # Non-VectorReplayBuffer. TODO: probably shouldn't rely on isinstance if not isinstance(buf, ReplayBufferManager): if buf.maxsize == 0: raise ValueError("Buffer maxsize should be greater than 0.") if self.env_num > 1: raise ValueError( f"Cannot use {type(buf).__name__} to collect from multiple envs ({self.env_num=}). " f"Please use the corresponding VectorReplayBuffer instead.", ) @property def env_num(self) -> int: return len(self.env) @property def action_space(self) -> gym.spaces.Space: return self._action_space def close(self) -> None: """Close the collector and the environment.""" self.env.close() self._is_closed = True def reset( self, reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Reset the environment, statistics, and data needed to start the collection. :param reset_buffer: if true, reset the replay buffer attached to the collector. :param reset_stats: if true, reset the statistics attached to the collector. :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) :return: The initial observation and info from the environment. """ obs_NO, info_N = self.reset_env(gym_reset_kwargs=gym_reset_kwargs) if reset_buffer: self.reset_buffer() if reset_stats: self.reset_stat() self._is_closed = False return obs_NO, info_N def reset_stat(self) -> None: """Reset the statistic variables.""" self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 def reset_buffer(self, keep_statistics: bool = False) -> None: """Reset the data buffer.""" self.buffer.reset(keep_statistics=keep_statistics) def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Reset the environments and the initial obs, info, and hidden state of the collector.""" gym_reset_kwargs = gym_reset_kwargs or {} obs_NO, info_N = self.env.reset(**gym_reset_kwargs) # TODO: hack, wrap envpool envs such that they don't return a dict if isinstance(info_N, dict): # type: ignore[unreachable] # this can happen if the env is an envpool env. Then the thing returned by reset is a dict # with array entries instead of an array of dicts # We use Batch to turn it into an array of dicts info_N = _dict_of_arr_to_arr_of_dicts(info_N) # type: ignore[unreachable] return obs_NO, info_N @abstractmethod def _collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, gym_reset_kwargs: dict[str, Any] | None = None, ) -> TCollectStats: pass def collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> TCollectStats: """Collect the specified number of steps or episodes to the buffer. .. note:: One and only one collection specification is permitted, either ``n_step`` or ``n_episode``. To ensure an unbiased sampling result with the `n_episode` option, this function will first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` episodes, they will be collected evenly from each env. :param n_step: how many steps to collect. :param n_episode: how many episodes to collect. :param random: whether to sample randomly from the action space instead of using the policy for collecting data. :param render: the sleep time between rendering consecutive frames. :param reset_before_collect: whether to reset the environment before collecting data. (The collector needs the initial `obs` and `info` to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. :return: The collected stats """ # check that exactly one of n_step or n_episode is set and that the other is larger than 0 self._validate_n_step_n_episode(n_episode, n_step) if reset_before_collect: self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) pre_collect_time = time.time() with torch_train_mode(self.policy, enabled=False): collect_stats = self._collect( n_step=n_step, n_episode=n_episode, random=random, render=render, gym_reset_kwargs=gym_reset_kwargs, ) collect_time = time.time() - pre_collect_time collect_stats.set_collect_time(collect_time, update_collect_speed=True) collect_stats.refresh_all_sequence_stats() if self.raise_on_nan_in_buffer and self.buffer.hasnull(): nan_batch = self.buffer.isnull().apply_values_transform(np.sum) raise MalformedBufferError( "NaN detected in the buffer. You can drop them with `buffer.dropnull()`. " f"This error is most often caused by an incorrect use of {EpisodeRolloutHook.__name__}" "together with the `n_steps` (instead of `n_episodes`) option, or by " f"an incorrect implementation of {StepHook.__name__}." "Here an overview of the number of NaNs per field: \n" f"{nan_batch}", ) return collect_stats def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) -> None: if not n_step and not n_episode: raise ValueError( f"Only one of n_step and n_episode should be set to a value larger than zero " f"but got {n_step=}, {n_episode=}.", ) if n_step is None and n_episode is None: raise ValueError( "Exactly one of n_step and n_episode should be set but got None for both.", ) if n_step and n_step % self.env_num != 0: warnings.warn( f"{n_step=} is not a multiple of ({self.env_num=}), " "which may cause extra transitions being collected into the buffer.", ) if n_episode and self.env_num > n_episode: warnings.warn( f"{n_episode=} should be larger than {self.env_num=} to " f"collect at least one trajectory in each environment.", ) class Collector(BaseCollector[TCollectStats], Generic[TCollectStats]): """Collects transitions from a vectorized env by computing and applying actions batch-wise.""" # NAMING CONVENTION (mostly suffixes): # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, # the corresponding env is either reset or removed from the ready envs. # N - number of envs, always fixed and >= R. # R - number ready env ids. Note that this might change when envs get idle. # This can only happen in n_episode case, see explanation in the corresponding block. # For n_step, we always use all envs to collect the data, while for n_episode, # R will be at most n_episode at the beginning, but can decrease during the collection. # O - dimension(s) of observations # A - dimension(s) of actions # H - dimension(s) of hidden state # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. # S - number of surplus envs, i.e., envs that are ready but won't be used in the next iteration. # Only used in n_episode case. Then, R becomes R-S. # local_index - selecting from the locally available environments. In more details: # Each env is associated to a number in [0,..., N-1]. At any moment there are R ready envs, # but they are not necessarily equal to [0, ..., R-1]. Let the R corresponding indices be # [r_0, ..., r_(R-1)] (each r_i is in [0, ... N-1]). If the local index is # [0, 1, 2], it means that we want to select envs [r_0, r_1, r_2]. # We will usually select from the ready envs by slicing like `ready_env_idx_R[local_index]` # global_index - the index in [0, ..., N-1]. Slicing a `_R` index by a local_index produces the # corresponding global index. In the example above: # 1. _R index is [r_0, ..., r_(R-1)] # 2. local_index is [0, 1, 2] # 3. global_index is [r_0, r_1, r_2] and can be used to select from an array of length N # def __init__( self, policy: Policy | Algorithm, env: gym.Env | BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, on_episode_done_hook: Optional["EpisodeRolloutHookProtocol"] = None, on_step_hook: Optional["StepHookProtocol"] = None, raise_on_nan_in_buffer: bool = ENABLE_VALIDATION, collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] ) -> None: """ :param policy: a tianshou policy or algorithm :param env: a ``gymnasium.Env`` environment or a vectorized instance of the :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with a gymnasium env the collection will not happen in parallel (a `DummyVectorEnv` will be constructed internally from the passed env) :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` of size :data:`DEFAULT_BUFFER_MAXSIZE` * (number of envs) as the default buffer. :param exploration_noise: determine whether the action needs to be modified with the corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. :param on_episode_done_hook: if passed will be executed when an episode is done. The input to the hook will be a `RolloutBatch` that contains the entire episode (and nothing else). If a dict is returned by the hook it will be used to add new entries to the buffer for the episode that just ended. The values of the dict should be arrays with floats of the same length as the input rollout batch. Note that multiple hooks can be combined using :class:`EpisodeRolloutHookMerged`. A typical example of a hook is :class:`EpisodeRolloutHookMCReturn` which adds the Monte Carlo return as a field to the buffer. Care must be taken when using such hook, as for unfinished episodes one can easily end up with NaNs in the buffer. It is recommended to use the hooks only with the `n_episode` option in `collect`, or to strip the buffer of NaNs after the collection. :param on_step_hook: if passed will be executed after each step of the collection but before the resulting rollout batch is added to the buffer. The inputs to the hook will be the action distributions computed from the previous observations (following the :class:`CollectActionBatchProtocol`) using the policy, and the resulting rollout batch (following the :class:`RolloutBatchProtocol`). **Note** that modifying the rollout batch with this hook also modifies the data that is collected to the buffer! :param raise_on_nan_in_buffer: whether to raise a `RuntimeError` if NaNs are found in the buffer after a collection step. Especially useful when episode-level hooks are passed for making sure that nothing is broken during the collection. Consider setting to False if the NaN-check becomes a bottleneck. :param collect_stats_class: the class to use for collecting statistics. Allows customizing the stats collection logic by passing a subclass of :class:`CollectStats`. Changing this is rarely necessary and is mainly done by "power users". """ super().__init__( policy, env, buffer, exploration_noise=exploration_noise, collect_stats_class=collect_stats_class, raise_on_nan_in_buffer=raise_on_nan_in_buffer, ) self._pre_collect_obs_RO: np.ndarray | None = None self._pre_collect_info_R: np.ndarray | None = None self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None self._is_closed = False self._on_episode_done_hook = on_episode_done_hook self._on_step_hook = on_step_hook self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 def set_on_episode_done_hook(self, hook: Optional["EpisodeRolloutHookProtocol"]) -> None: self._on_episode_done_hook = hook def set_on_step_hook(self, hook: Optional["StepHookProtocol"]) -> None: self._on_step_hook = hook def get_on_episode_done_hook(self) -> Optional["EpisodeRolloutHookProtocol"]: return self._on_episode_done_hook def get_on_step_hook(self) -> Optional["StepHookProtocol"]: return self._on_step_hook def close(self) -> None: super().close() self._pre_collect_obs_RO = None self._pre_collect_info_R = None def run_on_episode_done( self, episode_batch: EpisodeBatchProtocol, ) -> dict[str, np.ndarray] | None: """Executes the `on_episode_done_hook` that was passed on init. One of the main uses of this public method is to allow users to override it in custom subclasses of :class:`Collector`. This way, they can override the init to no longer accept the `on_episode_done` provider. """ if self._on_episode_done_hook is not None: return self._on_episode_done_hook(episode_batch) return None def run_on_step_hook( self, action_batch: CollectActionBatchProtocol, rollout_batch: RolloutBatchProtocol, ) -> None: """Executes the instance's `on_step_hook`. One of the main uses of this public method is to allow users to override it in custom subclasses of the `Collector`. This way, they can override the init to no longer accept the `on_step_hook` provider. """ if self._on_step_hook is not None: self._on_step_hook(action_batch, rollout_batch) def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Reset the environments and the initial obs, info, and hidden state of the collector.""" obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs) # We assume that R = N when reset is called. # TODO: there is currently no mechanism that ensures this and it's a public method! self._pre_collect_obs_RO = obs_NO self._pre_collect_info_R = info_N self._pre_collect_hidden_state_RH = None return obs_NO, info_N def _compute_action_policy_hidden( self, random: bool, ready_env_ids_R: np.ndarray, last_obs_RO: np.ndarray, last_info_R: np.ndarray, last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, ) -> CollectActionBatchProtocol: """Returns the action, the normalized action, a "policy" entry, and the hidden state.""" if random: try: act_normalized_RA = np.array( [self._action_space[i].sample() for i in ready_env_ids_R], ) # TODO: test whether envpool env explicitly except TypeError: # envpool's action space is not for per-env act_normalized_RA = np.array([self._action_space.sample() for _ in ready_env_ids_R]) act_RA = self.policy.map_action_inverse(np.array(act_normalized_RA)) policy_R = Batch() hidden_state_RH = None # TODO: instead use a (uniform) Distribution instance that corresponds to sampling from action_space action_dist_R = None else: info_batch = _HACKY_create_info_batch(last_info_R) obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) act_batch_RA: ActBatchProtocol | DistBatchProtocol = self.policy( obs_batch_R, last_hidden_state_RH, ) act_RA = to_numpy(act_batch_RA.act) if self.exploration_noise: act_RA = self.policy.add_exploration_noise(act_RA, obs_batch_R) act_normalized_RA = self.policy.map_action(act_RA) # TODO: cleanup the whole policy in batch thing # todo policy_R can also be none, check policy_R = act_batch_RA.get("policy", Batch()) if not isinstance(policy_R, Batch): raise RuntimeError( f"The policy result should be a {Batch}, but got {type(policy_R)}", ) hidden_state_RH = act_batch_RA.get("state", None) # TODO: do we need the conditional? Would be better to just add hidden_state which could be None if hidden_state_RH is not None: policy_R.hidden_state = ( hidden_state_RH # save state into buffer through policy attr ) # can't use act_batch_RA.dist directly as act_batch_RA might not have that attribute action_dist_R = act_batch_RA.get("dist") return cast( CollectActionBatchProtocol, Batch( act=act_RA, act_normalized=act_normalized_RA, policy_entry=policy_R, dist=action_dist_R, hidden_state=hidden_state_RH, ), ) # TODO: reduce complexity, remove the noqa def _collect( # noqa: C901 self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, gym_reset_kwargs: dict[str, Any] | None = None, ) -> TCollectStats: """This method is currently very complex, but it's difficult to break it down into smaller chunks. Please read the block-comment of the class to understand the notation in the implementation. It does the collection by executing the following logic: 0. Keep track of n_step and n_episode for being able to stop the collection. 1. Create a CollectStats instance to store the statistics of the collection. 2. Compute actions (with policy or sampling from action space) for the R currently active envs. 3. Perform a step in these R envs. 4. Perform on-step hook on the result 5. Update the CollectStats (using `update_at_step_batch`) and the internal counters after the step 6. Add the resulting R transitions to the buffer 7. Find the D envs that reached done in the current iteration 8. Reset the envs that reached done 9. Extract episodes for the envs that reached done from the buffer 10. Perform on-episode-done hook. If it has a return, modify the transitions belonging to the episodes inside the buffer inplace 11. Update the CollectStats instance with the episodes from 9. by using `update_on_episode_done` 12. Prepare next step in while loop by saving the last observations and infos 13. Remove S surplus envs from collection mechanism, thereby reducing R to R-S, to increase performance 14. Update instance-level collection counters (contrary to counters with a lifetime of the collect execution) 15. Prepare for the next call of collect (save last observations and info to collector state) You can search for Step to find where it happens """ # TODO: can't do it init since AsyncCollector is currently a subclass of Collector if self.env.is_async: raise ValueError( f"Please use AsyncCollector for asynchronous environments. " f"Env class: {self.env.__class__.__name__}.", ) ready_env_ids_R: np.ndarray[Any, np.dtype[np.signedinteger]] """provides a mapping from local indices (indexing within `1, ..., R` where `R` is the number of ready envs) to global ones (indexing within `1, ..., num_envs`). So the entry i in this array is the global index of the i-th ready env.""" if n_step is not None: ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: if self.env_num > n_episode: log.warning( f"Number of episodes ({n_episode}) is smaller than the number of environments " f"({self.env_num}). This means that {self.env_num - n_episode} " f"environments (or, equivalently, parallel workers) will not be used!", ) ready_env_ids_R = np.arange(min(self.env_num, n_episode)) else: raise RuntimeError("Input validation failed, this is a bug and shouldn't have happened") if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: raise ValueError( "Initial obs and info should not be None. " "Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.", ) # Step 0 # get the first obs to be the current obs in the n_step case as # episodes as a new call to collect does not restart trajectories # (which we also really don't want) step_count = 0 num_collected_episodes = 0 episode_returns: list[float] = [] episode_lens: list[int] = [] episode_start_indices: list[int] = [] # Step 1 collect_stats = self.collect_stats_class() # in case we select fewer episodes than envs, we run only some of them last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R) last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R) last_hidden_state_RH = _nullable_slice( self._pre_collect_hidden_state_RH, ready_env_ids_R, ) while True: # todo check if we need this when using cur_rollout_batch # if len(cur_rollout_batch) != len(ready_env_ids): # raise RuntimeError( # f"The length of the collected_rollout_batch {len(cur_rollout_batch)}) is not equal to the length of ready_env_ids" # f"{len(ready_env_ids)}. This should not happen and could be a bug!", # ) # restore the state: if the last state is None, it won't store # Step 2 # get the next action and related stats from the previous observation collect_action_computation_batch_R = self._compute_action_policy_hidden( random=random, ready_env_ids_R=ready_env_ids_R, last_obs_RO=last_obs_RO, last_info_R=last_info_R, last_hidden_state_RH=last_hidden_state_RH, ) TraceLogger.log(log, lambda: f"Action: {collect_action_computation_batch_R.act}") # Step 3 obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( collect_action_computation_batch_R.act_normalized, ready_env_ids_R, ) if isinstance(info_R, dict): # type: ignore[unreachable] # This can happen if the env is an envpool env. Then the info returned by step is a dict info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] done_R = np.logical_or(terminated_R, truncated_R) current_step_batch_R = cast( CollectStepBatchProtocol, Batch( obs=last_obs_RO, dist=collect_action_computation_batch_R.dist, act=collect_action_computation_batch_R.act, policy=collect_action_computation_batch_R.policy_entry, obs_next=obs_next_RO, rew=rew_R, terminated=terminated_R, truncated=truncated_R, done=done_R, info=info_R, ), ) # TODO: only makes sense if render_mode is human. # Also, doubtful whether it makes sense at all for true vectorized envs if render: self.env.render() if not np.isclose(render, 0): time.sleep(render) # Step 4 self.run_on_step_hook( collect_action_computation_batch_R, current_step_batch_R, ) # Step 5, collect statistics collect_stats.update_at_step_batch(current_step_batch_R) num_episodes_done_this_iter = np.sum(done_R) num_collected_episodes += num_episodes_done_this_iter step_count += len(ready_env_ids_R) # Step 6 # add data into the buffer. Since the buffer is essentially an array, we don't want # to add the dist. One should not have arrays of dists but rather a single, batch-wise dist. # Tianshou already implements slicing of dists, but we don't yet implement merging multiple # dists into one, which would be necessary to make a buffer with dists work properly batch_to_add_R = copy(current_step_batch_R) batch_to_add_R.pop("dist") batch_to_add_R = cast(RolloutBatchProtocol, batch_to_add_R) insertion_idx_R, ep_return_R, ep_len_R, ep_start_idx_R = self.buffer.add( batch_to_add_R, buffer_ids=ready_env_ids_R, ) # preparing for the next iteration # obs_next, info and hidden_state will be modified inplace in the code below, # so we copy to not affect the data in the buffer last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) last_hidden_state_RH = copy(collect_action_computation_batch_R.hidden_state) # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration # Resetting envs that reached done, or removing some of them from the collection if needed (see below) if num_episodes_done_this_iter > 0: # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays # D - number of envs that reached done in the rollout above # local_idx - see block comment on class level # Step 7 env_done_local_idx_D = np.where(done_R)[0] """Indexes which episodes are done within the ready envs, so it can be used for selecting from `..._R` arrays. Stands in contrast to the "global" index, which counts within all envs and is unsuitable for selecting from `..._R` arrays.""" episode_lens_D = ep_len_R[env_done_local_idx_D] episode_returns_D = ep_return_R[env_done_local_idx_D] episode_start_indices_D = ep_start_idx_R[env_done_local_idx_D] episode_lens.extend(episode_lens_D) episode_returns.extend(episode_returns_D) episode_start_indices.extend(episode_start_indices_D) # Step 8 # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. gym_reset_kwargs = gym_reset_kwargs or {} # The index env_done_idx_D was based on 0, ..., R # However, each env has an index in the context of the vectorized env and buffer. So the env 0 being done means # that some env of the corresponding "global" index was done. The mapping between "local" index in # 0,...,R and this global index is maintained by the ready_env_ids_R array. # See the class block comment for more details env_done_global_idx_D = ready_env_ids_R[env_done_local_idx_D] """Indexes which episodes are done within all envs, i.e., within the index `1, ..., num_envs`. It can be used to communicate with the vector env, where env ids are selected from this "global" index. Is not suited for selecting from the ready envs (`..._R` arrays), use the local counterpart instead. """ obs_reset_DO, info_reset_D = self.env.reset( env_id=env_done_global_idx_D, **gym_reset_kwargs, ) # Set the hidden state to zero or None for the envs that reached done # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of # this complex logic self._reset_hidden_state_based_on_type(env_done_local_idx_D, last_hidden_state_RH) # Step 9 # execute episode hooks for those envs which emitted 'done' for local_done_idx, cur_ep_return in zip( env_done_local_idx_D, episode_returns_D, strict=True, ): # retrieve the episode batch from the buffer using the episode start and stop indices ep_start_idx, ep_stop_idx = ( int(ep_start_idx_R[local_done_idx]), int(insertion_idx_R[local_done_idx] + 1), ) ep_index_array = self.buffer.get_buffer_indices(ep_start_idx, ep_stop_idx) ep_batch = cast(EpisodeBatchProtocol, self.buffer[ep_index_array]) # Step 10 episode_hook_additions = self.run_on_episode_done(ep_batch) if episode_hook_additions is not None: if n_episode is None: raise ValueError( "An on_episode_done_hook with non-empty returns is not supported for n_step collection." "Such hooks should only be used when collecting full episodes. Got a on_episode_done_hook " f"that would add the following fields to the buffer: {list(episode_hook_additions)}.", ) for key, episode_addition in episode_hook_additions.items(): self.buffer.set_array_at_key( episode_addition, key, index=ep_index_array, ) # executing the same logic in the episode-batch since stats computation # may depend on the presence of additional fields ep_batch.set_array_at_key( episode_addition, key, ) # Step 11 # Finally, update the stats collect_stats.update_at_episode_done( episode_batch=ep_batch, episode_return=cur_ep_return, ) # Step 12 # preparing for the next iteration last_obs_RO[env_done_local_idx_D] = obs_reset_DO last_info_R[env_done_local_idx_D] = info_reset_D # Step 13 # Handling the case when we have more ready envs than desired and are not done yet # # This can only happen if we are collecting a fixed number of episodes # If we have more ready envs than there are remaining episodes to collect, # we will remove some of them for the next rollout # One effect of this is the following: only envs that have completed an episode # in the last step can ever be removed from the ready envs. # Thus, this guarantees that each env will contribute at least one episode to the # collected data (the buffer). This effect was previous called "avoiding bias in selecting environments" # However, it is not at all clear whether this is actually useful or necessary. # Additional naming convention: # S - number of surplus envs # TODO: can the whole block be removed? If we have too many episodes, we could just strip the last ones. # Changing R to R-S highly increases the complexity of the code. if n_episode: remaining_episodes_to_collect = n_episode - num_collected_episodes surplus_env_num = len(ready_env_ids_R) - remaining_episodes_to_collect if surplus_env_num > 0: # R becomes R-S here, preparing for the next iteration in while loop # Everything that was of length R needs to be filtered and become of length R-S. # Note that this won't be the last iteration, as one iteration equals one # step and we still need to collect the remaining episodes to reach the breaking condition. # creating the mask env_to_be_ignored_ind_local_S = env_done_local_idx_D[:surplus_env_num] env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool) env_should_remain_R[env_to_be_ignored_ind_local_S] = False # stripping the "idle" indices, shortening the relevant quantities from R to R-S ready_env_ids_R = ready_env_ids_R[env_should_remain_R] last_obs_RO = last_obs_RO[env_should_remain_R] last_info_R = last_info_R[env_should_remain_R] if collect_action_computation_batch_R.hidden_state is not None: last_hidden_state_RH = last_hidden_state_RH[env_should_remain_R] # type: ignore[index] if (n_step and step_count >= n_step) or ( n_episode and num_collected_episodes >= n_episode ): break # Check if we screwed up somewhere if self.raise_on_nan_in_buffer and self.buffer.hasnull(): nan_batch = self.buffer.isnull().apply_values_transform(np.sum) raise MalformedBufferError( "NaN detected in the buffer. You can drop them with `buffer.dropnull()`. " "This error is most often caused by an incorrect use of `EpisodeRolloutHooks`" "together with the `n_steps` (instead of `n_episodes`) option, or by " "an incorrect implementation of `StepHook`." "Here an overview of the number of NaNs per field: \n" f"{nan_batch}", ) # Step 14 # update instance-lifetime counters, different from collect_stats self.collect_step += step_count self.collect_episode += num_collected_episodes # Step 15 if n_step: # persist for future collect iterations self._pre_collect_obs_RO = last_obs_RO self._pre_collect_info_R = last_info_R self._pre_collect_hidden_state_RH = last_hidden_state_RH elif n_episode: # reset envs and the _pre_collect fields self.reset_env(gym_reset_kwargs) # todo still necessary? return collect_stats @staticmethod def _reset_hidden_state_based_on_type( env_ind_local_D: np.ndarray, last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, ) -> None: if isinstance(last_hidden_state_RH, torch.Tensor): last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] elif isinstance(last_hidden_state_RH, np.ndarray): last_hidden_state_RH[env_ind_local_D] = ( None if last_hidden_state_RH.dtype == object else 0 ) elif isinstance(last_hidden_state_RH, Batch): last_hidden_state_RH.empty_(env_ind_local_D) # todo is this inplace magic and just working? class AsyncCollector(Collector[CollectStats]): """Async Collector handles async vector environment. Please refer to :class:`~tianshou.data.Collector` for a more detailed explanation. """ def __init__( self, policy: Policy | Algorithm, env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, raise_on_nan_in_buffer: bool = True, ) -> None: if not env.is_async: # TODO: raise an exception? log.error( f"Please use {Collector.__name__} if not using async venv. " f"Env class: {env.__class__.__name__}", ) # assert env.is_async warnings.warn("Using async setting may collect extra transitions into buffer.") super().__init__( policy, env, buffer, exploration_noise, collect_stats_class=CollectStats, raise_on_nan_in_buffer=raise_on_nan_in_buffer, ) # E denotes the number of parallel environments: self.env_num # At init, E=R but during collection R <= E # Keep in sync with reset! self._ready_env_ids_R: np.ndarray = np.arange(self.env_num) self._current_obs_in_all_envs_EO: np.ndarray | None = copy(self._pre_collect_obs_RO) self._current_info_in_all_envs_E: np.ndarray | None = copy(self._pre_collect_info_R) self._current_hidden_state_in_all_envs_EH: np.ndarray | torch.Tensor | Batch | None = copy( self._pre_collect_hidden_state_RH, ) self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num) self._current_policy_in_all_envs_E: Batch | None = None def reset( self, reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Reset the environment, statistics, and data needed to start the collection. :param reset_buffer: if true, reset the replay buffer attached to the collector. :param reset_stats: if true, reset the statistics attached to the collector. :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) :return: The initial observation and info from the environment. """ # This sets the _pre_collect attrs result = super().reset( reset_buffer=reset_buffer, reset_stats=reset_stats, gym_reset_kwargs=gym_reset_kwargs, ) # Keep in sync with init! self._ready_env_ids_R = np.arange(self.env_num) # E denotes the number of parallel environments self.env_num self._current_obs_in_all_envs_EO = copy(self._pre_collect_obs_RO) self._current_info_in_all_envs_E = copy(self._pre_collect_info_R) self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH) self._current_action_in_all_envs_EA = np.empty(self.env_num) self._current_policy_in_all_envs_E = None return result @override def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: # we need to step through the envs and wait until they are ready to be able to interact with them if self.env.waiting_id: self.env.step(None, id=self.env.waiting_id) return super().reset_env(gym_reset_kwargs=gym_reset_kwargs) @override def _collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: start_time = time.time() step_count = 0 num_collected_episodes = 0 episode_returns: list[float] = [] episode_lens: list[int] = [] episode_start_indices: list[int] = [] ready_env_ids_R = self._ready_env_ids_R # last_obs_RO= self._current_obs_in_all_envs_EO[ready_env_ids_R] # type: ignore[index] # last_info_R = self._current_info_in_all_envs_E[ready_env_ids_R] # type: ignore[index] # last_hidden_state_RH = self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] # type: ignore[index] # last_obs_RO = self._pre_collect_obs_RO # last_info_R = self._pre_collect_info_R # last_hidden_state_RH = self._pre_collect_hidden_state_RH if self._current_obs_in_all_envs_EO is None or self._current_info_in_all_envs_E is None: raise RuntimeError( "Current obs or info array is None, did you call reset or pass reset_at_collect=True?", ) last_obs_RO = self._current_obs_in_all_envs_EO[ready_env_ids_R] last_info_R = self._current_info_in_all_envs_E[ready_env_ids_R] last_hidden_state_RH = _nullable_slice( self._current_hidden_state_in_all_envs_EH, ready_env_ids_R, ) # Each iteration of the AsyncCollector is only stepping a subset of the # envs. The last observation/ hidden state of the ones not included in # the current iteration has to be retained. This is done by copying the while True: # todo do we need this? # todo extend to all current attributes but some could be None at init if self._current_obs_in_all_envs_EO is None: raise RuntimeError( "Current obs is None, did you call reset or pass reset_at_collect=True?", ) if ( not len(self._current_obs_in_all_envs_EO) == len(self._current_action_in_all_envs_EA) == self.env_num ): # major difference raise RuntimeError( f"{len(self._current_obs_in_all_envs_EO)=} and" f"{len(self._current_action_in_all_envs_EA)=} have to equal" f" {self.env_num=} as it tracks the current transition" f"in all envs", ) # get the next action collect_batch_R = self._compute_action_policy_hidden( random=random, ready_env_ids_R=ready_env_ids_R, last_obs_RO=last_obs_RO, last_info_R=last_info_R, last_hidden_state_RH=last_hidden_state_RH, ) # save act_RA/policy_R/ hidden_state_RH before env.step self._current_action_in_all_envs_EA[ready_env_ids_R] = collect_batch_R.act if self._current_policy_in_all_envs_E: self._current_policy_in_all_envs_E[ready_env_ids_R] = collect_batch_R.policy_entry else: self._current_policy_in_all_envs_E = collect_batch_R.policy_entry # first iteration if collect_batch_R.hidden_state is not None: if self._current_hidden_state_in_all_envs_EH is not None: # Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not # a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat # and hope that if one of the two is a tensor, the other one is as well. self._current_hidden_state_in_all_envs_EH = cast( np.ndarray | Batch, self._current_hidden_state_in_all_envs_EH, ) self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = ( collect_batch_R.hidden_state ) else: self._current_hidden_state_in_all_envs_EH = collect_batch_R.hidden_state # step in env obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( collect_batch_R.act_normalized, ready_env_ids_R, ) done_R = np.logical_or(terminated_R, truncated_R) # Not all environments of the AsyncCollector might have performed a step in this iteration. # Change batch_of_envs_with_step_in_this_iteration here to reflect that ready_env_ids_R has changed. # This means especially that R is potentially changing every iteration try: ready_env_ids_R = cast(np.ndarray, info_R["env_id"]) # TODO: don't use bare Exception! except Exception: ready_env_ids_R = np.array([i["env_id"] for i in info_R]) current_iteration_batch = cast( RolloutBatchProtocol, Batch( obs=self._current_obs_in_all_envs_EO[ready_env_ids_R], act=self._current_action_in_all_envs_EA[ready_env_ids_R], policy=self._current_policy_in_all_envs_E[ready_env_ids_R], obs_next=obs_next_RO, rew=rew_R, terminated=terminated_R, truncated=truncated_R, done=done_R, info=info_R, ), ) if render: self.env.render() if render > 0 and not np.isclose(render, 0): time.sleep(render) # add data into the buffer ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( current_iteration_batch, buffer_ids=ready_env_ids_R, ) # collect statistics num_episodes_done_this_iter = np.sum(done_R) step_count += len(ready_env_ids_R) num_collected_episodes += num_episodes_done_this_iter # preparing for the next iteration # todo seem we can get rid of this last_sth stuff altogether last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) last_hidden_state_RH = copy( self._current_hidden_state_in_all_envs_EH[ready_env_ids_R], # type: ignore[index] ) if num_episodes_done_this_iter: env_ind_local_D = np.where(done_R)[0] env_ind_global_D = ready_env_ids_R[env_ind_local_D] episode_lens.extend(ep_len_R[env_ind_local_D]) episode_returns.extend(ep_rew_R[env_ind_local_D]) episode_start_indices.extend(ep_idx_R[env_ind_local_D]) # now we copy obs_next_RO to obs, but since there might be # finished episodes, we have to reset finished envs first. gym_reset_kwargs = gym_reset_kwargs or {} obs_reset_DO, info_reset_D = self.env.reset( env_id=env_ind_global_D, **gym_reset_kwargs, ) last_obs_RO[env_ind_local_D] = obs_reset_DO last_info_R[env_ind_local_D] = info_reset_D self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) # update based on the current transition in all envs self._current_obs_in_all_envs_EO[ready_env_ids_R] = last_obs_RO # this is a list, so loop over for idx, ready_env_id in enumerate(ready_env_ids_R): self._current_info_in_all_envs_E[ready_env_id] = last_info_R[idx] if self._current_hidden_state_in_all_envs_EH is not None: # Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not # a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat # and hope that if one of the two is a tensor, the other one is as well. self._current_hidden_state_in_all_envs_EH = cast( np.ndarray | Batch, self._current_hidden_state_in_all_envs_EH, ) self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = last_hidden_state_RH else: self._current_hidden_state_in_all_envs_EH = last_hidden_state_RH if (n_step and step_count >= n_step) or ( n_episode and num_collected_episodes >= n_episode ): break # generate statistics self.collect_step += step_count self.collect_episode += num_collected_episodes collect_time = max(time.time() - start_time, 1e-9) self.collect_time += collect_time # persist for future collect iterations self._ready_env_ids_R = ready_env_ids_R return CollectStats.with_autogenerated_stats( returns=np.array(episode_returns), lens=np.array(episode_lens), n_collected_episodes=num_collected_episodes, n_collected_steps=step_count, ) class StepHookProtocol(Protocol): """A protocol for step hooks.""" def __call__( self, action_batch: CollectActionBatchProtocol, rollout_batch: RolloutBatchProtocol, ) -> None: """The function to call when the hook is executed.""" ... class StepHook(StepHookProtocol, ABC): """Marker interface for step hooks. All step hooks in Tianshou will inherit from it, but only the corresponding protocol will be used in type hints. This makes it possible to discover all hooks in the codebase by looking up the hierarchy of this class (or the protocol itself) while still allowing the user to pass something like a lambda function as a hook. """ @abstractmethod def __call__( self, action_batch: CollectActionBatchProtocol, rollout_batch: RolloutBatchProtocol, ) -> None: ... class StepHookAddActionDistribution(StepHook): """Adds the action distribution to the collected rollout batch under the field "action_dist". The field is also accessible as class variable `ACTION_DIST_KEY`. This hook be useful for algorithms that need the previously taken actions for training, like variants of imitation learning or DAGGER. """ ACTION_DIST_KEY = "action_dist" def __call__( self, action_batch: CollectActionBatchProtocol, rollout_batch: RolloutBatchProtocol, ) -> None: rollout_batch[self.ACTION_DIST_KEY] = action_batch.dist class EpisodeRolloutHookProtocol(Protocol): """A protocol for hooks (functions) that act on an entire collected episode. Can be used to add values to the buffer that are only known after the episode is finished. A prime example is something like the MC return to go. """ def __call__(self, episode_batch: EpisodeBatchProtocol) -> dict[str, np.ndarray] | None: """Will be called by the collector when an episode is finished. If a dictionary is returned, the key-value pairs will be interpreted as new entries to be added to the episode batch (inside the buffer). In that case, the values should be arrays of the same length as the input `rollout_batch`. :param episode_batch: the batch of transitions that belong to the episode. :return: an optional dictionary containing new entries (of same len as `rollout_batch`) to be added to the buffer. """ ... class EpisodeRolloutHook(EpisodeRolloutHookProtocol, ABC): """Marker interface for episode hooks. All episode hooks in Tianshou will inherit from it, but only the corresponding protocol will be used in type hints. This makes it possible to discover all hooks in the codebase by looking up the hierarchy of this class (or the protocol itself) while still allowing the user to pass something like a lambda function as a hook. """ @abstractmethod def __call__(self, episode_batch: EpisodeBatchProtocol) -> dict[str, np.ndarray] | None: ... class EpisodeRolloutHookMCReturn(EpisodeRolloutHook): """Adds the MC return to go as well as the full episode MC return to the transitions in the buffer. The latter will be constant for all transitions in the same episode and simply corresponds to the initial MC return to go. Useful for algorithms that rely on the monte carlo returns during training. """ MC_RETURN_TO_GO_KEY = "mc_return_to_go" FULL_EPISODE_MC_RETURN_KEY = "full_episode_mc_return" class OutputDict(TypedDict): mc_return_to_go: np.ndarray full_episode_mc_return: np.ndarray def __init__(self, gamma: float = 0.99): if not 0 <= gamma <= 1: raise ValueError(f"Expected 0 <= gamma <= 1, but got {gamma=}.") self.gamma = gamma def __call__( # type: ignore[override] self, episode_batch: RolloutBatchProtocol, ) -> "EpisodeRolloutHookMCReturn.OutputDict": mc_return_to_go = episode_mc_return_to_go(episode_batch.rew, self.gamma) full_episode_mc_return = np.full_like(mc_return_to_go, mc_return_to_go[0]) return self.OutputDict( mc_return_to_go=mc_return_to_go, full_episode_mc_return=full_episode_mc_return, ) class EpisodeRolloutHookMerged(EpisodeRolloutHook): """Combines multiple episode hooks into a single one. If all hooks return `None`, this hook will also return `None`. """ def __init__( self, *episode_rollout_hooks: EpisodeRolloutHookProtocol, check_overlapping_keys: bool = True, ): """ :param episode_rollout_hooks: the hooks to combine :param check_overlapping_keys: whether to check for overlapping keys in the output of the hooks and raise a `KeyError` if any are found. Set to `False` to disable this check (can be useful if this becomes a performance bottleneck). """ self.episode_rollout_hooks = episode_rollout_hooks self.check_overlapping_keys = check_overlapping_keys def __call__(self, episode_batch: EpisodeBatchProtocol) -> dict[str, np.ndarray] | None: result: dict[str, np.ndarray] = {} for rollout_hook in self.episode_rollout_hooks: new_entries = rollout_hook(episode_batch) if new_entries is None: continue if self.check_overlapping_keys and ( duplicated_entries := set(new_entries).difference(result) ): raise KeyError( f"Combined rollout hook {rollout_hook} leads to previously " f"computed entries that would be overwritten: {duplicated_entries=}. " f"Consider combining hooks which will deliver non-overlapping entries to solve this.", ) result.update(new_entries) if not result: return None return result ================================================ FILE: tianshou/data/stats.py ================================================ import logging from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Optional import numpy as np from tianshou.utils.print import DataclassPPrintMixin if TYPE_CHECKING: from tianshou.algorithm.algorithm_base import TrainingStats from tianshou.data import CollectStats, CollectStatsBase log = logging.getLogger(__name__) @dataclass(kw_only=True) class SequenceSummaryStats(DataclassPPrintMixin): """A data structure for storing the statistics of a sequence.""" mean: float std: float max: float min: float @classmethod def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats": if len(sequence) == 0: return cls(mean=0.0, std=0.0, max=0.0, min=0.0) if hasattr(sequence, "shape") and len(sequence.shape) > 1: log.warning( f"Sequence has shape {sequence.shape}, but only 1D sequences are supported. " "Stats will be computed from the flattened sequence. For computing stats " "for each dimension consider using the function `compute_dim_to_summary_stats`.", ) return cls( mean=float(np.mean(sequence)), std=float(np.std(sequence)), max=float(np.max(sequence)), min=float(np.min(sequence)), ) @classmethod def from_single_value(cls, value: float | int) -> "SequenceSummaryStats": return cls(mean=value, std=0.0, max=value, min=value) def compute_dim_to_summary_stats( arr: Sequence[Sequence[float]] | np.ndarray, ) -> dict[int, SequenceSummaryStats]: """Compute summary statistics for each dimension of a sequence. :param arr: a 2-dim arr (or sequence of sequences) from which to compute the statistics. :return: A dictionary of summary statistics for each dimension. """ stats = {} for dim, seq in enumerate(arr): stats[dim] = SequenceSummaryStats.from_sequence(seq) return stats @dataclass(kw_only=True) class TimingStats(DataclassPPrintMixin): """A data structure for storing timing statistics.""" total_time: float = 0.0 """The total time elapsed.""" train_time: float = 0.0 """The total time elapsed for training (collecting samples plus model update).""" train_time_collect: float = 0.0 """The total time elapsed for collecting training transitions.""" train_time_update: float = 0.0 """The total time elapsed for updating models.""" test_time: float = 0.0 """The total time elapsed for testing models.""" update_speed: float = 0.0 """The speed of updating (env_step per second).""" @dataclass(kw_only=True) class InfoStats(DataclassPPrintMixin): """A data structure for storing information about the learning process.""" update_step: int """The total number of update steps that have been taken.""" best_score: float """The best score over the test results. The one with the highest score will be considered the best model.""" best_reward: float """The best reward over the test results.""" best_reward_std: float """Standard deviation of the best reward over the test results.""" train_step: int """The total collected step of training collector.""" train_episode: int """The total collected episode of training collector.""" test_step: int """The total collected step of test collector.""" test_episode: int """The total collected episode of test collector.""" timing: TimingStats """The timing statistics.""" @dataclass(kw_only=True) class EpochStats(DataclassPPrintMixin): """A data structure for storing epoch statistics.""" epoch: int """The current epoch.""" train_collect_stat: Optional["CollectStatsBase"] """The statistics of the last call to the training collector.""" test_collect_stat: Optional["CollectStats"] """The statistics of the last call to the test collector.""" training_stat: Optional["TrainingStats"] """The statistics of the last model update step. Can be None if no model update is performed, typically in the last training iteration.""" info_stat: InfoStats """The information of the collector.""" ================================================ FILE: tianshou/data/types.py ================================================ from typing import Protocol import numpy as np import torch from tianshou.data import Batch from tianshou.data.batch import BatchProtocol, TArr, TObsArr TObs = TObsArr | BatchProtocol TNestedDictValue = np.ndarray | dict[str, "TNestedDictValue"] class ObsBatchProtocol(BatchProtocol, Protocol): """Observations of an environment that a policy can turn into actions. Typically used inside a policy's forward """ obs: TObs """the observations as generated by the environment in `step`. If it is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors)""" info: TArr """array of info dicts generated by the environment in `step`""" class RolloutBatchProtocol(ObsBatchProtocol, Protocol): """Typically, the outcome of sampling from a replay buffer.""" obs_next: TObs """the observations after obs as generated by the environment in `step`. If it is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors)""" act: TArr rew: np.ndarray terminated: TArr truncated: TArr class BatchWithReturnsProtocol(RolloutBatchProtocol, Protocol): """With added returns, usually computed with GAE.""" returns: torch.Tensor class PrioBatchProtocol(RolloutBatchProtocol, Protocol): """Contains weights that can be used for prioritized replay.""" weight: np.ndarray | torch.Tensor """can be used for prioritized replay.""" class RecurrentStateBatch(BatchProtocol, Protocol): """Used by RNNs in policies, contains `hidden` and `cell` fields.""" hidden: torch.Tensor cell: torch.Tensor class ActBatchProtocol(BatchProtocol, Protocol): """Simplest batch, just containing the action. Useful e.g., for random policy.""" act: TArr class ActStateBatchProtocol(ActBatchProtocol, Protocol): """Contains action and state (which can be None), useful for policies that can support RNNs.""" state: dict | BatchProtocol | np.ndarray | None """Hidden state of RNNs, or None if not using RNNs. Used for recurrent policies. At the moment support for recurrent is experimental!""" class ModelOutputBatchProtocol(ActStateBatchProtocol, Protocol): """In addition to state and action, contains model output: (logits).""" logits: torch.Tensor class FQFBatchProtocol(ModelOutputBatchProtocol, Protocol): """Model outputs, fractions and quantiles_tau - specific to the FQF model.""" fractions: torch.Tensor quantiles_tau: torch.Tensor class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol, Protocol): """Contains estimated advantages and values. Returns are usually computed from GAE of advantages by adding the value. """ adv: torch.Tensor v_s: torch.Tensor class DistBatchProtocol(ModelOutputBatchProtocol, Protocol): """Contains dist instances for actions (created by dist_fn). Usually categorical or normal. """ dist: torch.distributions.Distribution class DistLogProbBatchProtocol(DistBatchProtocol, Protocol): """Contains dist objects that can be sampled from and log_prob of taken action.""" log_prob: torch.Tensor class LogpOldProtocol(BatchWithAdvantagesProtocol, Protocol): """Contains logp_old, often needed for importance weights, in particular in PPO. Builds on batches that contain advantages and values. """ logp_old: torch.Tensor class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol): """Contains taus for algorithms using quantile regression. See e.g. https://arxiv.org/abs/1806.06923 """ taus: torch.Tensor class ImitationBatchProtocol(ModelOutputBatchProtocol, Protocol): """Similar to other batches, but contains `imitation_logits` and `q_value` fields.""" state: dict | Batch | np.ndarray | None q_value: torch.Tensor imitation_logits: torch.Tensor ================================================ FILE: tianshou/data/utils/__init__.py ================================================ ================================================ FILE: tianshou/data/utils/converter.py ================================================ import pickle from copy import deepcopy from numbers import Number from typing import Any, Union, no_type_check import h5py import numpy as np import torch from tianshou.data.batch import Batch, _parse_value # TODO: confusing name, could actually return a batch... # Overrides and generic types should be added # todo check for ActBatchProtocol @no_type_check def to_numpy(x: Any) -> Batch | np.ndarray: """Return an object without torch.Tensor.""" if isinstance(x, torch.Tensor): # most often case return x.detach().cpu().numpy() if isinstance(x, np.ndarray): # second often case return x if isinstance(x, np.number | np.bool_ | Number): return np.asanyarray(x) if x is None: return np.array(None, dtype=object) if isinstance(x, dict | Batch): x = Batch(x) if isinstance(x, dict) else deepcopy(x) x.to_numpy_() return x if isinstance(x, list | tuple): return to_numpy(_parse_value(x)) # fallback return np.asanyarray(x) @no_type_check def to_torch( x: Any, dtype: torch.dtype | None = None, device: str | int | torch.device = "cpu", ) -> Batch | torch.Tensor: """Return an object without np.ndarray.""" if isinstance(x, np.ndarray) and issubclass( x.dtype.type, np.bool_ | np.number, ): # most often case x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) return x.to(device) if isinstance(x, torch.Tensor): # second often case if dtype is not None: x = x.type(dtype) return x.to(device) if isinstance(x, np.number | np.bool_ | Number): return to_torch(np.asanyarray(x), dtype, device) if isinstance(x, dict | Batch): x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x) x.to_torch_(dtype, device) return x if isinstance(x, list | tuple): return to_torch(_parse_value(x), dtype, device) # fallback raise TypeError(f"object {x} cannot be converted to torch.") @no_type_check def to_torch_as(x: Any, y: torch.Tensor) -> Batch | torch.Tensor: """Return an object without np.ndarray. Same as ``to_torch(x, dtype=y.dtype, device=y.device)``. """ assert isinstance(y, torch.Tensor) return to_torch(x, dtype=y.dtype, device=y.device) # Note: object is used as a proxy for objects that can be pickled # Note: mypy does not support cyclic definition currently Hdf5ConvertibleValues = Union[ int, float, Batch, np.ndarray, torch.Tensor, object, "Hdf5ConvertibleType", ] Hdf5ConvertibleType = dict[str, Hdf5ConvertibleValues] def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group, compression: str | None = None) -> None: """Copy object into HDF5 group.""" def to_hdf5_via_pickle( x: object, y: h5py.Group, key: str, compression: str | None = None, ) -> None: """Pickle, convert to numpy array and write to HDF5 dataset.""" data = np.frombuffer(pickle.dumps(x), dtype=np.byte) y.create_dataset(key, data=data, compression=compression) for k, v in x.items(): if isinstance(v, Batch | dict): # dicts and batches are both represented by groups subgrp = y.create_group(k) if isinstance(v, Batch): subgrp_data = v.__getstate__() subgrp.attrs["__data_type__"] = "Batch" else: subgrp_data = v to_hdf5(subgrp_data, subgrp, compression=compression) elif isinstance(v, torch.Tensor): # PyTorch tensors are written to datasets y.create_dataset(k, data=to_numpy(v), compression=compression) y[k].attrs["__data_type__"] = "Tensor" elif isinstance(v, np.ndarray): try: # NumPy arrays are written to datasets y.create_dataset(k, data=v, compression=compression) y[k].attrs["__data_type__"] = "ndarray" except TypeError: # If data type is not supported by HDF5 fall back to pickle. # This happens if dtype=object (e.g. due to entries being None) # and possibly in other cases like structured arrays. try: to_hdf5_via_pickle(v, y, k, compression=compression) except Exception as exception: raise RuntimeError( f"Attempted to pickle {v.__class__.__name__} due to " "data type not supported by HDF5 and failed.", ) from exception y[k].attrs["__data_type__"] = "pickled_ndarray" elif isinstance(v, int | float): # ints and floats are stored as attributes of groups y.attrs[k] = v else: # resort to pickle for any other type of object try: to_hdf5_via_pickle(v, y, k, compression=compression) except Exception as exception: raise NotImplementedError( f"No conversion to HDF5 for object of type '{type(v)}' " "implemented and fallback to pickle failed.", ) from exception y[k].attrs["__data_type__"] = v.__class__.__name__ def from_hdf5(x: h5py.Group, device: str | None = None) -> Hdf5ConvertibleValues: """Restore object from HDF5 group.""" if isinstance(x, h5py.Dataset): # handle datasets if x.attrs["__data_type__"] == "ndarray": return np.array(x) if x.attrs["__data_type__"] == "Tensor": return torch.tensor(x, device=device) return pickle.loads(x[()]) # handle groups representing a dict or a Batch y = dict(x.attrs.items()) data_type = y.pop("__data_type__", None) for k, v in x.items(): y[k] = from_hdf5(v, device) return Batch(y) if data_type == "Batch" else y ================================================ FILE: tianshou/data/utils/segtree.py ================================================ import numpy as np from numba import njit class SegmentTree: """Implementation of Segment Tree. The segment tree stores an array ``arr`` with size ``n``. It supports value update and fast query of the sum for the interval ``[left, right)`` in O(log n) time. The detailed procedure is as follows: 1. Pad the array to have length of power of 2, so that leaf nodes in the \ segment tree have the same depth. 2. Store the segment tree in a binary heap. :param size: the size of segment tree. """ def __init__(self, size: int) -> None: bound = 1 while bound < size: bound *= 2 self._size = size self._bound = bound self._value = np.zeros([bound * 2]) self._compile() def __len__(self) -> int: return self._size def __getitem__(self, index: int | np.ndarray) -> float | np.ndarray: """Return self[index].""" return self._value[index + self._bound] def __setitem__(self, index: int | np.ndarray, value: float | np.ndarray) -> None: """Update values in segment tree. Duplicate values in ``index`` are handled by numpy: later index overwrites previous ones. :: >>> a = np.array([1, 2, 3, 4]) >>> a[[0, 1, 0, 1]] = [4, 5, 6, 7] >>> print(a) [6 7 3 4] """ if isinstance(index, int): index, value = np.array([index]), np.array([value]) assert np.all(index >= 0) assert np.all(index < self._size) _setitem(self._value, index + self._bound, value) def reduce(self, start: int = 0, end: int | None = None) -> float: """Return operation(value[start:end]).""" if start == 0 and end is None: return self._value[1] if end is None: end = self._size if end < 0: end += self._size return _reduce(self._value, start + self._bound - 1, end + self._bound) def get_prefix_sum_idx(self, value: float | np.ndarray) -> int | np.ndarray: r"""Find the index with given value. Return the minimum index for each ``v`` in ``value`` so that :math:`v \le \mathrm{sums}_i`, where :math:`\mathrm{sums}_i = \sum_{j = 0}^{i} \mathrm{arr}_j`. .. warning:: Please make sure all of the values inside the segment tree are non-negative when using this function. """ assert np.all(value >= 0.0) assert np.all(value < self._value[1]) single = False if not isinstance(value, np.ndarray): value = np.array([value]) single = True index = _get_prefix_sum_idx(value, self._bound, self._value) return index.item() if single else index def _compile(self) -> None: f64 = np.array([0, 1], dtype=np.float64) f32 = np.array([0, 1], dtype=np.float32) i64 = np.array([0, 1], dtype=np.int64) _setitem(f64, i64, f64) _setitem(f64, i64, f32) _reduce(f64, 0, 1) _get_prefix_sum_idx(f64, 1, f64) _get_prefix_sum_idx(f32, 1, f64) @njit def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None: """Numba version, 4x faster: 0.1 -> 0.024.""" tree[index] = value while index[0] > 1: index //= 2 tree[index] = tree[index * 2] + tree[index * 2 + 1] @njit def _reduce(tree: np.ndarray, start: int, end: int) -> float: """Numba version, 2x faster: 0.009 -> 0.005.""" # nodes in (start, end) should be aggregated result = 0.0 while end - start > 1: # (start, end) interval is not empty if start % 2 == 0: result += tree[start + 1] start //= 2 if end % 2 == 1: result += tree[end - 1] end //= 2 return result @njit def _get_prefix_sum_idx(value: np.ndarray, bound: int, sums: np.ndarray) -> np.ndarray: """Numba version (v0.51), 5x speed up with size=100000 and bsz=64. vectorized np: 0.0923 (numpy best) -> 0.024 (now) for-loop: 0.2914 -> 0.019 (but not so stable) """ index = np.ones(value.shape, dtype=np.int64) while index[0] < bound: index *= 2 lsons = sums[index] direct = lsons < value value -= lsons * direct index += direct index -= bound return index ================================================ FILE: tianshou/env/__init__.py ================================================ """Env package.""" from tianshou.env.gym_wrappers import ( ContinuousToDiscrete, MultiDiscreteToDiscrete, TruncatedAsTerminated, ) from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper from tianshou.env.venvs import ( BaseVectorEnv, DummyVectorEnv, RayVectorEnv, ShmemVectorEnv, SubprocVectorEnv, ) __all__ = [ "BaseVectorEnv", "ContinuousToDiscrete", "DummyVectorEnv", "MultiDiscreteToDiscrete", "PettingZooEnv", "RayVectorEnv", "ShmemVectorEnv", "SubprocVectorEnv", "TruncatedAsTerminated", "VectorEnvNormObs", "VectorEnvWrapper", ] ================================================ FILE: tianshou/env/atari/atari_network.py ================================================ from collections.abc import Callable, Sequence from typing import Any, TypeVar import numpy as np import torch from torch import nn from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrOrCont from tianshou.data import Batch from tianshou.data.types import TObs from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.core import ( TDevice, ) from tianshou.highlevel.module.intermediate import ( IntermediateModule, IntermediateModuleFactory, ) from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical from tianshou.utils.net.common import ( ActionReprNetWithVectorOutput, ) from tianshou.utils.net.discrete import DiscreteActor, NoisyLinear from tianshou.utils.torch_utils import torch_device def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: """TODO.""" torch.nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) return layer T = TypeVar("T") class ScaledObsInputActionReprNet(ActionReprNetWithVectorOutput): def __init__(self, module: ActionReprNetWithVectorOutput, denom: float = 255.0) -> None: super().__init__(module.get_output_dim()) self.module = module self.denom = denom def forward( self, obs: TObs, state: T | None = None, info: dict[str, T] | None = None, ) -> tuple[torch.Tensor | Sequence[torch.Tensor], T | None]: if info is None: info = {} scaler = lambda arr: arr / self.denom if isinstance(obs, Batch): scaled_obs = obs.apply_values_transform(scaler) else: scaled_obs = scaler(obs) return self.module.forward(scaled_obs, state, info) class DQNet(ActionReprNetWithVectorOutput[Any]): """Reference: Human-level control through deep reinforcement learning.""" def __init__( self, c: int, h: int, w: int, action_shape: Sequence[int] | int, features_only: bool = False, output_dim_added_layer: int | None = None, layer_init: Callable[[nn.Module], nn.Module] = lambda x: x, ) -> None: # TODO: Add docstring if not features_only and output_dim_added_layer is not None: raise ValueError( "Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.", ) net = nn.Sequential( layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)), nn.ReLU(inplace=True), layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), nn.ReLU(inplace=True), layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)), nn.ReLU(inplace=True), nn.Flatten(), ) with torch.no_grad(): base_cnn_output_dim = int(np.prod(net(torch.zeros(1, c, h, w)).shape[1:])) if not features_only: action_dim = int(np.prod(action_shape)) net = nn.Sequential( net, layer_init(nn.Linear(base_cnn_output_dim, 512)), nn.ReLU(inplace=True), layer_init(nn.Linear(512, action_dim)), ) output_dim = action_dim elif output_dim_added_layer is not None: net = nn.Sequential( net, layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)), nn.ReLU(inplace=True), ) output_dim = output_dim_added_layer else: output_dim = base_cnn_output_dim super().__init__(output_dim) self.net = net def forward( self, obs: TObs, state: T | None = None, info: dict[str, T] | None = None, ) -> tuple[torch.Tensor, T | None]: r"""Mapping: s -> Q(s, \*). For more info, see docstring of parent. """ device = torch_device(self) obs = torch.as_tensor(obs, device=device, dtype=torch.float32) return self.net(obs), state class C51Net(DQNet): """Reference: A distributional perspective on reinforcement learning.""" def __init__( self, *, c: int, h: int, w: int, action_shape: Sequence[int], num_atoms: int = 51, ) -> None: self.action_num = int(np.prod(action_shape)) super().__init__(c=c, h=h, w=w, action_shape=[self.action_num * num_atoms]) self.num_atoms = num_atoms def forward( self, obs: TObs, state: T | None = None, info: dict[str, T] | None = None, ) -> tuple[torch.Tensor, T | None]: r"""Mapping: x -> Z(x, \*).""" obs, state = super().forward(obs) obs = obs.view(-1, self.num_atoms).softmax(dim=-1) obs = obs.view(-1, self.action_num, self.num_atoms) return obs, state class RainbowNet(DQNet): """Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.""" def __init__( self, *, c: int, h: int, w: int, action_shape: Sequence[int], num_atoms: int = 51, noisy_std: float = 0.5, is_dueling: bool = True, is_noisy: bool = True, ) -> None: super().__init__(c=c, h=h, w=w, action_shape=action_shape, features_only=True) self.action_num = int(np.prod(action_shape)) self.num_atoms = num_atoms def linear(x: int, y: int) -> NoisyLinear | nn.Linear: if is_noisy: return NoisyLinear(x, y, noisy_std) return nn.Linear(x, y) self.Q = nn.Sequential( linear(self.output_dim, 512), nn.ReLU(inplace=True), linear(512, self.action_num * self.num_atoms), ) self._is_dueling = is_dueling if self._is_dueling: self.V = nn.Sequential( linear(self.output_dim, 512), nn.ReLU(inplace=True), linear(512, self.num_atoms), ) self.output_dim = self.action_num * self.num_atoms def forward( self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, T | None]: obs, state = super().forward(obs) q = self.Q(obs) q = q.view(-1, self.action_num, self.num_atoms) if self._is_dueling: v = self.V(obs) v = v.view(-1, 1, self.num_atoms) logits = q - q.mean(dim=1, keepdim=True) + v else: logits = q probs = logits.softmax(dim=2) return probs, state class QRDQNet(DQNet): """Reference: Distributional Reinforcement Learning with Quantile Regression.""" def __init__( self, *, c: int, h: int, w: int, action_shape: Sequence[int] | int, num_quantiles: int = 200, ) -> None: self.action_num = int(np.prod(action_shape)) super().__init__(c=c, h=h, w=w, action_shape=[self.action_num * num_quantiles]) self.num_quantiles = num_quantiles def forward( self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, T | None]: obs, state = super().forward(obs) obs = obs.view(-1, self.action_num, self.num_quantiles) return obs, state class ActorFactoryAtariDQN(ActorFactory): USE_SOFTMAX_OUTPUT = False def __init__( self, scale_obs: bool = True, features_only: bool = False, output_dim_added_layer: int | None = None, ) -> None: self.output_dim_added_layer = output_dim_added_layer self.scale_obs = scale_obs self.features_only = features_only def create_module(self, envs: Environments, device: TDevice) -> DiscreteActor: c, h, w = envs.get_observation_shape() # type: ignore # only right shape is a sequence of length 3 action_shape = envs.get_action_shape() if isinstance(action_shape, np.int64): action_shape = int(action_shape) net: DQNet | ScaledObsInputActionReprNet net = DQNet( c=c, h=h, w=w, action_shape=action_shape, features_only=self.features_only, output_dim_added_layer=self.output_dim_added_layer, layer_init=layer_init, ) if self.scale_obs: net = ScaledObsInputActionReprNet(net) return DiscreteActor( preprocess_net=net, action_shape=envs.get_action_shape(), softmax_output=self.USE_SOFTMAX_OUTPUT, ).to(device) def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: return DistributionFunctionFactoryCategorical( is_probs_input=self.USE_SOFTMAX_OUTPUT, ).create_dist_fn(envs) class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory): def __init__(self, features_only: bool = False, net_only: bool = False) -> None: self.features_only = features_only self.net_only = net_only def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: obs_shape = envs.get_observation_shape() if isinstance(obs_shape, int): obs_shape = [obs_shape] assert len(obs_shape) == 3 c, h, w = obs_shape action_shape = envs.get_action_shape() if isinstance(action_shape, np.int64): action_shape = int(action_shape) dqn = DQNet( c=c, h=h, w=w, action_shape=action_shape, features_only=self.features_only, ).to(device) module = dqn.net if self.net_only else dqn return IntermediateModule(module, dqn.output_dim) class IntermediateModuleFactoryAtariDQNFeatures(IntermediateModuleFactoryAtariDQN): def __init__(self) -> None: super().__init__(features_only=True, net_only=True) ================================================ FILE: tianshou/env/atari/atari_wrapper.py ================================================ # Borrow a lot from openai baselines: # https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py import logging import warnings from collections import deque from typing import Any, SupportsFloat import cv2 import gymnasium as gym import numpy as np from gymnasium import Env from tianshou.env import BaseVectorEnv from tianshou.highlevel.env import ( EnvFactoryRegistered, EnvMode, EnvPoolFactory, VectorEnvType, ) from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext envpool_is_available = True try: import envpool except ImportError: envpool_is_available = False envpool = None log = logging.getLogger(__name__) def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]: contains_info = ( isinstance(reset_result, tuple) and len(reset_result) == 2 and isinstance(reset_result[1], dict) ) if contains_info: return reset_result[0], reset_result[1], contains_info return reset_result, {}, contains_info def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]: """TODO.""" obs_space_dtype: type[np.integer] | type[np.floating] if np.issubdtype(obs_space.dtype, np.integer): obs_space_dtype = np.integer elif np.issubdtype(obs_space.dtype, np.floating): obs_space_dtype = np.floating else: raise TypeError( f"Unsupported observation space dtype: {obs_space.dtype}. " f"This might be a bug in tianshou or gymnasium, please report it!", ) return obs_space_dtype class NoopResetEnv(gym.Wrapper): """Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. :param gym.Env env: the environment to wrap. :param int noop_max: the maximum value of no-ops to run. """ def __init__(self, env: gym.Env, noop_max: int = 30) -> None: super().__init__(env) self.noop_max = noop_max self.noop_action = 0 assert hasattr(env.unwrapped, "get_action_meanings") assert env.unwrapped.get_action_meanings()[0] == "NOOP" def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: _, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) for _ in range(noops): step_result = self.env.step(self.noop_action) if len(step_result) == 4: obs, rew, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) else: obs, rew, term, trunc, info = step_result done = term or trunc if done: obs, info, _ = _parse_reset_result(self.env.reset()) if return_info: return obs, info return obs, {} class MaxAndSkipEnv(gym.Wrapper): """Return only every `skip`-th frame (frameskipping) using most recent raw observations (for max pooling across time steps). :param gym.Env env: the environment to wrap. :param int skip: number of `skip`-th frame. """ def __init__(self, env: gym.Env, skip: int = 4) -> None: super().__init__(env) self._skip = skip def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: """Step the environment with the given action. Repeat action, sum reward, and max over last observations. """ obs_list = [] total_reward = 0.0 new_step_api = False for _ in range(self._skip): step_result = self.env.step(action) if len(step_result) == 4: obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) else: obs, reward, term, trunc, info = step_result done = term or trunc new_step_api = True obs_list.append(obs) total_reward += float(reward) if done: break max_frame = np.max(obs_list[-2:], axis=0) if new_step_api: return max_frame, total_reward, term, trunc, info return ( max_frame, total_reward, done, info.get("TimeLimit.truncated", False), info, ) class EpisodicLifeEnv(gym.Wrapper): """Make end-of-life == end-of-episode, but only reset on true game over. It helps the value estimation. :param gym.Env env: the environment to wrap. """ def __init__(self, env: gym.Env) -> None: super().__init__(env) self.lives = 0 self.was_real_done = True self._return_info = False def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: step_result = self.env.step(action) if len(step_result) == 4: obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) new_step_api = False else: obs, reward, term, trunc, info = step_result done = term or trunc new_step_api = True reward = float(reward) self.was_real_done = done # check current lives, make loss of life terminal, then update lives to # handle bonus lives assert hasattr(self.env.unwrapped, "ale") lives = self.env.unwrapped.ale.lives() if 0 < lives < self.lives: # for Qbert sometimes we stay in lives == 0 condition for a few # frames, so its important to keep lives > 0, so that we only reset # once the environment is actually done. done = True term = True self.lives = lives if new_step_api: return obs, reward, term, trunc, info return obs, reward, done, info.get("TimeLimit.truncated", False), info def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: """Calls the Gym environment reset, only when lives are exhausted. This way all states are still reachable even though lives are episodic, and the learner need not know about any of this behind-the-scenes. """ if self.was_real_done: obs, info, self._return_info = _parse_reset_result(self.env.reset(**kwargs)) else: # no-op step to advance from terminal/lost life state step_result = self.env.step(0) obs, info = step_result[0], step_result[-1] assert hasattr(self.env.unwrapped, "ale") self.lives = self.env.unwrapped.ale.lives() if self._return_info: return obs, info return obs, {} class FireResetEnv(gym.Wrapper): """Take action on reset for environments that are fixed until firing. Related discussion: https://github.com/openai/baselines/issues/240. :param gym.Env env: the environment to wrap. """ def __init__(self, env: gym.Env) -> None: super().__init__(env) assert hasattr(env.unwrapped, "get_action_meanings") assert env.unwrapped.get_action_meanings()[1] == "FIRE" assert len(env.unwrapped.get_action_meanings()) >= 3 def reset(self, **kwargs: Any) -> tuple[Any, dict]: _, _, return_info = _parse_reset_result(self.env.reset(**kwargs)) obs = self.env.step(1)[0] return obs, {} class WarpFrame(gym.ObservationWrapper): """Warp frames to 84x84 as done in the Nature paper and later work. :param gym.Env env: the environment to wrap. """ def __init__(self, env: gym.Env) -> None: super().__init__(env) self.size = 84 obs_space = env.observation_space assert isinstance(obs_space, gym.spaces.Box) obs_space_dtype = get_space_dtype(obs_space) self.observation_space = gym.spaces.Box( low=np.min(obs_space.low), high=np.max(obs_space.high), shape=(self.size, self.size), dtype=obs_space_dtype, ) def observation(self, frame: np.ndarray) -> np.ndarray: """Returns the current observation from a frame.""" frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) class ScaledFloatFrame(gym.ObservationWrapper): """Normalize observations to 0~1. :param gym.Env env: the environment to wrap. """ def __init__(self, env: gym.Env) -> None: super().__init__(env) obs_space = env.observation_space assert isinstance(obs_space, gym.spaces.Box) low = np.min(obs_space.low) high = np.max(obs_space.high) self.bias = low self.scale = high - low self.observation_space = gym.spaces.Box( low=0.0, high=1.0, shape=obs_space.shape, dtype=np.float32, ) def observation(self, observation: np.ndarray) -> np.ndarray: return (observation - self.bias) / self.scale class ClipRewardEnv(gym.RewardWrapper): """clips the reward to {+1, 0, -1} by its sign. :param gym.Env env: the environment to wrap. """ def __init__(self, env: gym.Env) -> None: super().__init__(env) self.reward_range = (-1, 1) def reward(self, reward: SupportsFloat) -> int: """Bin reward to {+1, 0, -1} by its sign. Note: np.sign(0) == 0.""" return np.sign(float(reward)) class FrameStack(gym.Wrapper): """Stack n_frames last frames. :param gym.Env env: the environment to wrap. :param int n_frames: the number of frames to stack. """ def __init__(self, env: gym.Env, n_frames: int) -> None: super().__init__(env) self.n_frames: int = n_frames self.frames: deque[tuple[Any, ...]] = deque([], maxlen=n_frames) obs_space = env.observation_space obs_space_shape = env.observation_space.shape assert obs_space_shape is not None shape = (n_frames, *obs_space_shape) assert isinstance(obs_space, gym.spaces.Box) obs_space_dtype = get_space_dtype(obs_space) self.observation_space = gym.spaces.Box( low=np.min(obs_space.low), high=np.max(obs_space.high), shape=shape, dtype=obs_space_dtype, ) def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) for _ in range(self.n_frames): self.frames.append(obs) return (self._get_ob(), info) if return_info else (self._get_ob(), {}) def step(self, action: Any) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: step_result = self.env.step(action) done: bool if len(step_result) == 4: obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) new_step_api = False else: obs, reward, term, trunc, info = step_result new_step_api = True self.frames.append(obs) reward = float(reward) if new_step_api: return self._get_ob(), reward, term, trunc, info return ( self._get_ob(), reward, done, info.get("TimeLimit.truncated", False), info, ) def _get_ob(self) -> np.ndarray: # the original wrapper use `LazyFrames` but since we use np buffer, # it has no effect return np.stack(self.frames, axis=0) def wrap_deepmind( env: gym.Env, episode_life: bool = True, clip_rewards: bool = True, frame_stack: int = 4, scale: bool = False, warp_frame: bool = True, ) -> ( MaxAndSkipEnv | EpisodicLifeEnv | FireResetEnv | WarpFrame | ScaledFloatFrame | ClipRewardEnv | FrameStack ): """Configure environment for DeepMind-style Atari. The observation is channel-first: (c, h, w) instead of (h, w, c). :param env: the Atari environment to wrap. :param bool episode_life: wrap the episode life wrapper. :param bool clip_rewards: wrap the reward clipping wrapper. :param int frame_stack: wrap the frame stacking wrapper. :param bool scale: wrap the scaling observation wrapper. :param bool warp_frame: wrap the grayscale + resize observation wrapper. :return: the wrapped atari environment. """ env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) assert hasattr(env.unwrapped, "get_action_meanings") # for mypy wrapped_env: ( MaxAndSkipEnv | EpisodicLifeEnv | FireResetEnv | WarpFrame | ScaledFloatFrame | ClipRewardEnv | FrameStack ) = env if episode_life: wrapped_env = EpisodicLifeEnv(wrapped_env) if "FIRE" in env.unwrapped.get_action_meanings(): wrapped_env = FireResetEnv(wrapped_env) if warp_frame: wrapped_env = WarpFrame(wrapped_env) if scale: wrapped_env = ScaledFloatFrame(wrapped_env) if clip_rewards: wrapped_env = ClipRewardEnv(wrapped_env) if frame_stack: wrapped_env = FrameStack(wrapped_env, frame_stack) return wrapped_env def make_atari_env( task: str, seed: int, num_training_envs: int, num_test_envs: int, scale: int | bool = False, frame_stack: int = 4, ) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]: """Wrapper function for Atari env. If EnvPool is installed, it will automatically switch to EnvPool's Atari env. :return: a tuple of (single env, training envs, test envs). """ env_factory = AtariEnvFactory(task, frame_stack, scale=bool(scale)) envs = env_factory.create_envs(num_training_envs, num_test_envs, seed=seed) return envs.env, envs.training_envs, envs.test_envs class AtariEnvFactory(EnvFactoryRegistered): def __init__( self, task: str, frame_stack: int, scale: bool = False, use_envpool_if_available: bool = True, venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO, ) -> None: assert "NoFrameskip" in task self.frame_stack = frame_stack self.scale = scale envpool_factory = None if use_envpool_if_available: if envpool_is_available: envpool_factory = self.EnvPoolFactoryAtari(self) log.info("Using envpool, because it available") else: log.info("Not using envpool, because it is not available") super().__init__( task=task, venv_type=venv_type, envpool_factory=envpool_factory, ) def _create_env(self, mode: EnvMode) -> gym.Env: env = super()._create_env(mode) is_train = mode == EnvMode.TRAINING return wrap_deepmind( env, episode_life=is_train, clip_rewards=is_train, frame_stack=self.frame_stack, scale=self.scale, ) class EnvPoolFactoryAtari(EnvPoolFactory): """Atari-specific envpool creation. Since envpool internally handles the functions that are implemented through the wrappers in `wrap_deepmind`, it sets the creation keyword arguments accordingly. """ def __init__(self, parent: "AtariEnvFactory") -> None: self.parent = parent if self.parent.scale: warnings.warn( "EnvPool does not include ScaledFloatFrame wrapper, " "please compensate by scaling inside your network's forward function (e.g. `x = x / 255.0` for Atari)", ) def _transform_task(self, task: str) -> str: task = super()._transform_task(task) # TODO: Maybe warn user, explain why this is needed return task.replace("NoFrameskip-v4", "-v5") def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: kwargs = super()._transform_kwargs(kwargs, mode) is_train = mode == EnvMode.TRAINING kwargs["reward_clip"] = is_train kwargs["episodic_life"] = is_train kwargs["stack_num"] = self.parent.frame_stack return kwargs class AtariEpochStopCallback(EpochStopCallback): def __init__(self, task: str) -> None: self.task = task def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: env = context.envs.env if env.spec and env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold if "Pong" in self.task: return mean_rewards >= 20 return False ================================================ FILE: tianshou/env/gym_wrappers.py ================================================ from typing import Any, SupportsFloat import gymnasium as gym import numpy as np from packaging import version class ContinuousToDiscrete(gym.ActionWrapper): """Gym environment wrapper to take discrete action in a continuous environment. :param gym.Env env: gym environment with continuous action space. :param action_per_dim: number of discrete actions in each dimension of the action space. """ def __init__(self, env: gym.Env, action_per_dim: int | list[int]) -> None: super().__init__(env) assert isinstance(env.action_space, gym.spaces.Box) low, high = env.action_space.low, env.action_space.high if isinstance(action_per_dim, int): action_per_dim = [action_per_dim] * env.action_space.shape[0] assert len(action_per_dim) == env.action_space.shape[0] self.action_space = gym.spaces.MultiDiscrete(action_per_dim) self.mesh = np.array( [np.linspace(lo, hi, a) for lo, hi, a in zip(low, high, action_per_dim, strict=True)], dtype=object, ) def action(self, act: np.ndarray) -> np.ndarray: # modify act assert len(act.shape) <= 2, f"Unknown action format with shape {act.shape}." if len(act.shape) == 1: return np.array([self.mesh[i][a] for i, a in enumerate(act)]) return np.array([[self.mesh[i][a] for i, a in enumerate(a_)] for a_ in act]) class MultiDiscreteToDiscrete(gym.ActionWrapper): """Gym environment wrapper to take discrete action in multidiscrete environment. :param gym.Env env: gym environment with multidiscrete action space. """ def __init__(self, env: gym.Env) -> None: super().__init__(env) assert isinstance(env.action_space, gym.spaces.MultiDiscrete) nvec = env.action_space.nvec assert nvec.ndim == 1 self.bases = np.ones_like(nvec) for i in range(1, len(self.bases)): self.bases[i] = self.bases[i - 1] * nvec[-i] self.action_space = gym.spaces.Discrete(np.prod(nvec)) def action(self, act: np.ndarray) -> np.ndarray: converted_act = [] for b in np.flip(self.bases): converted_act.append(act // b) act = act % b return np.array(converted_act).transpose() class TruncatedAsTerminated(gym.Wrapper): """A wrapper that set ``terminated = terminated or truncated`` for ``step()``. It's intended to use with ``gym.wrappers.TimeLimit``. :param gym.Env env: gym environment. """ def __init__(self, env: gym.Env): super().__init__(env) if not version.parse(gym.__version__) >= version.parse("0.26.0"): raise OSError( f"TruncatedAsTerminated is not applicable with gym version \ {gym.__version__}", ) def step(self, act: np.ndarray) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: observation, reward, terminated, truncated, info = super().step(act) terminated = terminated or truncated return observation, reward, terminated, truncated, info ================================================ FILE: tianshou/env/pettingzoo_env.py ================================================ import warnings from abc import ABC from typing import Any import pettingzoo from gymnasium import spaces from packaging import version from pettingzoo.utils.env import AECEnv from pettingzoo.utils.wrappers import BaseWrapper if version.parse(pettingzoo.__version__) < version.parse("1.21.0"): warnings.warn( f"You are using PettingZoo {pettingzoo.__version__}. " f"Future tianshou versions may not support PettingZoo<1.21.0. " f"Consider upgrading your PettingZoo version.", DeprecationWarning, ) class PettingZooEnv(AECEnv, ABC): """The interface for petting zoo environments which support multi-agent RL. Multi-agent environments must be wrapped as :class:`~tianshou.env.PettingZooEnv`. Here is the usage: :: env = PettingZooEnv(...) # obs is a dict containing obs, agent_id, and mask obs = env.reset() action = policy(obs) obs, rew, trunc, term, info = env.step(action) env.close() The available action's mask is set to True, otherwise it is set to False. """ def __init__(self, env: BaseWrapper): super().__init__() self.env = env # agent idx list self.agents = self.env.possible_agents self.agent_idx = {} for i, agent_id in enumerate(self.agents): self.agent_idx[agent_id] = i self.rewards = [0] * len(self.agents) # Get first observation space, assuming all agents have equal space self.observation_space: Any = self.env.observation_space(self.agents[0]) # Get first action space, assuming all agents have equal space self.action_space: Any = self.env.action_space(self.agents[0]) assert all( self.env.observation_space(agent) == self.observation_space for agent in self.agents ), ( "Observation spaces for all agents must be identical. Perhaps " "SuperSuit's pad_observations wrapper can help (usage: " "`supersuit.pad_observations_v0(env)`" ) assert all(self.env.action_space(agent) == self.action_space for agent in self.agents), ( "Action spaces for all agents must be identical. Perhaps " "SuperSuit's pad_action_space wrapper can help (usage: " "`supersuit.pad_action_space_v0(env)`" ) self.reset() def reset(self, *args: Any, **kwargs: Any) -> tuple[dict, dict]: self.env.reset(*args, **kwargs) observation, reward, terminated, truncated, info = self.env.last(self) if isinstance(observation, dict) and "action_mask" in observation: observation_dict = { "agent_id": self.env.agent_selection, "obs": observation["observation"], "mask": [obm == 1 for obm in observation["action_mask"]], } else: if isinstance(self.action_space, spaces.Discrete): observation_dict = { "agent_id": self.env.agent_selection, "obs": observation, "mask": [True] * self.env.action_space(self.env.agent_selection).n, } else: observation_dict = { "agent_id": self.env.agent_selection, "obs": observation, } return observation_dict, info def step(self, action: Any) -> tuple[dict, list[int], bool, bool, dict]: self.env.step(action) observation, rew, term, trunc, info = self.env.last() if isinstance(observation, dict) and "action_mask" in observation: obs = { "agent_id": self.env.agent_selection, "obs": observation["observation"], "mask": [obm == 1 for obm in observation["action_mask"]], } else: if isinstance(self.action_space, spaces.Discrete): obs = { "agent_id": self.env.agent_selection, "obs": observation, "mask": [True] * self.env.action_space(self.env.agent_selection).n, } else: obs = {"agent_id": self.env.agent_selection, "obs": observation} for agent_id, reward in self.env.rewards.items(): self.rewards[self.agent_idx[agent_id]] = reward return obs, self.rewards, term, trunc, info def close(self) -> None: self.env.close() def seed(self, seed: Any = None) -> None: try: self.env.seed(seed) except (NotImplementedError, AttributeError): self.env.reset(seed=seed) def render(self) -> Any: return self.env.render() ================================================ FILE: tianshou/env/utils.py ================================================ from typing import Any import cloudpickle import gymnasium import numpy as np from tianshou.env.pettingzoo_env import PettingZooEnv ENV_TYPE = gymnasium.Env | PettingZooEnv gym_new_venv_step_type = tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] class CloudpickleWrapper: """A cloudpickle wrapper used in SubprocVectorEnv.""" def __init__(self, data: Any) -> None: self.data = data def __getstate__(self) -> str: return cloudpickle.dumps(self.data) def __setstate__(self, data: str) -> None: self.data = cloudpickle.loads(data) ================================================ FILE: tianshou/env/venv_wrappers.py ================================================ from typing import Any import numpy as np import torch from tianshou.env.utils import gym_new_venv_step_type from tianshou.env.venvs import GYM_RESERVED_KEYS, BaseVectorEnv from tianshou.utils import RunningMeanStd class VectorEnvWrapper(BaseVectorEnv): """Base class for vectorized environments wrapper.""" # Note: No super call because this is a wrapper with overridden __getattribute__ # It's not a "true" subclass of BaseVectorEnv but it does extend its interface, so # it can be used as a drop-in replacement # noinspection PyMissingConstructor def __init__(self, venv: BaseVectorEnv) -> None: self.venv = venv self.is_async = venv.is_async def __len__(self) -> int: return len(self.venv) def __getattribute__(self, key: str) -> Any: if key in GYM_RESERVED_KEYS: # reserved keys in gym.Env return getattr(self.venv, key) return super().__getattribute__(key) def get_env_attr( self, key: str, id: int | list[int] | np.ndarray | None = None, ) -> list[Any]: return self.venv.get_env_attr(key, id) def set_env_attr( self, key: str, value: Any, id: int | list[int] | np.ndarray | None = None, ) -> None: return self.venv.set_env_attr(key, value, id) def reset( self, env_id: int | list[int] | np.ndarray | None = None, **kwargs: Any, ) -> tuple[np.ndarray, np.ndarray]: return self.venv.reset(env_id, **kwargs) def step( self, action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: return self.venv.step(action, id) def seed(self, seed: int | list[int] | None = None) -> list[list[int] | None]: return self.venv.seed(seed) def render(self, **kwargs: Any) -> list[Any]: return self.venv.render(**kwargs) def close(self) -> None: self.venv.close() class VectorEnvNormObs(VectorEnvWrapper): """An observation normalization wrapper for vectorized environments. :param update_obs_rms: whether to update obs_rms. Default to True. """ def __init__(self, venv: BaseVectorEnv, update_obs_rms: bool = True) -> None: super().__init__(venv) # initialize observation running mean/std self.update_obs_rms = update_obs_rms self.obs_rms = RunningMeanStd() def reset( self, env_id: int | list[int] | np.ndarray | None = None, **kwargs: Any, ) -> tuple[np.ndarray, np.ndarray]: obs, info = self.venv.reset(env_id, **kwargs) if isinstance(obs, tuple): # type: ignore raise TypeError( "Tuple observation space is not supported. ", "Please change it to array or dict space", ) if self.obs_rms and self.update_obs_rms: self.obs_rms.update(obs) obs = self._norm_obs(obs) return obs, info def step( self, action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: step_results = self.venv.step(action, id) if self.obs_rms and self.update_obs_rms: self.obs_rms.update(step_results[0]) return (self._norm_obs(step_results[0]), *step_results[1:]) def _norm_obs(self, obs: np.ndarray) -> np.ndarray: if self.obs_rms: return self.obs_rms.norm(obs) # type: ignore return obs def set_obs_rms(self, obs_rms: RunningMeanStd) -> None: """Set with given observation running mean/std.""" self.obs_rms = obs_rms def get_obs_rms(self) -> RunningMeanStd: """Return observation running mean/std.""" return self.obs_rms ================================================ FILE: tianshou/env/venvs.py ================================================ from collections.abc import Callable, Sequence from typing import Any, Literal import gymnasium as gym import numpy as np import torch from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type from tianshou.env.worker import ( DummyEnvWorker, EnvWorker, RayEnvWorker, SubprocEnvWorker, ) GYM_RESERVED_KEYS = [ "metadata", "reward_range", "spec", "action_space", "observation_space", ] class BaseVectorEnv: """Base class for vectorized environments. Usage: :: env_num = 8 envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)]) assert len(envs) == env_num It accepts a list of environment generators. In other words, an environment generator ``efn`` of a specific task means that ``efn()`` returns the environment of the given task, for example, ``gym.make(task)``. All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`. Here are some other usages: :: envs.seed(2) # which is equal to the next line envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env obs = envs.reset() # reset all environments obs = envs.reset([0, 5, 7]) # reset 3 specific environments obs, rew, done, info = envs.step([1] * 8) # step synchronously envs.render() # render all environments envs.close() # close all environments .. warning:: If you use your own environment, please make sure the ``seed`` method is set up properly, e.g., :: def seed(self, seed): np.random.seed(seed) Otherwise, the outputs of these envs may be the same with each other. :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the i-th env. :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a worker which contains the i-th env. :param wait_num: use in asynchronous simulation if the time cost of ``env.step`` varies with time and synchronously waiting for all environments to finish a step is time-wasting. In that case, we can return when ``wait_num`` environments finish a step and keep on simulation in these environments. If ``None``, asynchronous simulation is disabled; else, ``1 <= wait_num <= env_num``. :param timeout: use in asynchronous simulation same as above, in each vectorized step it only deal with those environments spending time within ``timeout`` seconds. """ def __init__( self, env_fns: Sequence[Callable[[], ENV_TYPE]], worker_fn: Callable[[Callable[[], ENV_TYPE]], EnvWorker], wait_num: int | None = None, timeout: float | None = None, ) -> None: self._env_fns = env_fns # A VectorEnv contains a pool of EnvWorkers, which corresponds to # interact with the given envs (one worker <-> one env). self.workers = [worker_fn(fn) for fn in env_fns] self.worker_class = type(self.workers[0]) assert issubclass(self.worker_class, EnvWorker) assert all(isinstance(w, self.worker_class) for w in self.workers) self.env_num = len(env_fns) self.wait_num = wait_num or len(env_fns) assert 1 <= self.wait_num <= len(env_fns), ( f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}" ) self.timeout = timeout assert self.timeout is None or self.timeout > 0, ( f"timeout is {timeout}, it should be positive if provided!" ) self.is_async = self.wait_num != len(env_fns) or timeout is not None self.waiting_conn: list[EnvWorker] = [] # environments in self.ready_id is actually ready # but environments in self.waiting_id are just waiting when checked, # and they may be ready now, but this is not known until we check it # in the step() function self.waiting_id: list[int] = [] # all environments are ready in the beginning self.ready_id = list(range(self.env_num)) self.is_closed = False def _assert_is_not_closed(self) -> None: assert not self.is_closed, ( f"Methods of {self.__class__.__name__} cannot be called after close." ) def __len__(self) -> int: """Return len(self), which is the number of environments.""" return self.env_num def __getattribute__(self, key: str) -> Any: """Switch the attribute getter depending on the key. Any class who inherits ``gym.Env`` will inherit some attributes, like ``action_space``. However, we would like the attribute lookup to go straight into the worker (in fact, this vector env's action_space is always None). """ if key in GYM_RESERVED_KEYS: # reserved keys in gym.Env return self.get_env_attr(key) return super().__getattribute__(key) def get_env_attr( self, key: str, id: int | list[int] | np.ndarray | None = None, ) -> list[Any]: """Get an attribute from the underlying environments. If id is an int, retrieve the attribute denoted by key from the environment underlying the worker at index id. The result is returned as a list with one element. Otherwise, retrieve the attribute for all workers at indices id and return a list that is ordered correspondingly to id. :param str key: The key of the desired attribute. :param id: Indice(s) of the desired worker(s). Default to None for all env_id. :return list: The list of environment attributes. """ self._assert_is_not_closed() id = self._wrap_id(id) if self.is_async: self._assert_id(id) return [self.workers[j].get_env_attr(key) for j in id] def set_env_attr( self, key: str, value: Any, id: int | list[int] | np.ndarray | None = None, ) -> None: """Set an attribute in the underlying environments. If id is an int, set the attribute denoted by key from the environment underlying the worker at index id to value. Otherwise, set the attribute for all workers at indices id. :param str key: The key of the desired attribute. :param Any value: The new value of the attribute. :param id: Indice(s) of the desired worker(s). Default to None for all env_id. """ self._assert_is_not_closed() id = self._wrap_id(id) if self.is_async: self._assert_id(id) for j in id: self.workers[j].set_env_attr(key, value) def _wrap_id( self, id: int | list[int] | np.ndarray | None = None, ) -> list[int] | np.ndarray: if id is None: return list(range(self.env_num)) return [id] if np.isscalar(id) else id # type: ignore def _assert_id(self, id: list[int] | np.ndarray) -> None: for i in id: assert i not in self.waiting_id, ( f"Cannot interact with environment {i} which is stepping now." ) assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}." # TODO: for now, has to be kept in sync with reset in EnvPoolMixin # In particular, can't rename env_id to env_ids def reset( self, env_id: int | list[int] | np.ndarray | None = None, **kwargs: Any, ) -> tuple[np.ndarray, np.ndarray]: """Reset the state of some envs and return initial observations. If id is None, reset the state of all the environments and return initial observations, otherwise reset the specific environments with the given id, either an int or a list. """ self._assert_is_not_closed() env_id = self._wrap_id(env_id) if self.is_async: self._assert_id(env_id) # send(None) == reset() in worker for id in env_id: self.workers[id].send(None, **kwargs) ret_list = [self.workers[id].recv() for id in env_id] assert ( isinstance(ret_list[0], tuple | list) and len(ret_list[0]) == 2 and isinstance(ret_list[0][1], dict) ), "The environment does not adhere to the Gymnasium's API." obs_list = [r[0] for r in ret_list] if isinstance(obs_list[0], tuple): # type: ignore raise TypeError( "Tuple observation space is not supported. ", "Please change it to array or dict space", ) try: obs = np.stack(obs_list) except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) infos = np.array([r[1] for r in ret_list]) return obs, infos def step( self, action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: """Run one timestep of some environments' dynamics. If id is None, run one timestep of all the environments` dynamics; otherwise run one timestep for some environments with given id, either an int or a list. When the end of episode is reached, you are responsible for calling reset(id) to reset this environment`s state. Accept a batch of action and return a tuple (batch_obs, batch_rew, batch_done, batch_info) in numpy format. :param numpy.ndarray action: a batch of action provided by the agent. If the venv is async, the action can be None, which will result in all arrays in the returned tuple being empty. :return: A tuple consisting of either: * ``obs`` a numpy.ndarray, the agent's observation of current environments * ``rew`` a numpy.ndarray, the amount of rewards returned after \ previous actions * ``terminated`` a numpy.ndarray, whether these episodes have been \ terminated * ``truncated`` a numpy.ndarray, whether these episodes have been truncated * ``info`` a numpy.ndarray, contains auxiliary diagnostic \ information (helpful for debugging, and sometimes learning) For the async simulation: Provide the given action to the environments. The action sequence should correspond to the ``id`` argument, and the ``id`` argument should be a subset of the ``env_id`` in the last returned ``info`` (initially they are env_ids of all the environments). If action is None, fetch unfinished step() calls instead. """ self._assert_is_not_closed() id = self._wrap_id(id) if not self.is_async: if action is None: raise ValueError("action must be not-None for non-async") assert len(action) == len(id) for i, j in enumerate(id): self.workers[j].send(action[i]) result = [] for j in id: env_return = self.workers[j].recv() env_return[-1]["env_id"] = j result.append(env_return) else: if action is not None: self._assert_id(id) assert len(action) == len(id) for act, env_id in zip(action, id, strict=True): self.workers[env_id].send(act) self.waiting_conn.append(self.workers[env_id]) self.waiting_id.append(env_id) self.ready_id = [x for x in self.ready_id if x not in id] ready_conns: list[EnvWorker] = [] while not ready_conns: ready_conns = self.worker_class.wait(self.waiting_conn, self.wait_num, self.timeout) result = [] for conn in ready_conns: waiting_index = self.waiting_conn.index(conn) self.waiting_conn.pop(waiting_index) env_id = self.waiting_id.pop(waiting_index) # env_return can be (obs, reward, done, info) or # (obs, reward, terminated, truncated, info) env_return = conn.recv() env_return[-1]["env_id"] = env_id # Add `env_id` to info result.append(env_return) self.ready_id.append(env_id) obs_list, rew_list, term_list, trunc_list, info_list = tuple(zip(*result, strict=True)) try: obs_stack = np.stack(obs_list) except ValueError: # different len(obs) obs_stack = np.array(obs_list, dtype=object) return ( obs_stack, np.stack(rew_list), np.stack(term_list), np.stack(trunc_list), np.stack(info_list), ) def seed(self, seed: int | list[int] | None = None) -> list[list[int] | None]: """Set the seed for all environments. Accept ``None``, an int (which will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list. :return: The list of seeds used in this env's random number generators. The first value in the list should be the "main" seed, or the value which a reproducer pass to "seed". """ self._assert_is_not_closed() seed_list: list[None] | list[int] if seed is None: seed_list = [seed] * self.env_num elif isinstance(seed, int): seed_list = [seed + i for i in range(self.env_num)] else: seed_list = seed return [w.seed(s) for w, s in zip(self.workers, seed_list, strict=True)] def render(self, **kwargs: Any) -> list[Any]: """Render all of the environments.""" self._assert_is_not_closed() if self.is_async and len(self.waiting_id) > 0: raise RuntimeError( f"Environments {self.waiting_id} are still stepping, cannot render them now.", ) return [w.render(**kwargs) for w in self.workers] def close(self) -> None: """Close all of the environments. This function will be called only once (if not, it will be called during garbage collected). This way, ``close`` of all workers can be assured. """ self._assert_is_not_closed() for w in self.workers: w.close() self.is_closed = True class DummyVectorEnv(BaseVectorEnv): """Dummy vectorized environment wrapper, implemented in for-loop. This has the same interface as true vectorized environment, but the rollout does not happen in parallel. So, all workers just wait for each other and the environment is as efficient as using a single environment. This can be useful for testing or for demonstration purposes. A rare use-case would be using vector based interface, but parallelization is not desired (e.g. because of too much overhead). However, in such cases one should consider using a single environment. .. seealso:: Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ def __init__( self, env_fns: Sequence[Callable[[], ENV_TYPE]], wait_num: int | None = None, timeout: float | None = None, ) -> None: super().__init__(env_fns, DummyEnvWorker, wait_num, timeout) class SubprocVectorEnv(BaseVectorEnv): """Vectorized environment wrapper based on subprocess. .. seealso:: Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. Additional arguments are: :param share_memory: whether to share memory between the main process and the worker process. Allows for shared buffers to exchange observations :param context: the context to use for multiprocessing. Usually it's fine to use the default context, but `spawn` as well as `fork` can have non-obvious side effects, see for example https://github.com/google-deepmind/mujoco/issues/742, or https://github.com/Farama-Foundation/Gymnasium/issues/222. Consider using 'fork' when using macOS and additional parallelization, for example via joblib. Defaults to None, which will use the default system context. """ def __init__( self, env_fns: Sequence[Callable[[], ENV_TYPE]], wait_num: int | None = None, timeout: float | None = None, share_memory: bool = False, context: Literal["fork", "spawn"] | None = None, ) -> None: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=share_memory, context=context) super().__init__( env_fns, worker_fn, wait_num, timeout, ) class ShmemVectorEnv(BaseVectorEnv): """Optimized SubprocVectorEnv with shared buffers to exchange observations. ShmemVectorEnv has exactly the same API as SubprocVectorEnv. .. seealso:: Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ def __init__( self, env_fns: Sequence[Callable[[], ENV_TYPE]], wait_num: int | None = None, timeout: float | None = None, ) -> None: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=True) super().__init__(env_fns, worker_fn, wait_num, timeout) class RayVectorEnv(BaseVectorEnv): """Vectorized environment wrapper based on ray. This is a choice to run distributed environments in a cluster. .. seealso:: Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ def __init__( self, env_fns: Sequence[Callable[[], ENV_TYPE]], wait_num: int | None = None, timeout: float | None = None, ) -> None: try: import ray except ImportError as exception: raise ImportError( "Please install ray to support RayVectorEnv: pip install ray", ) from exception if not ray.is_initialized(): ray.init() super().__init__(env_fns, lambda env_fn: RayEnvWorker(env_fn), wait_num, timeout) ================================================ FILE: tianshou/env/worker/__init__.py ================================================ # isort:skip_file # NOTE: Import order is important to avoid circular import errors! from tianshou.env.worker.worker_base import EnvWorker from tianshou.env.worker.dummy import DummyEnvWorker from tianshou.env.worker.ray import RayEnvWorker from tianshou.env.worker.subproc import SubprocEnvWorker __all__ = [ "DummyEnvWorker", "EnvWorker", "RayEnvWorker", "SubprocEnvWorker", ] ================================================ FILE: tianshou/env/worker/dummy.py ================================================ from collections.abc import Callable from typing import Any import gymnasium as gym import numpy as np from tianshou.env.worker import EnvWorker class DummyEnvWorker(EnvWorker): """Dummy worker used in sequential vector environments.""" def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self.env = env_fn() super().__init__(env_fn) def get_env_attr(self, key: str) -> Any: return getattr(self.env.unwrapped, key) def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env.unwrapped, key, value) def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: if "seed" in kwargs: super().seed(kwargs["seed"]) return self.env.reset(**kwargs) @staticmethod def wait( # type: ignore workers: list["DummyEnvWorker"], wait_num: int, timeout: float | None = None, ) -> list["DummyEnvWorker"]: # Sequential EnvWorker objects are always ready return workers def send(self, action: np.ndarray | None, **kwargs: Any) -> None: if action is None: self.result = self.env.reset(**kwargs) else: self.result = self.env.step(action) # type: ignore def seed(self, seed: int | None = None) -> list[int] | None: super().seed(seed) try: return self.env.seed(seed) # type: ignore except (AttributeError, NotImplementedError): self.env.reset(seed=seed) return [seed] # type: ignore def render(self, **kwargs: Any) -> Any: return self.env.render(**kwargs) def close_env(self) -> None: self.env.close() ================================================ FILE: tianshou/env/worker/ray.py ================================================ # mypy: disable-error-code=unused-ignore import contextlib from collections.abc import Callable from typing import Any import gymnasium as gym import numpy as np from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type from tianshou.env.worker import EnvWorker with contextlib.suppress(ImportError): import ray class _SetAttrWrapper(gym.Wrapper): def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env.unwrapped, key, value) def get_env_attr(self, key: str) -> Any: return getattr(self.env, key) class RayEnvWorker(EnvWorker): """Ray worker used in RayVectorEnv.""" def __init__( self, env_fn: Callable[[], ENV_TYPE], ) -> None: # TODO: is ENV_TYPE actually correct? self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn()) # type: ignore super().__init__(env_fn) def get_env_attr(self, key: str) -> Any: return ray.get(self.env.get_env_attr.remote(key)) # type: ignore def set_env_attr(self, key: str, value: Any) -> None: ray.get(self.env.set_env_attr.remote(key, value)) # type: ignore def reset(self, **kwargs: Any) -> Any: if "seed" in kwargs: super().seed(kwargs["seed"]) return ray.get(self.env.reset.remote(**kwargs)) # type: ignore @staticmethod def wait( # type: ignore workers: list["RayEnvWorker"], wait_num: int, timeout: float | None = None, ) -> list["RayEnvWorker"]: results = [x.result for x in workers] ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) # type: ignore return [workers[results.index(result)] for result in ready_results] def send(self, action: np.ndarray | None, **kwargs: Any) -> None: # self.result is actually a handle if action is None: self.result = self.env.reset.remote(**kwargs) # type: ignore else: self.result = self.env.step.remote(action) # type: ignore def recv(self) -> gym_new_venv_step_type: return ray.get(self.result) # type: ignore def seed(self, seed: int | None = None) -> list[int] | None: super().seed(seed) try: return ray.get(self.env.seed.remote(seed)) # type: ignore except (AttributeError, NotImplementedError): self.env.reset.remote(seed=seed) # type: ignore return None def render(self, **kwargs: Any) -> Any: return ray.get(self.env.render.remote(**kwargs)) # type: ignore def close_env(self) -> None: ray.get(self.env.close.remote()) # type: ignore ================================================ FILE: tianshou/env/worker/subproc.py ================================================ import ctypes import multiprocessing import time from collections.abc import Callable from multiprocessing import connection from multiprocessing.context import BaseContext from typing import Any, Literal import gymnasium as gym import numpy as np from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type from tianshou.env.worker import EnvWorker # mypy: disable-error-code="unused-ignore" _NP_TO_CT = { np.bool_: ctypes.c_bool, np.uint8: ctypes.c_uint8, np.uint16: ctypes.c_uint16, np.uint32: ctypes.c_uint32, np.uint64: ctypes.c_uint64, np.int8: ctypes.c_int8, np.int16: ctypes.c_int16, np.int32: ctypes.c_int32, np.int64: ctypes.c_int64, np.float32: ctypes.c_float, np.float64: ctypes.c_double, } class ShArray: """Wrapper of multiprocessing Array. Example usage: :: import numpy as np import multiprocessing as mp from tianshou.env.worker.subproc import ShArray ctx = mp.get_context('fork') # set an explicit context arr = ShArray(np.dtype(np.float32), (2, 3), ctx) arr.save(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)) print(arr.get()) """ def __init__(self, dtype: np.generic, shape: tuple[int], ctx: BaseContext | None) -> None: if ctx is None: ctx = multiprocessing.get_context() self.arr = ctx.Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore self.dtype = dtype self.shape = shape def save(self, ndarray: np.ndarray) -> None: assert isinstance(ndarray, np.ndarray) dst = self.arr.get_obj() dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape) # type: ignore np.copyto(dst_np, ndarray) def get(self) -> np.ndarray: obj = self.arr.get_obj() return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore def _setup_buf(space: gym.Space, ctx: BaseContext) -> dict | tuple | ShArray: if isinstance(space, gym.spaces.Dict): return {k: _setup_buf(v, ctx) for k, v in space.spaces.items()} if isinstance(space, gym.spaces.Tuple): assert isinstance(space.spaces, tuple) return tuple([_setup_buf(t, ctx) for t in space.spaces]) return ShArray(space.dtype, space.shape, ctx) # type: ignore def _worker( parent: connection.Connection, p: connection.Connection, env_fn_wrapper: CloudpickleWrapper, obs_bufs: dict | tuple | ShArray | None = None, ) -> None: def _encode_obs( obs: dict | tuple | np.ndarray, buffer: dict | tuple | ShArray, ) -> None: if isinstance(buffer, ShArray): # if buffer is an ShArray, obs must be array-like obs = np.asarray(obs, dtype=buffer.dtype) buffer.save(obs) elif isinstance(obs, tuple) and isinstance(buffer, tuple): for o, b in zip(obs, buffer, strict=True): _encode_obs(o, b) elif isinstance(obs, dict) and isinstance(buffer, dict): for k in obs: _encode_obs(obs[k], buffer[k]) parent.close() env = env_fn_wrapper.data() try: while True: try: cmd, data = p.recv() except EOFError: # the pipe has been closed p.close() break if cmd == "step": env_return = env.step(data) if obs_bufs is not None: _encode_obs(env_return[0], obs_bufs) env_return = (None, *env_return[1:]) p.send(env_return) elif cmd == "reset": obs, info = env.reset(**data) if obs_bufs is not None: _encode_obs(obs, obs_bufs) obs = None p.send((obs, info)) elif cmd == "close": p.send(env.close()) p.close() break elif cmd == "render": p.send(env.render(**data) if hasattr(env, "render") else None) elif cmd == "seed": if hasattr(env, "seed"): p.send(env.seed(data)) else: env.action_space.seed(seed=data) env.reset(seed=data) p.send(None) elif cmd == "getattr": p.send(getattr(env, data) if hasattr(env, data) else None) elif cmd == "setattr": setattr(env.unwrapped, data["key"], data["value"]) else: p.close() raise NotImplementedError except KeyboardInterrupt: p.close() class SubprocEnvWorker(EnvWorker): """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.""" def __init__( self, env_fn: Callable[[], gym.Env], share_memory: bool = False, context: BaseContext | Literal["fork", "spawn"] | None = None, ) -> None: if not isinstance(context, BaseContext): context = multiprocessing.get_context(context) self.parent_remote, self.child_remote = context.Pipe() self.share_memory = share_memory self.buffer: dict | tuple | ShArray | None = None assert hasattr(context, "Process") # for mypy if self.share_memory: dummy = env_fn() obs_space = dummy.observation_space dummy.close() del dummy self.buffer = _setup_buf(obs_space, context) args = ( self.parent_remote, self.child_remote, CloudpickleWrapper(env_fn), self.buffer, ) self.process = context.Process(target=_worker, args=args, daemon=True) self.process.start() self.child_remote.close() super().__init__(env_fn) def get_env_attr(self, key: str) -> Any: self.parent_remote.send(["getattr", key]) return self.parent_remote.recv() def set_env_attr(self, key: str, value: Any) -> None: self.parent_remote.send(["setattr", {"key": key, "value": value}]) def _decode_obs(self) -> dict | tuple | np.ndarray: def decode_obs( buffer: dict | tuple | ShArray | None, ) -> dict | tuple | np.ndarray: if isinstance(buffer, ShArray): return buffer.get() if isinstance(buffer, tuple): return tuple([decode_obs(b) for b in buffer]) if isinstance(buffer, dict): return {k: decode_obs(v) for k, v in buffer.items()} raise NotImplementedError return decode_obs(self.buffer) @staticmethod def wait( # type: ignore workers: list["SubprocEnvWorker"], wait_num: int, timeout: float | None = None, ) -> list["SubprocEnvWorker"]: remain_conns = conns = [x.parent_remote for x in workers] ready_conns: list[connection.Connection] = [] remain_time, t1 = timeout, time.time() while len(remain_conns) > 0 and len(ready_conns) < wait_num: if timeout: remain_time = timeout - (time.time() - t1) if remain_time <= 0: break # connection.wait hangs if the list is empty new_ready_conns = connection.wait(remain_conns, timeout=remain_time) # type: ignore ready_conns.extend(new_ready_conns) # type: ignore remain_conns = [conn for conn in remain_conns if conn not in ready_conns] # type: ignore return [workers[conns.index(con)] for con in ready_conns] # type: ignore def send(self, action: np.ndarray | None, **kwargs: Any) -> None: if action is None: if "seed" in kwargs: super().seed(kwargs["seed"]) self.parent_remote.send(["reset", kwargs]) else: self.parent_remote.send(["step", action]) def recv(self) -> gym_new_venv_step_type | tuple[np.ndarray, dict]: result = self.parent_remote.recv() if isinstance(result, tuple): if len(result) == 2: obs, info = result if self.share_memory: obs = self._decode_obs() return obs, info obs = result[0] if self.share_memory: obs = self._decode_obs() # TODO: figure out the typing issue, simplify and document this method return (obs, *result[1:]) obs = result if self.share_memory: obs = self._decode_obs() return obs def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: if "seed" in kwargs: super().seed(kwargs["seed"]) self.parent_remote.send(["reset", kwargs]) result = self.parent_remote.recv() if isinstance(result, tuple): obs, info = result if self.share_memory: obs = self._decode_obs() return obs, info obs = result if self.share_memory: obs = self._decode_obs() return obs def seed(self, seed: int | None = None) -> list[int] | None: super().seed(seed) self.parent_remote.send(["seed", seed]) return self.parent_remote.recv() def render(self, **kwargs: Any) -> Any: self.parent_remote.send(["render", kwargs]) return self.parent_remote.recv() def close_env(self) -> None: try: self.parent_remote.send(["close", None]) # mp may be deleted so it may raise AttributeError self.parent_remote.recv() self.process.join() except (BrokenPipeError, EOFError, AttributeError): pass # ensure the subproc is terminated self.process.terminate() ================================================ FILE: tianshou/env/worker/worker_base.py ================================================ from abc import ABC, abstractmethod from collections.abc import Callable from typing import Any import gymnasium as gym import numpy as np from tianshou.env.utils import gym_new_venv_step_type class EnvWorker(ABC): """An abstract worker for an environment.""" def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False self.result: gym_new_venv_step_type | tuple[np.ndarray, dict] self.action_space = self.get_env_attr("action_space") self.is_reset = False @abstractmethod def get_env_attr(self, key: str) -> Any: pass @abstractmethod def set_env_attr(self, key: str, value: Any) -> None: pass @abstractmethod def send(self, action: np.ndarray | None) -> None: """Send action signal to low-level worker. When action is None, it indicates sending "reset" signal; otherwise it indicates "step" signal. The paired return value from "recv" function is determined by such kind of different signal. """ def recv(self) -> gym_new_venv_step_type | tuple[np.ndarray, dict]: """Receive result from low-level worker. If the last "send" function sends a NULL action, it only returns a single observation; otherwise it returns a tuple of (obs, rew, done, info) or (obs, rew, terminated, truncated, info), based on whether the environment is using the old step API or the new one. """ return self.result @abstractmethod def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: pass def step(self, action: np.ndarray) -> gym_new_venv_step_type: """Perform one timestep of the environment's dynamic. "send" and "recv" are coupled in sync simulation, so users only call "step" function. But they can be called separately in async simulation, i.e. someone calls "send" first, and calls "recv" later. """ self.send(action) return self.recv() # type: ignore @staticmethod def wait( workers: list["EnvWorker"], wait_num: int, timeout: float | None = None, ) -> list["EnvWorker"]: """Given a list of workers, return those ready ones.""" raise NotImplementedError def seed(self, seed: int | None = None) -> list[int] | None: """ Seeds the environment's action space sampler. NOTE: This does *not* seed the environment itself. :param seed: the random seed :return: a list containing the resulting seed used """ return self.action_space.seed(seed) @abstractmethod def render(self, **kwargs: Any) -> Any: """Render the environment.""" @abstractmethod def close_env(self) -> None: pass def close(self) -> None: if self.is_closed: return self.is_closed = True self.close_env() ================================================ FILE: tianshou/evaluation/__init__.py ================================================ ================================================ FILE: tianshou/evaluation/launcher.py ================================================ """Provides a basic interface for launching experiments. The API is experimental and subject to change!.""" import logging from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from copy import copy from dataclasses import asdict, dataclass from enum import Enum from typing import TYPE_CHECKING, Literal from joblib import Parallel, delayed from tianshou.data import InfoStats if TYPE_CHECKING: from tianshou.highlevel.experiment import Experiment log = logging.getLogger(__name__) @dataclass class JoblibConfig: n_jobs: int = -1 """The maximum number of concurrently running jobs. If -1, all CPUs are used.""" backend: Literal["loky", "multiprocessing", "threading"] | None = "loky" """Allows to hard-code backend, None means it will be inferred automatically.""" verbose: int = 10 """If greater than zero, prints progress messages.""" def default_experiment_execution(exp: "Experiment") -> InfoStats | None: """The default execution simply runs the experiment and returns the trainer result.""" return exp.run().trainer_result class ExpLauncher(ABC): """Base interface for launching multiple experiments simultaneously.""" def __init__( self, experiment_runner: Callable[ ["Experiment"], InfoStats | None ] = default_experiment_execution, ): """ :param experiment_runner: determines how an experiment is to be executed. Overriding the default can be useful, e.g., for using high-level interfaces to set up an experiment (or an experiment collection) and tinkering with it prior to execution. This need often arises when prototyping with mechanisms that are not yet supported by the high-level interfaces. Deviation from the default allows arbitrary things to happen during experiment execution, so use this option with caution!. """ self.experiment_runner = experiment_runner def get_name(self) -> str: """Returns the name of the launcher.""" return self.__class__.__name__.replace("Launcher", "").lower() @abstractmethod def _launch(self, experiments: Sequence["Experiment"]) -> list[InfoStats | None]: """Should call `self.experiment_runner` for each experiment in experiments and aggregate the results.""" def _safe_execute(self, exp: "Experiment") -> InfoStats | None | Literal["failed"]: try: return self.experiment_runner(exp) except BaseException as e: log.error(f"Failed to run experiment {exp}.", exc_info=e) return "failed" @staticmethod def _return_from_successful_and_failed_exps( successful_exp_stats: list[InfoStats | None], failed_exps: list["Experiment"], ) -> list[InfoStats | None]: if not successful_exp_stats: raise RuntimeError("All experiments failed, see error logs for more details.") if failed_exps: log.error( f"Failed to run the following " f"{len(failed_exps)}/{len(successful_exp_stats) + len(failed_exps)} experiments: {failed_exps}. " f"See the logs for more details. " f"Returning the results of {len(successful_exp_stats)} successful experiments.", ) return successful_exp_stats def launch(self, experiments: Sequence["Experiment"]) -> list[InfoStats | None]: """Will return the results of successfully executed experiments. If a single experiment is passed, will not use parallelism and run it in the main process. Failed experiments will be logged, and a RuntimeError is only raised if all experiments have failed. """ if len(experiments) == 1: log.info( "A single experiment is being run, will not use parallelism and run it in the main process.", ) return [self.experiment_runner(experiments[0])] return self._launch(experiments) class SequentialExpLauncher(ExpLauncher): """Convenience wrapper around a simple for loop to run experiments sequentially.""" def _launch(self, experiments: Sequence["Experiment"]) -> list[InfoStats | None]: successful_exp_stats = [] failed_exps = [] for exp in experiments: exp_stats = self._safe_execute(exp) if exp_stats == "failed": failed_exps.append(exp) else: successful_exp_stats.append(exp_stats) # noinspection PyTypeChecker return self._return_from_successful_and_failed_exps(successful_exp_stats, failed_exps) class JoblibExpLauncher(ExpLauncher): def __init__( self, joblib_cfg: JoblibConfig | None = None, experiment_runner: Callable[ ["Experiment"], InfoStats | None ] = default_experiment_execution, ) -> None: super().__init__(experiment_runner=experiment_runner) self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig() # Joblib's backend is hard-coded to loky since the threading backend produces different results # TODO: fix this if self.joblib_cfg.backend != "loky": log.warning( f"Ignoring the user provided joblib backend {self.joblib_cfg.backend} and using loky instead. " f"The current implementation requires loky to work and will be relaxed soon", ) self.joblib_cfg.backend = "loky" def _launch(self, experiments: Sequence["Experiment"]) -> list[InfoStats | None]: results = Parallel(**asdict(self.joblib_cfg))( delayed(self._safe_execute)(exp) for exp in experiments ) successful_exps = [] failed_exps = [] for exp, result in zip(experiments, results, strict=True): if result == "failed": failed_exps.append(exp) else: successful_exps.append(result) return self._return_from_successful_and_failed_exps(successful_exps, failed_exps) class RegisteredExpLauncher(Enum): JOBLIB = "JOBLIB" SEQUENTIAL = "SEQUENTIAL" def create_launcher(self) -> ExpLauncher: match self: case RegisteredExpLauncher.JOBLIB: return JoblibExpLauncher() case RegisteredExpLauncher.SEQUENTIAL: return SequentialExpLauncher() case _: raise NotImplementedError( f"Launcher {self} is not yet implemented.", ) ================================================ FILE: tianshou/evaluation/rliable_evaluation.py ================================================ """The rliable-evaluation module provides a high-level interface to evaluate the results of an experiment with multiple runs on different seeds using the rliable library. The API is experimental and subject to change!. """ import json import os from collections.abc import Iterator from dataclasses import asdict, dataclass, fields from typing import Literal, cast import matplotlib.pyplot as plt import numpy as np import scipy.stats as sst from rliable import library as rly from rliable import plot_utils from sensai.util import logging from tianshou.utils import TensorboardLogger from tianshou.utils.logger.logger_base import DataScope log = logging.getLogger(__name__) @dataclass class EvaluationSequenceEntry: """A single entry in an evaluation sequence, representing data collected at a fixed environment step. """ # the structure expected in benchmark.js env_step: int """The number of environment steps at which the evaluation was performed.""" rew: float """The mean episode return at the given env_step. Called `rew` (confusingly) to be consistent with Tianshou's internal naming conventions.""" rew_std: float """The standard deviation of the episode returns at the given env_step, computed from multiple runs.""" iqm: float """The interquartile mean (IQM) of the episode returns at the given env_step, computed from multiple runs.""" iqm_confidence_interval: tuple[float, float] """The 95% confidence interval of the IQM of the episode returns at the given env_step.""" @dataclass class LoggedSummaryData: mean: np.ndarray std: np.ndarray max: np.ndarray min: np.ndarray @dataclass class LoggedCollectStats: env_step: np.ndarray | None = None n_collected_episodes: np.ndarray | None = None n_collected_steps: np.ndarray | None = None collect_time: np.ndarray | None = None collect_speed: np.ndarray | None = None returns_stat: LoggedSummaryData | None = None lens_stat: LoggedSummaryData | None = None @classmethod def from_data_dict(cls, data: dict) -> "LoggedCollectStats": """Create a LoggedCollectStats object from a dictionary. Converts SequenceSummaryStats from dict format to dataclass format and ignores fields that are not present. """ dataclass_data = {} field_names = [f.name for f in fields(cls)] for k, v in data.items(): if k not in field_names: log.info( f"Key {k} in data dict is not a valid field of LoggedCollectStats, ignoring it.", ) continue if isinstance(v, dict): v = LoggedSummaryData(**v) dataclass_data[k] = v return cls(**dataclass_data) @dataclass class MultiRunExperimentResult: """The result of multiple runs of an experiment (runs usually just differing by random seeds) that can be used with the rliable library. Glossary: - R: number of runs (typically, equal to the number of different seeds) - E: number of environment steps at which evaluation results were computed, i.e., the evaluation points n_1, n_2, ..., n_E """ exp_dir: str """The base directory where each sub-directory contains the results of one experiment run.""" exp_name: str """The name of the experiment, typically the name of the algorithm or the experiment directory basename.""" test_episode_returns_RE: np.ndarray """The test episode returns for each run of the experiment, where each row corresponds to one run.""" training_episode_returns_RE: np.ndarray """The training episode returns for each run of the experiment, where each row corresponds to one run.""" test_env_steps_E: np.ndarray """The environment steps at which the test episodes were evaluated.""" training_env_steps_E: np.ndarray """The environment steps at which the training episodes were evaluated.""" @classmethod def load_from_disk( cls, exp_dir: str, exp_name: str | None = None, max_env_step: int | None = None, ) -> "MultiRunExperimentResult": """Load the experiment result from disk. :param exp_dir: The directory from where the experiment results are restored. :param exp_name: The name of the experiment. If not passed, will be inferred from the experiment directory name. :param max_env_step: The maximum number of environment steps to consider. If None, all data is considered. Note: if the experiments have different numbers of steps, the minimum number is used. """ test_episode_returns_RE = [] training_episode_returns_RE = [] test_env_steps_E = None """The number of steps of the test run, will try extracting it either from the loaded stats or from loaded arrays.""" training_env_steps_E = None """The number of steps of the training run, will try extracting it from the loaded stats or from loaded arrays.""" if exp_name is None: exp_name = os.path.basename(os.path.normpath(exp_dir)) from tianshou.highlevel.experiment import Experiment # TODO: test_env_steps_E should not be defined in a loop and overwritten at each iteration # just for retrieving them. We might need a cleaner directory structure. for entry in os.scandir(exp_dir): if entry.name.startswith(".") or not entry.is_dir(): continue try: logger_factory = Experiment.from_directory(entry.path).logger_factory # only retrieve logger class to prevent creating another tfevent file logger_cls = logger_factory.get_logger_class() # Usually this means from low-level API except FileNotFoundError: log.info( f"Could not find persisted experiment in {entry.path}, using default logger.", ) logger_cls = TensorboardLogger data = logger_cls.restore_logged_data(entry.path) if not data: raise ValueError(f"Could not restore data from {entry.path}.") if DataScope.TEST not in data or not data[DataScope.TEST]: continue restored_test_data = data[DataScope.TEST] restored_train_data = data[DataScope.TRAINING] assert isinstance(restored_test_data, dict) assert isinstance(restored_train_data, dict) for restored_data, scope in zip( [restored_test_data, restored_train_data], [DataScope.TEST, DataScope.TRAINING], strict=True, ): if not isinstance(restored_data, dict): raise RuntimeError( f"Expected entry with key {scope} data to be a dictionary, " f"but got {restored_data=}.", ) test_data = LoggedCollectStats.from_data_dict(restored_test_data) training_data = LoggedCollectStats.from_data_dict(restored_train_data) if test_data.returns_stat is not None: test_episode_returns_RE.append(test_data.returns_stat.mean) test_env_steps_E = test_data.env_step if training_data.returns_stat is not None: training_episode_returns_RE.append(training_data.returns_stat.mean) training_env_steps_E = training_data.env_step test_data_found = True training_data_found = True if not test_episode_returns_RE or test_env_steps_E is None: log.warning(f"No test experiment data found in {exp_dir}.") test_data_found = False if not training_episode_returns_RE or training_env_steps_E is None: log.warning(f"No train experiment data found in {exp_dir}.") training_data_found = False if not test_data_found and not training_data_found: raise RuntimeError(f"No test or train data found in {exp_dir}.") min_training_data_len = min([len(arr) for arr in training_episode_returns_RE]) if max_env_step is not None: min_training_data_len = min(min_training_data_len, max_env_step) min_test_data_len = min([len(arr) for arr in test_episode_returns_RE]) if max_env_step is not None: min_test_data_len = min(min_test_data_len, max_env_step) assert test_env_steps_E is not None assert training_env_steps_E is not None test_env_steps_E = test_env_steps_E[:min_test_data_len] training_env_steps_E = training_env_steps_E[:min_training_data_len] if max_env_step: # find the index at which the maximum env step is reached with searchsorted min_test_data_len = int(np.searchsorted(test_env_steps_E, max_env_step)) min_training_data_len = int(np.searchsorted(training_env_steps_E, max_env_step)) test_env_steps_E = test_env_steps_E[:min_test_data_len] training_env_steps_E = training_env_steps_E[:min_training_data_len] test_episode_returns_RE = np.array( [arr[:min_test_data_len] for arr in test_episode_returns_RE] ) training_episode_returns_RE = np.array( [arr[:min_training_data_len] for arr in training_episode_returns_RE] ) return cls( test_episode_returns_RE=test_episode_returns_RE, test_env_steps_E=test_env_steps_E, exp_dir=exp_dir, exp_name=exp_name, training_episode_returns_RE=training_episode_returns_RE, training_env_steps_E=training_env_steps_E, ) def _get_env_steps_and_returns( self, scope: DataScope = DataScope.TEST, ) -> tuple[np.ndarray, np.ndarray]: if scope == DataScope.TEST: return self.test_env_steps_E, self.test_episode_returns_RE elif scope == DataScope.TRAINING: return self.training_env_steps_E, self.training_episode_returns_RE else: raise ValueError(f"Invalid scope {scope}, should be either 'TEST' or 'TRAINING'.") def _get_data_in_rliable_format( self, algo_name: str | None = None, score_thresholds: np.ndarray | None = None, scope: DataScope = DataScope.TEST, ) -> tuple[dict, np.ndarray, np.ndarray]: """Return the data in the format expected by the rliable library. :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm is set to the experiment dir. :param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred from the minimum and maximum test episode returns. :return: A tuple score_dict (algo_name->returns), env_steps, and score_thresholds. """ env_steps, returns = self._get_env_steps_and_returns(scope=scope) if score_thresholds is None: score_thresholds = np.linspace( np.min(returns), np.max(returns), 101, ) if algo_name is None: algo_name = os.path.basename(self.exp_dir) score_dict = {algo_name: returns} return score_dict, env_steps, score_thresholds def _compute_iqm_scores( self, scope: DataScope = DataScope.TEST, ) -> tuple[dict, dict]: """Compute the IQM scores and confidence intervals for the experiment in a format expected by the rliable library. :param scope: The scope of the evaluation, either 'TEST' or 'TRAIN'. :return: A tuple of dicts, each with a single entry, (self.exp_name->iqm_scores, self.exp_name->iqm_confidence_intervals), where confidence intervals is an array of shape (2 x E) where the first row contains the lower bounds while the second row contains the upper bound of 95% CIs. """ score_dict, _, _ = self._get_data_in_rliable_format( algo_name=self.exp_name, score_thresholds=None, scope=scope, ) compute_iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0) return rly.get_interval_estimates(score_dict, compute_iqm) def eval_results( self, algo_name: str | None = None, score_thresholds: np.ndarray | None = None, save_as_json: bool = True, save_plots: bool = True, show_plots: bool = True, scope: DataScope = DataScope.TEST, ax_iqm_sample_efficiency: plt.Axes | None = None, ax_performance_profile: plt.Axes | None = None, algo2color: dict[str, str] | None = None, ) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]: """Evaluate the results of an experiment and create a sample efficiency curve and a performance profile. :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm is set to the experiment dir. :param score_thresholds: The score thresholds for the performance profile. If None, they will be inferred from the minimum and maximum test episode returns. :param save_as_json: whether to save the evaluation results as a JSON file (in a format compatible by the Tianshou benchmarking visualization) in the experiment directory. :param save_plots: whether to save the plots to the experiment directory. :param show_plots: whether to display the plots. :param scope: The scope of the evaluation, either 'TEST' or 'TRAIN'. :param ax_iqm_sample_efficiency: The axis to plot the IQM sample efficiency curve on. If None, a new figure is created. :param ax_performance_profile: The axis to plot the performance profile on. If None, a new figure is created. :param algo2color: A dictionary mapping algorithm names to colors. Useful for plotting the evaluations of multiple algorithms in the same figure, e.g., by first creating an ax_iqm and ax_profile with one evaluation and then passing them into the other evaluation. Same as the `colors` kwarg in the rliable plotting utils. :return: The created figures and axes in the order: fig_iqm, ax_iqm, fig_profile, ax_profile. """ iqm_scores, iqm_confidence_intervals = self._compute_iqm_scores(scope=scope) # Plot IQM sample efficiency curve if ax_iqm_sample_efficiency is None: fig_iqm_sample_efficiency, ax_iqm_sample_efficiency = plt.subplots( ncols=1, figsize=(7, 5), constrained_layout=True ) else: fig_iqm_sample_efficiency = ax_iqm_sample_efficiency.get_figure() # type: ignore score_dict, env_steps, score_thresholds = self._get_data_in_rliable_format( algo_name=algo_name, score_thresholds=score_thresholds, scope=scope, ) plot_utils.plot_sample_efficiency_curve( env_steps, iqm_scores, iqm_confidence_intervals, algorithms=None, xlabel="env step", ylabel="IQM episode return", ax=ax_iqm_sample_efficiency, colors=algo2color, ) if show_plots: plt.show(block=False) if save_plots: iqm_sample_efficiency_curve_path = os.path.abspath( os.path.join( self.exp_dir, f"iqm_sample_efficiency_curve_{scope.value}.png", ), ) log.info(f"Saving iqm sample efficiency curve to {iqm_sample_efficiency_curve_path}.") fig_iqm_sample_efficiency.savefig(iqm_sample_efficiency_curve_path) final_score_dict = {algo: returns[:, [-1]] for algo, returns in score_dict.items()} score_distributions, score_distributions_cis = rly.create_performance_profile( final_score_dict, score_thresholds, ) # Plot score distributions if ax_performance_profile is None: fig_performance_profile, ax_performance_profile = plt.subplots( ncols=1, figsize=(7, 5), constrained_layout=True ) else: fig_performance_profile = ax_performance_profile.get_figure() # type: ignore plot_utils.plot_performance_profiles( score_distributions, score_thresholds, performance_profile_cis=score_distributions_cis, xlabel=r"Episode return $(\tau)$", ax=ax_performance_profile, ) if save_plots: profile_curve_path = os.path.abspath( os.path.join(self.exp_dir, f"performance_profile_{scope.value}.png"), ) log.info(f"Saving performance profile curve to {profile_curve_path}.") fig_performance_profile.savefig(profile_curve_path) if show_plots: plt.show(block=False) if save_as_json: json_path = os.path.abspath( os.path.join(self.exp_dir, f"rliable_evaluation_{scope.value.lower()}.json"), ) log.info(f"Saving rliable evaluation results to {json_path}.") eval_results_dict_sequence = [ asdict(eval_entry) for eval_entry in self.to_evaluation_sequence(scope=scope) ] with open(json_path, "w", encoding="utf-8") as f: json.dump(eval_results_dict_sequence, f, indent=4) return ( fig_iqm_sample_efficiency, ax_iqm_sample_efficiency, fig_performance_profile, ax_performance_profile, ) def to_evaluation_sequence( self, scope: DataScope = DataScope.TEST ) -> Iterator[EvaluationSequenceEntry]: """Convert the experiment result to EvaluationSequence. :param scope: The scope of the evaluation, either 'TEST' or 'TRAIN'. :return: The rliable EvaluationSequence. """ env_steps_E, returns_RE = self._get_env_steps_and_returns(scope=scope) iqm_scores_dict, iqm_confidence_intervals_dict = self._compute_iqm_scores(scope=scope) iqm_scores = iqm_scores_dict[self.exp_name] iqm_confidence_intervals = iqm_confidence_intervals_dict[self.exp_name] for i, env_step in enumerate(env_steps_E): returns_R = returns_RE[:, i] returns_mean, returns_std = np.mean(returns_R), np.std(returns_R) yield EvaluationSequenceEntry( env_step=env_step, rew_std=returns_std, rew=returns_mean, iqm=iqm_scores[i], iqm_confidence_interval=tuple(iqm_confidence_intervals[:, i]), ) def load_and_eval_experiment( log_dir: str, show_plots: bool = True, save_plots: bool = True, save_as_json: bool = True, scope: DataScope | Literal["both"] = DataScope.TEST, max_env_step: int | None = None, ) -> MultiRunExperimentResult: """Evaluate the experiments in the given log directory using the rliable API and return the loaded results object. By default, will persist the evaluation results as plots and JSON files in the experiment directory. :param log_dir: The directory containing the experiment results. :param show_plots: whether to display plots. :param save_plots: whether to save plots to the `log_dir`. :param save_as_json: whether to save the evaluation results as a JSON file (in a format compatible by the Tianshou benchmarking visualization) in the experiment directory. :param scope: The scope of the evaluation (training or test) or 'both'. :param max_env_step: The maximum number of environment steps to consider. If None, all data is considered. Note: if the experiments have different numbers of steps, the minimum number is used. """ rliable_result = MultiRunExperimentResult.load_from_disk(log_dir, max_env_step=max_env_step) scopes = [scope] if scope == "both": scopes = [DataScope.TEST, DataScope.TRAINING] for scope in scopes: scope = cast(DataScope, scope) rliable_result.eval_results( show_plots=show_plots, save_plots=save_plots, save_as_json=save_as_json, scope=scope, ) return rliable_result ================================================ FILE: tianshou/exploration/__init__.py ================================================ from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise __all__ = [ "BaseNoise", "GaussianNoise", "OUNoise", ] ================================================ FILE: tianshou/exploration/random.py ================================================ from abc import ABC, abstractmethod from collections.abc import Sequence import numpy as np class BaseNoise(ABC): """The action noise base class.""" @abstractmethod def reset(self) -> None: """Reset to the initial state.""" @abstractmethod def __call__(self, size: Sequence[int]) -> np.ndarray: """Generate new noise.""" raise NotImplementedError class GaussianNoise(BaseNoise): """The vanilla Gaussian process, for exploration in DDPG by default.""" def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None: self._mu = mu assert sigma >= 0, "Noise std should not be negative." self._sigma = sigma def __call__(self, size: Sequence[int]) -> np.ndarray: return np.random.normal(self._mu, self._sigma, size) def reset(self) -> None: pass class OUNoise(BaseNoise): """Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG. Usage: :: # init self.noise = OUNoise() # generate noise noise = self.noise(logits.shape, eps) For required parameters, you can refer to the stackoverflow page. However, our experiment result shows that (similar to OpenAI SpinningUp) using vanilla Gaussian process has little difference from using the Ornstein-Uhlenbeck process. """ def __init__( self, mu: float = 0.0, sigma: float = 0.3, theta: float = 0.15, dt: float = 1e-2, x0: float | np.ndarray | None = None, ) -> None: super().__init__() self._mu = mu self._alpha = theta * dt self._beta = sigma * np.sqrt(dt) self._x0 = x0 self.reset() def reset(self) -> None: """Reset to the initial state.""" self._x = self._x0 def __call__(self, size: Sequence[int], mu: float | None = None) -> np.ndarray: """Generate new noise. Return an numpy array which size is equal to ``size``. """ if self._x is None or (isinstance(self._x, np.ndarray) and self._x.shape != size): self._x = 0.0 if mu is None: mu = self._mu r = self._beta * np.random.normal(size=size) self._x = self._x + self._alpha * (mu - self._x) + r return self._x # type: ignore ================================================ FILE: tianshou/highlevel/__init__.py ================================================ ================================================ FILE: tianshou/highlevel/algorithm.py ================================================ import logging import typing from abc import ABC, abstractmethod from typing import Any, Generic, TypeVar, cast import gymnasium import torch from sensai.util.string import ToStringMixin from tianshou.algorithm import ( A2C, DDPG, DQN, IQN, NPG, PPO, REDQ, SAC, TD3, TRPO, Algorithm, DiscreteSAC, Reinforce, ) from tianshou.algorithm.algorithm_base import ( OffPolicyAlgorithm, OnPolicyAlgorithm, Policy, ) from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.modelfree.iqn import IQNPolicy from tianshou.algorithm.modelfree.redq import REDQPolicy from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.modelfree.sac import SACPolicy from tianshou.data import ReplayBuffer, VectorReplayBuffer from tianshou.data.collector import BaseCollector from tianshou.highlevel.config import ( OffPolicyTrainingConfig, OnPolicyTrainingConfig, TrainingConfig, ) from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ( ActorFactory, ) from tianshou.highlevel.module.core import ( ModuleFactory, TDevice, ) from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory from tianshou.highlevel.params.algorithm_params import ( A2CParams, DDPGParams, DiscreteSACParams, DQNParams, IQNParams, NPGParams, Params, ParamsMixinActorAndDualCritics, ParamsMixinSingleModel, ParamTransformerData, PPOParams, REDQParams, ReinforceParams, SACParams, TD3Params, TRPOParams, ) from tianshou.highlevel.params.algorithm_wrapper import AlgorithmWrapperFactory from tianshou.highlevel.params.collector import ( CollectorFactory, CollectorFactoryDefault, ) from tianshou.highlevel.params.optim import OptimizerFactoryFactory from tianshou.highlevel.persistence import PolicyPersistence from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext from tianshou.highlevel.world import World from tianshou.trainer import ( OffPolicyTrainer, OffPolicyTrainerParams, OnPolicyTrainer, OnPolicyTrainerParams, Trainer, ) from tianshou.utils.net.discrete import DiscreteActor CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" TParams = TypeVar("TParams", bound=Params) TActorCriticParams = TypeVar( "TActorCriticParams", bound=Params | ParamsMixinSingleModel, ) TActorDualCriticsParams = TypeVar( "TActorDualCriticsParams", bound=Params | ParamsMixinActorAndDualCritics, ) TDiscreteCriticOnlyParams = TypeVar( "TDiscreteCriticOnlyParams", bound=Params | ParamsMixinSingleModel, ) TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) TPolicy = TypeVar("TPolicy", bound=Policy) TTrainingConfig = TypeVar("TTrainingConfig", bound=TrainingConfig) log = logging.getLogger(__name__) class AlgorithmFactory(ABC, ToStringMixin, Generic[TTrainingConfig]): """Factory for the creation of an :class:`Algorithm` instance, its policy, trainer as well as collectors.""" def __init__(self, training_config: TTrainingConfig, optim_factory: OptimizerFactoryFactory): self.training_config = training_config self.optim_factory = optim_factory self.algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self.collector_factory: CollectorFactory = CollectorFactoryDefault() def set_collector_factory(self, collector_factory: CollectorFactory) -> None: self.collector_factory = collector_factory def create_train_test_collectors( self, algorithm: Algorithm, envs: Environments, reset_collectors: bool = True, ) -> tuple[BaseCollector, BaseCollector]: """ Creates the collectors for training and test environments. :param algorithm: the algorithm :param envs: the environments wrapper :param reset_collectors: Whether to reset the collectors before returning them. Setting to True means that the envs will be reset as well. :return: a tuple of (training_collector, test_collector) """ buffer_size = self.training_config.buffer_size training_envs = envs.training_envs buffer: ReplayBuffer if len(training_envs) > 1: buffer = VectorReplayBuffer( buffer_size, len(training_envs), stack_num=self.training_config.replay_buffer_stack_num, save_only_last_obs=self.training_config.replay_buffer_save_only_last_obs, ignore_obs_next=self.training_config.replay_buffer_ignore_obs_next, ) else: buffer = ReplayBuffer( buffer_size, stack_num=self.training_config.replay_buffer_stack_num, save_only_last_obs=self.training_config.replay_buffer_save_only_last_obs, ignore_obs_next=self.training_config.replay_buffer_ignore_obs_next, ) training_collector = self.collector_factory.create_collector( algorithm, training_envs, buffer, exploration_noise=True, ) test_collector = self.collector_factory.create_collector(algorithm, envs.test_envs) if reset_collectors: training_collector.reset() test_collector.reset() return training_collector, test_collector def set_policy_wrapper_factory( self, policy_wrapper_factory: AlgorithmWrapperFactory | None, ) -> None: self.algorithm_wrapper_factory = policy_wrapper_factory def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: self.trainer_callbacks = callbacks @staticmethod def _create_policy_from_args( constructor: type[TPolicy], params_dict: dict, policy_params: list[str], **kwargs: Any, ) -> TPolicy: params = {p: params_dict.pop(p) for p in policy_params} return constructor(**params, **kwargs) @abstractmethod def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: pass def create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: algorithm = self._create_algorithm(envs, device) if self.algorithm_wrapper_factory is not None: algorithm = self.algorithm_wrapper_factory.create_wrapped_algorithm( algorithm, envs, self.optim_factory, device, ) return algorithm @abstractmethod def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> Trainer: pass class OnPolicyAlgorithmFactory(AlgorithmFactory[OnPolicyTrainingConfig], ABC): def create_trainer( self, world: World, policy_persistence: PolicyPersistence, ) -> OnPolicyTrainer: training_config = self.training_config callbacks = self.trainer_callbacks context = TrainingContext(world.algorithm, world.envs, world.logger) train_fn = ( callbacks.epoch_train_callback.get_trainer_fn(context) if callbacks.epoch_train_callback else None ) test_fn = ( callbacks.epoch_test_callback.get_trainer_fn(context) if callbacks.epoch_test_callback else None ) stop_fn = ( callbacks.epoch_stop_callback.get_trainer_fn(context) if callbacks.epoch_stop_callback else None ) algorithm = cast(OnPolicyAlgorithm, world.algorithm) assert world.training_collector is not None return algorithm.create_trainer( OnPolicyTrainerParams( training_collector=world.training_collector, test_collector=world.test_collector, max_epochs=training_config.max_epochs, epoch_num_steps=training_config.epoch_num_steps, update_step_num_repetitions=training_config.update_step_num_repetitions, test_step_num_episodes=training_config.test_step_num_episodes, batch_size=training_config.batch_size, collection_step_num_env_steps=training_config.collection_step_num_env_steps, save_best_fn=policy_persistence.get_save_best_fn(world), save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world), logger=world.logger, test_in_training=training_config.test_in_training, training_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, verbose=False, ) ) class OffPolicyAlgorithmFactory(AlgorithmFactory[OffPolicyTrainingConfig], ABC): def create_trainer( self, world: World, policy_persistence: PolicyPersistence, ) -> OffPolicyTrainer: training_config = self.training_config callbacks = self.trainer_callbacks context = TrainingContext(world.algorithm, world.envs, world.logger) train_fn = ( callbacks.epoch_train_callback.get_trainer_fn(context) if callbacks.epoch_train_callback else None ) test_fn = ( callbacks.epoch_test_callback.get_trainer_fn(context) if callbacks.epoch_test_callback else None ) stop_fn = ( callbacks.epoch_stop_callback.get_trainer_fn(context) if callbacks.epoch_stop_callback else None ) algorithm = cast(OffPolicyAlgorithm, world.algorithm) assert world.training_collector is not None return algorithm.create_trainer( OffPolicyTrainerParams( training_collector=world.training_collector, test_collector=world.test_collector, max_epochs=training_config.max_epochs, epoch_num_steps=training_config.epoch_num_steps, collection_step_num_env_steps=training_config.collection_step_num_env_steps, test_step_num_episodes=training_config.test_step_num_episodes, batch_size=training_config.batch_size, save_best_fn=policy_persistence.get_save_best_fn(world), logger=world.logger, update_step_num_gradient_steps_per_sample=training_config.update_step_num_gradient_steps_per_sample, test_in_training=training_config.test_in_training, training_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, verbose=False, ) ) class ReinforceAlgorithmFactory(OnPolicyAlgorithmFactory): def __init__( self, params: ReinforceParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, optim_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory) self.params = params self.actor_factory = actor_factory self.optim_factory = optim_factory def _create_algorithm(self, envs: Environments, device: TDevice) -> Reinforce: actor = self.actor_factory.create_module(envs, device) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) dist_fn = self.actor_factory.create_dist_fn(envs) assert dist_fn is not None policy = self._create_policy_from_args( ProbabilisticActorPolicy, kwargs, ["action_scaling", "action_bound_method", "deterministic_eval"], actor=actor, dist_fn=dist_fn, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) return Reinforce( policy=policy, **kwargs, ) class ActorCriticOnPolicyAlgorithmFactory( OnPolicyAlgorithmFactory, Generic[TActorCriticParams, TAlgorithm], ): def __init__( self, params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory=optimizer_factory) self.params = params self.actor_factory = actor_factory self.critic_factory = critic_factory self.optim_factory = optimizer_factory self.critic_use_action = False @abstractmethod def _get_algorithm_class(self) -> type[TAlgorithm]: pass @typing.no_type_check def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: actor = self.actor_factory.create_module(envs, device) critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) kwargs["actor"] = actor kwargs["critic"] = critic kwargs["action_space"] = envs.get_action_space() kwargs["observation_space"] = envs.get_observation_space() kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs) return kwargs def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: params = self._create_kwargs(envs, device) policy = self._create_policy_from_args( ProbabilisticActorPolicy, params, [ "actor", "dist_fn", "action_space", "deterministic_eval", "observation_space", "action_scaling", "action_bound_method", ], ) algorithm_class = self._get_algorithm_class() return algorithm_class(policy=policy, **params) class A2CAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[A2CParams, A2C]): def _get_algorithm_class(self) -> type[A2C]: return A2C class PPOAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[PPOParams, PPO]): def _get_algorithm_class(self) -> type[PPO]: return PPO class NPGAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[NPGParams, NPG]): def _get_algorithm_class(self) -> type[NPG]: return NPG class TRPOAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[TRPOParams, TRPO]): def _get_algorithm_class(self) -> type[TRPO]: return TRPO class DiscreteCriticOnlyOffPolicyAlgorithmFactory( OffPolicyAlgorithmFactory, Generic[TDiscreteCriticOnlyParams, TAlgorithm], ): def __init__( self, params: TDiscreteCriticOnlyParams, training_config: OffPolicyTrainingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory) self.params = params self.model_factory = model_factory self.optim_factory = optim_factory @abstractmethod def _get_algorithm_class(self) -> type[TAlgorithm]: pass @abstractmethod def _create_policy( self, model: torch.nn.Module, params: dict, action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, ) -> Policy: pass @typing.no_type_check def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: model = self.model_factory.create_module(envs, device) params_dict = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) envs.get_type().assert_discrete(self) action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space()) policy = self._create_policy(model, params_dict, action_space, envs.get_observation_space()) algorithm_class = self._get_algorithm_class() return algorithm_class( policy=policy, **params_dict, ) class DQNAlgorithmFactory(DiscreteCriticOnlyOffPolicyAlgorithmFactory[DQNParams, DQN]): def _create_policy( self, model: torch.nn.Module, params: dict, action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, ) -> Policy: return self._create_policy_from_args( constructor=DiscreteQLearningPolicy, params_dict=params, policy_params=["eps_training", "eps_inference"], model=model, action_space=action_space, observation_space=observation_space, ) def _get_algorithm_class(self) -> type[DQN]: return DQN class IQNAlgorithmFactory(DiscreteCriticOnlyOffPolicyAlgorithmFactory[IQNParams, IQN]): def _create_policy( self, model: torch.nn.Module, params: dict, action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, ) -> Policy: return self._create_policy_from_args( IQNPolicy, params, [ "sample_size", "online_sample_size", "target_sample_size", "eps_training", "eps_inference", ], model=model, action_space=action_space, observation_space=observation_space, ) def _get_algorithm_class(self) -> type[IQN]: return IQN class DDPGAlgorithmFactory(OffPolicyAlgorithmFactory): def __init__( self, params: DDPGParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optim_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory) self.critic_factory = critic_factory self.actor_factory = actor_factory self.params = params self.optim_factory = optim_factory def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: actor = self.actor_factory.create_module(envs, device) critic = self.critic_factory.create_module( envs, device, True, ) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) policy = self._create_policy_from_args( ContinuousDeterministicPolicy, kwargs, ["exploration_noise", "action_scaling", "action_bound_method"], actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) return DDPG( policy=policy, critic=critic, **kwargs, ) class REDQAlgorithmFactory(OffPolicyAlgorithmFactory): def __init__( self, params: REDQParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic_ensemble_factory: CriticEnsembleFactory, optim_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory) self.critic_ensemble_factory = critic_ensemble_factory self.actor_factory = actor_factory self.params = params self.optim_factory = optim_factory def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: envs.get_type().assert_continuous(self) actor = self.actor_factory.create_module( envs, device, ) critic_ensemble = self.critic_ensemble_factory.create_module( envs, device, self.params.ensemble_size, True, ) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) action_space = cast(gymnasium.spaces.Box, envs.get_action_space()) policy = self._create_policy_from_args( REDQPolicy, kwargs, [ "exploration_noise", "deterministic_eval", "action_scaling", "action_bound_method", ], actor=actor, action_space=action_space, observation_space=envs.get_observation_space(), ) return REDQ( policy=policy, critic=critic_ensemble, **kwargs, ) class ActorDualCriticsOffPolicyAlgorithmFactory( OffPolicyAlgorithmFactory, Generic[TActorDualCriticsParams, TAlgorithm, TPolicy], ): def __init__( self, params: TActorDualCriticsParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory) self.params = params self.actor_factory = actor_factory self.critic1_factory = critic1_factory self.critic2_factory = critic2_factory self.optim_factory = optim_factory @abstractmethod def _get_algorithm_class(self) -> type[TAlgorithm]: pass def _get_discrete_last_size_use_action_shape(self) -> bool: return True @staticmethod def _get_critic_use_action(envs: Environments) -> bool: return envs.get_type().is_continuous() @abstractmethod def _create_policy( self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> TPolicy: pass @typing.no_type_check def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: actor = self.actor_factory.create_module(envs, device) use_action_shape = self._get_discrete_last_size_use_action_shape() critic_use_action = self._get_critic_use_action(envs) critic1 = self.critic1_factory.create_module( envs, device, use_action=critic_use_action, discrete_last_size_use_action_shape=use_action_shape, ) critic2 = self.critic2_factory.create_module( envs, device, use_action=critic_use_action, discrete_last_size_use_action_shape=use_action_shape, ) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) policy = self._create_policy(actor, envs, kwargs) algorithm_class = self._get_algorithm_class() return algorithm_class( policy=policy, critic=critic1, critic2=critic2, **kwargs, ) class SACAlgorithmFactory(ActorDualCriticsOffPolicyAlgorithmFactory[SACParams, SAC, SACPolicy]): def _create_policy( self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> SACPolicy: return self._create_policy_from_args( SACPolicy, params, ["exploration_noise", "deterministic_eval", "action_scaling"], actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) def _get_algorithm_class(self) -> type[SAC]: return SAC class DiscreteSACAlgorithmFactory( ActorDualCriticsOffPolicyAlgorithmFactory[DiscreteSACParams, DiscreteSAC, DiscreteSACPolicy] ): def _create_policy( self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> DiscreteSACPolicy: return self._create_policy_from_args( DiscreteSACPolicy, params, ["deterministic_eval"], actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) def _get_algorithm_class(self) -> type[DiscreteSAC]: return DiscreteSAC class TD3AlgorithmFactory( ActorDualCriticsOffPolicyAlgorithmFactory[TD3Params, TD3, ContinuousDeterministicPolicy] ): def _create_policy( self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> ContinuousDeterministicPolicy: return self._create_policy_from_args( ContinuousDeterministicPolicy, params, ["exploration_noise", "action_scaling", "action_bound_method"], actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) def _get_algorithm_class(self) -> type[TD3]: return TD3 ================================================ FILE: tianshou/highlevel/config.py ================================================ import logging import multiprocessing from dataclasses import dataclass from sensai.util.pickle import setstate from sensai.util.string import ToStringMixin log = logging.getLogger(__name__) @dataclass(kw_only=True) class TrainingConfig(ToStringMixin): """Training configuration.""" max_epochs: int = 100 """ the (maximum) number of epochs to run training for. An **epoch** is the outermost iteration level and each epoch consists of a number of training steps and one test step, where each training step * [for the online case] collects environment steps/transitions (**collection step**), adding them to the (replay) buffer (see :attr:`collection_step_num_env_steps` and :attr:`collection_step_num_episodes`) * performs an **update step** via the RL algorithm being used, which can involve one or more actual gradient updates, depending on the algorithm and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate agent performance. Training may be stopped early if the stop criterion is met (see :attr:`~tianshou.trainer.trainer.TrainerParams.stop_fn`). For online training, the number of training steps in each epoch is indirectly determined by :attr:`epoch_num_steps`: As many training steps will be performed as are required in order to reach :attr:`epoch_num_steps` total steps in the training environments. Specifically, if the number of transitions collected per step is `c` (see :attr:`collection_step_num_env_steps`) and :attr:`epoch_num_steps` is set to `s`, then the number of training steps per epoch is `ceil(s / c)`. Therefore, if `max_epochs = e`, the total number of environment steps taken during training can be computed as `e * ceil(s / c) * c`. For offline training, the number of training steps per epoch is equal to :attr:`epoch_num_steps`. """ epoch_num_steps: int = 30000 """ For an online algorithm, this is the total number of environment steps to be collected per epoch, and, for an offline algorithm, it is the total number of training steps to take per epoch. See :attr:`max_epochs` for an explanation of epoch semantics. """ num_training_envs: int = -1 """the number of training environments to use. If set to -1, use number of CPUs/threads.""" num_test_envs: int = 1 """the number of test environments to use""" test_step_num_episodes: int = -1 """the total number of episodes to collect in each test step (across all test environments). -1 means this will be set to the number of test environments, i.e. each test environment will run exactly one episode per test step. """ buffer_size: int = 4096 """the total size of the sample/replay buffer, in which environment steps (transitions) are stored""" collection_step_num_env_steps: int | None = 2048 """ the number of environment steps/transitions to collect in each collection step before the network update within each training step. This is mutually exclusive with :attr:`collection_step_num_episodes`, and one of the two must be set. Note that the exact number can be reached only if this is a multiple of the number of training environments being used, as each training environment will produce the same (non-zero) number of transitions. Specifically, if this is set to `n` and `m` training environments are used, then the total number of transitions collected per collection step is `ceil(n / m) * m =: c`. See :attr:`max_epochs` for information on the total number of environment steps being collected during training. """ collection_step_num_episodes: int | None = None """ the number of episodes to collect in each collection step before the network update within each training step. If this is set, the number of environment steps collected in each collection step is the sum of the lengths of the episodes collected. This is mutually exclusive with :attr:`collection_step_num_env_steps`, and one of the two must be set. """ start_timesteps: int = 0 """ the number of environment steps to collect before the actual training loop begins """ start_timesteps_random: bool = False """ whether to use a random policy (instead of the initial or restored policy to be trained) when collecting the initial :attr:`start_timesteps` environment steps before training """ replay_buffer_ignore_obs_next: bool = False """whether to ignore the `obs_next` field in the collected samples when storing them in the buffer and instead use the one-in-the-future of `obs` as the next observation. This can be useful for very large observations, like for Atari, in order to save RAM. However, setting this to True **may introduce an error** at the last steps of episodes! Should only be used in exceptional cases and only when you know what you are doing. Currently only used in Atari examples and may be removed in the future! """ replay_buffer_save_only_last_obs: bool = False """if True, for the case where the environment outputs stacked frames (e.g. because it is using a `FrameStack` wrapper), save only the most recent frame so as not to duplicate observations in buffer memory. Specifically, if the environment outputs observations `obs` with shape (N, ...), only obs[-1] of shape (...) will be stored. Frame stacking with a fixed number of frames can then be recreated at the buffer level by setting :attr:`replay_buffer_stack_num`. Note: Currently only used in Atari examples and may be removed in the future! """ replay_buffer_stack_num: int = 1 """ the number of consecutive environment observations to stack and use as the observation input to the agent for each time step. Setting this to a value greater than 1 can help agents learn temporal aspects (e.g. velocities of moving objects for which only positions are observed). Note: it is recommended to do this stacking on the environment level by using something like gymnasium's `FrameStack` instead. Setting this to larger than one in conjunction with :attr:`replay_buffer_save_only_last_obs` means that stacking will be recreated at the buffer level, which is more memory-efficient. Currently only used in Atari examples and may be removed in the future! """ def __setstate__(self, state: dict) -> None: setstate( TrainingConfig, self, state, renamed_properties={"num_train_envs": "num_training_envs"} ) def __post_init__(self) -> None: if self.num_training_envs == -1: self.num_training_envs = multiprocessing.cpu_count() if self.test_step_num_episodes == 0 and self.num_test_envs != 0: log.warning( f"Number of test episodes is set to 0, " f"but number of test environments is ({self.num_test_envs}). " f"This can cause unnecessary memory usage.", ) if self.test_step_num_episodes == -1: log.debug( f"Setting test_step_num_episodes to num_test_envs ({self.num_test_envs}) since it was -1." ) self.test_step_num_episodes = self.num_test_envs if ( self.test_step_num_episodes != 0 and self.test_step_num_episodes % self.num_test_envs != 0 ): log.warning( f"Number of test episodes ({self.test_step_num_episodes} " f"is not divisible by the number of test environments ({self.num_test_envs}). " f"This can cause unnecessary memory usage, it is recommended to adjust this.", ) assert ( sum( [ self.collection_step_num_env_steps is not None, self.collection_step_num_episodes is not None, ] ) == 1 ), ( "Only one of `collection_step_num_env_steps` and `collection_step_num_episodes` can be set.", ) @dataclass(kw_only=True) class OnlineTrainingConfig(TrainingConfig): collection_step_num_env_steps: int | None = 2048 """ the number of environment steps/transitions to collect in each collection step before the network update within each training step. This is mutually exclusive with :attr:`collection_step_num_episodes`, and one of the two must be set. Note that the exact number can be reached only if this is a multiple of the number of training environments being used, as each training environment will produce the same (non-zero) number of transitions. Specifically, if this is set to `n` and `m` training environments are used, then the total number of transitions collected per collection step is `ceil(n / m) * m =: c`. See :attr:`max_epochs` for information on the total number of environment steps being collected during training. """ collection_step_num_episodes: int | None = None """ the number of episodes to collect in each collection step before the network update within each training step. If this is set, the number of environment steps collected in each collection step is the sum of the lengths of the episodes collected. This is mutually exclusive with :attr:`collection_step_num_env_steps`, and one of the two must be set. """ test_in_training: bool = False """ Whether to apply a test step within a training step depending on the early stopping criterion (see :meth:`~tianshou.highlevel.Experiment.with_epoch_stop_callback`) being satisfied based on the data collected within the training step. Specifically, after each collect step, we check whether the early stopping criterion would be satisfied by data we collected (provided that at least one episode was indeed completed, such that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step (collecting :attr:`test_step_num_episodes` episodes in order to evaluate performance), and if the early stopping criterion is also satisfied based on the test data, we stop training early. """ def __setstate__(self, state: dict) -> None: setstate( OnlineTrainingConfig, self, state, renamed_properties={"test_in_train": "test_in_training"}, ) @dataclass(kw_only=True) class OnPolicyTrainingConfig(OnlineTrainingConfig): batch_size: int | None = 64 """ Use mini-batches of this size for gradient updates (causing the gradient to be less accurate, a form of regularization). Set ``batch_size=None`` for the full buffer that was collected within the training step to be used for the gradient update (no mini-batching). """ update_step_num_repetitions: int = 1 """ controls, within one update step of an on-policy algorithm, the number of times the full collected data is applied for gradient updates, i.e. if the parameter is 5, then the collected data shall be used five times to update the policy within the same update step. """ @dataclass(kw_only=True) class OffPolicyTrainingConfig(OnlineTrainingConfig): batch_size: int = 64 """ the the number of environment steps/transitions to sample from the buffer for a gradient update. """ update_step_num_gradient_steps_per_sample: float = 1.0 """ the number of gradient steps to perform per sample collected (see :attr:`collection_step_num_env_steps`). Specifically, if this is set to `u` and the number of samples collected in the preceding collection step is `n`, then `round(u * n)` gradient steps will be performed. """ ================================================ FILE: tianshou/highlevel/env.py ================================================ import logging import platform from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from enum import Enum from typing import Any, TypeAlias, cast import gymnasium as gym import gymnasium.spaces import numpy as np from gymnasium import Env from sensai.util.pickle import setstate from sensai.util.string import ToStringMixin from tianshou.env import ( BaseVectorEnv, DummyVectorEnv, RayVectorEnv, SubprocVectorEnv, ) from tianshou.highlevel.persistence import Persistence from tianshou.utils.net.common import TActionShape TObservationShape: TypeAlias = int | Sequence[int] log = logging.getLogger(__name__) class EnvType(Enum): """Enumeration of environment types.""" CONTINUOUS = "continuous" DISCRETE = "discrete" def is_discrete(self) -> bool: return self == EnvType.DISCRETE def is_continuous(self) -> bool: return self == EnvType.CONTINUOUS def assert_continuous(self, requiring_entity: Any) -> None: if not self.is_continuous(): raise AssertionError(f"{requiring_entity} requires continuous environments") def assert_discrete(self, requiring_entity: Any) -> None: if not self.is_discrete(): raise AssertionError(f"{requiring_entity} requires discrete environments") @staticmethod def from_env(env: Env) -> "EnvType": if isinstance(env.action_space, gymnasium.spaces.Discrete): return EnvType.DISCRETE elif isinstance(env.action_space, gymnasium.spaces.Box): return EnvType.CONTINUOUS else: raise Exception(f"Unsupported environment type with action space {env.action_space}") class EnvMode(Enum): """Indicates the purpose for which an environment is created.""" TRAINING = "training" TEST = "test" WATCH = "watch" class VectorEnvType(Enum): DUMMY = "dummy" """Vectorized environment without parallelization; environments are processed sequentially""" SUBPROC = "subproc" """Parallelization based on `subprocess`""" SUBPROC_SHARED_MEM_DEFAULT_CONTEXT = "shmem" """Parallelization based on `subprocess` with shared memory""" SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork" """Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn` by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)""" RAY = "ray" """Parallelization based on the `ray` library""" SUBPROC_SHARED_MEM_AUTO = "subproc_shared_mem_auto" """Parallelization based on `subprocess` with shared memory, using default context on windows and fork context otherwise""" def create_venv( self, factories: Sequence[Callable[[], gym.Env]], ) -> BaseVectorEnv: match self: case VectorEnvType.DUMMY: return DummyVectorEnv(factories) case VectorEnvType.SUBPROC: return SubprocVectorEnv(factories) case VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT: return SubprocVectorEnv(factories, share_memory=True) case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT: return SubprocVectorEnv(factories, share_memory=True, context="fork") case VectorEnvType.SUBPROC_SHARED_MEM_AUTO: if platform.system().lower() == "windows": selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT else: selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT return selected_venv_type.create_venv(factories) case VectorEnvType.RAY: return RayVectorEnv(factories) case _: raise NotImplementedError(self) class Environments(ToStringMixin, ABC): """Represents (vectorized) environments for a learning process.""" def __init__( self, env: gym.Env, training_envs: BaseVectorEnv, test_envs: BaseVectorEnv, watch_env: BaseVectorEnv | None = None, ): self.env = env self.training_envs = training_envs self.test_envs = test_envs self.watch_env = watch_env self.persistence: Sequence[Persistence] = [] @staticmethod def from_factory_and_type( factory_fn: Callable[[EnvMode], gym.Env], env_type: EnvType, venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, create_watch_env: bool = False, ) -> "Environments": """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). :param factory_fn: the factory for a single environment instance :param env_type: the type of environments created by `factory_fn` :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create :param create_watch_env: whether to create an environment for watching the agent :return: the instance """ training_envs = venv_type.create_venv( [lambda: factory_fn(EnvMode.TRAINING)] * num_training_envs, ) test_envs = venv_type.create_venv( [lambda: factory_fn(EnvMode.TEST)] * num_test_envs, ) if create_watch_env: watch_env = VectorEnvType.DUMMY.create_venv([lambda: factory_fn(EnvMode.WATCH)]) else: watch_env = None env = factory_fn(EnvMode.TRAINING) match env_type: case EnvType.CONTINUOUS: return ContinuousEnvironments(env, training_envs, test_envs, watch_env) case EnvType.DISCRETE: return DiscreteEnvironments(env, training_envs, test_envs, watch_env) case _: raise ValueError(f"Environment type {env_type} not handled") def _tostring_includes(self) -> list[str]: return [] def _tostring_additional_entries(self) -> dict[str, Any]: return self.info() def info(self) -> dict[str, Any]: return { "action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape(), } def set_persistence(self, *p: Persistence) -> None: """Associates the given persistence handlers which may persist and restore environment-specific information. :param p: persistence handlers """ self.persistence = p @abstractmethod def get_action_shape(self) -> TActionShape: pass @abstractmethod def get_observation_shape(self) -> TObservationShape: pass def get_action_space(self) -> gym.Space: return self.env.action_space def get_observation_space(self) -> gym.Space: return self.env.observation_space @abstractmethod def get_type(self) -> EnvType: pass class ContinuousEnvironments(Environments): """Represents (vectorized) continuous environments.""" def __init__( self, env: gym.Env, training_envs: BaseVectorEnv, test_envs: BaseVectorEnv, watch_env: BaseVectorEnv | None = None, ): super().__init__(env, training_envs, test_envs, watch_env) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) @staticmethod def from_factory( factory_fn: Callable[[EnvMode], gym.Env], venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, create_watch_env: bool = False, ) -> "ContinuousEnvironments": """Creates an instance from a factory function that creates a single instance. :param factory_fn: the factory for a single environment instance :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create :param create_watch_env: whether to create an environment for watching the agent :return: the instance """ return cast( ContinuousEnvironments, Environments.from_factory_and_type( factory_fn, EnvType.CONTINUOUS, venv_type, num_training_envs, num_test_envs, create_watch_env, ), ) def info(self) -> dict[str, Any]: d = super().info() d["max_action"] = self.max_action return d @staticmethod def _get_continuous_env_info( env: gym.Env, ) -> tuple[tuple[int, ...], tuple[int, ...], float]: if not isinstance(env.action_space, gym.spaces.Box): raise ValueError( "Only environments with continuous action space are supported here. " f"But got env with action space: {env.action_space.__class__}.", ) state_shape = env.observation_space.shape or env.observation_space.n # type: ignore if not state_shape: raise ValueError("Observation space shape is not defined") action_shape = env.action_space.shape max_action = env.action_space.high[0] return state_shape, action_shape, max_action def get_action_shape(self) -> TActionShape: return self.action_shape def get_observation_shape(self) -> TObservationShape: return self.state_shape def get_type(self) -> EnvType: return EnvType.CONTINUOUS class DiscreteEnvironments(Environments): """Represents (vectorized) discrete environments.""" def __init__( self, env: gym.Env, training_envs: BaseVectorEnv, test_envs: BaseVectorEnv, watch_env: BaseVectorEnv | None = None, ): super().__init__(env, training_envs, test_envs, watch_env) self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore self.action_shape = env.action_space.shape or env.action_space.n # type: ignore @staticmethod def from_factory( factory_fn: Callable[[EnvMode], gym.Env], venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, create_watch_env: bool = False, ) -> "DiscreteEnvironments": """Creates an instance from a factory function that creates a single instance. :param factory_fn: the factory for a single environment instance :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create :param create_watch_env: whether to create an environment for watching the agent :return: the instance """ return cast( DiscreteEnvironments, Environments.from_factory_and_type( factory_fn, EnvType.DISCRETE, venv_type, num_training_envs, num_test_envs, create_watch_env, ), ) def get_action_shape(self) -> TActionShape: return self.action_shape def get_observation_shape(self) -> TObservationShape: return self.observation_shape def get_type(self) -> EnvType: return EnvType.DISCRETE class EnvPoolFactory: """A factory for the creation of envpool-based vectorized environments, which can be used in conjunction with :class:`EnvFactoryRegistered`. """ def _transform_task(self, task: str) -> str: return task def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: """Transforms gymnasium keyword arguments to be envpool-compatible. :param kwargs: keyword arguments that would normally be passed to `gymnasium.make`. :param mode: the environment mode :return: the transformed keyword arguments """ kwargs = dict(kwargs) if "render_mode" in kwargs: del kwargs["render_mode"] return kwargs def create_venv( self, task: str, num_envs: int, mode: EnvMode, seed: int, kwargs: dict, ) -> BaseVectorEnv: import envpool envpool_task = self._transform_task(task) envpool_kwargs = self._transform_kwargs(kwargs, mode) return envpool.make_gymnasium( envpool_task, num_envs=num_envs, seed=seed, **envpool_kwargs, ) class EnvFactory(ToStringMixin, ABC): def __init__(self, venv_type: VectorEnvType): """Main interface for the creation of environments (in various forms). :param venv_type: the type of vectorized environment to use for train and test environments. `WATCH` environments are always created as `DUMMY` vector environments. """ self.venv_type = venv_type @staticmethod def _create_rng(seed: int | None) -> np.random.Generator: """ Creates a random number generator with the given seed. :param seed: the seed to use; if None, a random seed will be used :return: the random number generator """ return np.random.default_rng(seed=seed) @staticmethod def _next_seed(rng: np.random.Generator) -> int: """ Samples a random seed from the given random number generator. :param rng: the random number generator :return: the sampled random seed """ # int32 is needed for envpool compatibility return int(rng.integers(0, 2**31, dtype=np.int32)) @abstractmethod def _create_env(self, mode: EnvMode) -> Env: """Creates a single environment for the given mode. :param mode: the mode :return: an environment """ def create_env(self, mode: EnvMode, seed: int | None = None) -> Env: """ Creates a single environment for the given mode. :param mode: the mode :param seed: the random seed to use for the environment; if None, the seed will not be specified, and gymnasium will use a random seed. :return: the environment """ env = self._create_env(mode) # initialize the environment with the given seed (if any) if seed is not None: rng = self._create_rng(seed) env.np_random = rng # also set the seed member within the environment such that it can be retrieved # (gymnasium's random seed handling is, unfortunately, broken) if hasattr(env, "_np_random_seed"): env._np_random_seed = seed return env def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv: """Create vectorized environments. :param num_envs: the number of environments :param mode: the mode for which to create. In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env. :return: the vectorized environments """ rng = self._create_rng(seed) def create_factory_fn() -> Callable[[], Env]: # create a factory function that uses a sampled random seed return lambda random_seed=self._next_seed(rng): self.create_env(mode, seed=random_seed) # type: ignore # create the vectorized environment, seeded appropriately if mode == EnvMode.WATCH: venv = VectorEnvType.DUMMY.create_venv([create_factory_fn()]) else: venv = self.venv_type.create_venv([create_factory_fn() for _ in range(num_envs)]) # seed the action samplers venv.seed([self._next_seed(rng) for _ in range(num_envs)]) return venv def create_envs( self, num_training_envs: int, num_test_envs: int, create_watch_env: bool = False, seed: int | None = None, ) -> Environments: """Create environments for learning. :param num_training_envs: the number of training environments :param num_test_envs: the number of test environments :param create_watch_env: whether to create an environment for watching the agent :param seed: the random seed to use for environment creation :return: the environments """ rng = self._create_rng(seed) env = self.create_env(EnvMode.TRAINING) training_envs = self.create_venv( num_training_envs, EnvMode.TRAINING, seed=self._next_seed(rng) ) test_envs = self.create_venv(num_test_envs, EnvMode.TEST, seed=self._next_seed(rng)) watch_env = ( self.create_venv(1, EnvMode.WATCH, seed=self._next_seed(rng)) if create_watch_env else None ) match EnvType.from_env(env): case EnvType.DISCRETE: return DiscreteEnvironments(env, training_envs, test_envs, watch_env) case EnvType.CONTINUOUS: return ContinuousEnvironments(env, training_envs, test_envs, watch_env) case _: raise ValueError class EnvFactoryRegistered(EnvFactory): """Factory for environments that are registered with gymnasium and thus can be created via `gymnasium.make` (or via `envpool.make_gymnasium`). """ def __init__( self, *, task: str, venv_type: VectorEnvType, envpool_factory: EnvPoolFactory | None = None, render_mode_training: str | None = None, render_mode_test: str | None = None, render_mode_watch: str = "human", **make_kwargs: Any, ): """:param task: the gymnasium task/environment identifier :param seed: the random seed :param venv_type: the type of vectorized environment to use (if `envpool_factory` is not specified) :param envpool_factory: the factory to use for vectorized environment creation based on envpool; envpool must be installed. :param render_mode_training: the render mode to use for training environments :param render_mode_test: the render mode to use for test environments :param render_mode_watch: the render mode to use for environments that are used to watch agent performance :param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`. If envpool is used, the gymnasium parameters will be appropriately translated for use with `envpool.make_gymnasium`. """ super().__init__(venv_type) self.task = task self.envpool_factory = envpool_factory self.render_modes = { EnvMode.TRAINING: render_mode_training, EnvMode.TEST: render_mode_test, EnvMode.WATCH: render_mode_watch, } self.make_kwargs = make_kwargs def __setstate__(self, state: dict) -> None: if "seed" in state: if "test_seed" in state or "training_seed" in state: raise RuntimeError( f"Cannot have both 'seed' and 'test_seed'/'training_seed' in state. " f"Something went wrong during serialization/deserialization: " f"{state=}", ) state["test_seed"] = state["seed"] state["training_seed"] = state["seed"] del state["seed"] if "train_seed" in state: state["training_seed"] = state["train_seed"] del state["train_seed"] setstate(EnvFactoryRegistered, self, state) def _create_kwargs(self, mode: EnvMode) -> dict: """Adapts the keyword arguments for the given mode. :param mode: the mode :return: adapted keyword arguments """ kwargs = dict(self.make_kwargs) kwargs["render_mode"] = self.render_modes.get(mode) return kwargs def _create_env(self, mode: EnvMode) -> Env: """Creates a single environment for the given mode. :param mode: the mode :return: an environment """ kwargs = self._create_kwargs(mode) return gymnasium.make(self.task, **kwargs) def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv: if self.envpool_factory is not None: rng = self._create_rng(seed) return self.envpool_factory.create_venv( self.task, num_envs, mode, self._next_seed(rng), self._create_kwargs(mode), ) else: return super().create_venv(num_envs, mode, seed=seed) ================================================ FILE: tianshou/highlevel/experiment.py ================================================ """The experiment module provides high-level interfaces for setting up and running reinforcement learning experiments. The main entry points are: * :class:`ExperimentConfig`: a dataclass for configuring the experiment. The configuration is different from RL specific configuration (such as policy and trainer parameters) and only pertains to configuration that is common to all experiments. * :class:`Experiment`: represents a reinforcement learning experiment. It is composed of configuration and factory objects, is lightweight and serializable. An instance of `Experiment` is usually saved as a pickle file after an experiment is executed. * :class:`ExperimentBuilder`: a helper class for creating experiments. It contains a lot of defaults and allows for easy customization of the experiment setup. * :class:`ExperimentCollection`: a shallow wrapper around a list of experiments providing a simple interface for running them with a launcher. Useful for running multiple experiments in parallel, in particular, for the important case of running experiments that only differ in their random seeds. Various implementations of the `ExperimentBuilder` are provided for each of the algorithms supported by Tianshou. """ import os import pickle from abc import ABC, abstractmethod from collections.abc import Sequence from contextlib import suppress from copy import deepcopy from dataclasses import asdict, dataclass from pprint import pformat from typing import TYPE_CHECKING, Generic, Self import numpy as np import torch from sensai.util import logging from sensai.util.logging import datetime_tag from sensai.util.string import ToStringMixin from tianshou.algorithm import Algorithm from tianshou.data import BaseCollector, Collector, CollectStats, InfoStats from tianshou.env import BaseVectorEnv from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher from tianshou.evaluation.rliable_evaluation import load_and_eval_experiment from tianshou.highlevel.algorithm import ( A2CAlgorithmFactory, AlgorithmFactory, DDPGAlgorithmFactory, DiscreteSACAlgorithmFactory, DQNAlgorithmFactory, IQNAlgorithmFactory, NPGAlgorithmFactory, PPOAlgorithmFactory, REDQAlgorithmFactory, ReinforceAlgorithmFactory, SACAlgorithmFactory, TD3AlgorithmFactory, TRPOAlgorithmFactory, TTrainingConfig, ) from tianshou.highlevel.config import ( OffPolicyTrainingConfig, OnPolicyTrainingConfig, TrainingConfig, ) from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger from tianshou.highlevel.module.actor import ( ActorFactory, ActorFactoryDefault, ActorFactoryTransientStorageDecorator, ActorFuture, ActorFutureProviderProtocol, ContinuousActorType, IntermediateModuleFactoryFromActorFactory, ) from tianshou.highlevel.module.core import ( TDevice, ) from tianshou.highlevel.module.critic import ( CriticEnsembleFactory, CriticEnsembleFactoryDefault, CriticFactory, CriticFactoryDefault, CriticFactoryReuseActor, ) from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory from tianshou.highlevel.params.algorithm_params import ( A2CParams, DDPGParams, DiscreteSACParams, DQNParams, IQNParams, NPGParams, PPOParams, REDQParams, ReinforceParams, SACParams, TD3Params, TRPOParams, ) from tianshou.highlevel.params.algorithm_wrapper import AlgorithmWrapperFactory from tianshou.highlevel.params.collector import CollectorFactory from tianshou.highlevel.params.optim import ( OptimizerFactoryFactory, OptimizerFactoryFactoryAdam, ) from tianshou.highlevel.persistence import ( PersistenceGroup, PolicyPersistence, ) from tianshou.highlevel.trainer import ( EpochStopCallback, EpochTestCallback, EpochTrainCallback, TrainerCallbacks, ) from tianshou.highlevel.world import World from tianshou.utils import LazyLogger from tianshou.utils.net.common import ModuleType if TYPE_CHECKING: from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher log = logging.getLogger(__name__) @dataclass class ExperimentConfig: """Generic config for setting up the experiment, not RL or training specific.""" seed: int = 42 """The random seed with which to initialize random number generators.""" device: TDevice = "cuda" if torch.cuda.is_available() else "cpu" """The torch device to use""" policy_restore_directory: str | None = None """Directory from which to load the policy neural network parameters (persistence directory of a previous run)""" train: bool = True """Whether to perform training""" watch: bool = True """Whether to watch agent performance (after training)""" watch_num_episodes: int = 10 """Number of episodes for which to watch performance (if `watch` is enabled)""" watch_render: float = 0.0 """Milliseconds between rendered frames when watching agent performance (if `watch` is enabled)""" persistence_base_dir: str = "log" """Base directory in which experiment data is to be stored. Every experiment run will create a subdirectory in this directory based on the run's experiment name""" persistence_enabled: bool = True """Whether persistence is enabled, allowing files to be stored""" log_file_enabled: bool = True """Whether to write to a log file; has no effect if `persistence_enabled` is False. Disable this if you have externally configured log file generation.""" policy_persistence_mode: PolicyPersistence.Mode = PolicyPersistence.Mode.POLICY """Controls the way in which the policy is persisted""" @dataclass class ExperimentResult: """Contains the results of an experiment.""" world: World """The `World` contains all the essential instances of the experiment. Can also be created via `Experiment.create_experiment_world` for more custom setups, see docstring there. Note: it is typically not serializable, so it is not stored in the experiment pickle, and shouldn't be sent across processes, meaning also that `ExperimentResult` itself is typically not serializable. """ trainer_result: InfoStats | None """dataclass of results as returned by the trainer (if any)""" class Experiment(ToStringMixin): """Represents a reinforcement learning experiment. An experiment is composed only of configuration and factory objects, which themselves should be designed to contain only configuration. Therefore, experiments can easily be stored/pickled and later restored without any problems. The main entry points are: 1. :meth:`run`: runs the experiment and returns the results 2. :meth:`create_experiment_world`: creates the world object for the experiment, which contains all relevant instances. Useful for setting up the experiment and running it in a more custom way. The methods :meth:`save` and :meth:`from_directory` can be used to store and restore experiments. """ LOG_FILENAME = "log.txt" EXPERIMENT_PICKLE_FILENAME = "experiment.pkl" def __init__( self, config: ExperimentConfig, env_factory: EnvFactory, algorithm_factory: AlgorithmFactory, training_config: TrainingConfig, name: str, logger_factory: LoggerFactory | None = None, ): if logger_factory is None: logger_factory = LoggerFactoryDefault() self.config = config self.training_config = training_config self.env_factory = env_factory self.algorithm_factory = algorithm_factory self.logger_factory = logger_factory self.name = name @classmethod def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": """Restores an experiment from a previously stored pickle. :param directory: persistence directory of a previous run, in which a pickled experiment is found :param restore_policy: whether the experiment shall be configured to restore the policy that was persisted in the given directory """ with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f: experiment: Experiment = pickle.load(f) if restore_policy: experiment.config.policy_restore_directory = directory return experiment @staticmethod def seeding_info_str_static(seed: int) -> str: """Static method variant of `get_seeding_info_as_str`, which can be used without an `Experiment` instance.""" return f"exp_seed={seed}" def get_seeding_info_as_str(self) -> str: """Returns information on the seeds used in the experiment as a string. This can be useful for creating unique experiment names based on seeds, e.g. A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. """ return self.seeding_info_str_static(self.config.seed) def _set_seed(self) -> None: seed = self.config.seed log.info(f"Setting random seed {seed}") np.random.seed(seed) torch.manual_seed(seed) def _build_config_dict(self) -> dict: return {"experiment": self.pprints()} def save(self, directory: str) -> None: path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME) log.info( f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')", ) with open(path, "wb") as f: pickle.dump(self, f) @staticmethod def persistence_dir_static( persistence_base_dir: str, experiment_name: str, seed: int | None = None ) -> str: """Static method for constructing the persistence directory for an experiment from the base persistence directory and the experiment name. Useful for contexts where one wants access to the persistence directory without having access to the corresponding `Experiment` instance. :param persistence_base_dir: base persistence directory :param experiment_name: name of the experiment :param seed: optional seed. Experiments are saved within a subdirectory named after the seed, but it is often sufficient to know the base directory without the seed subdirectory in user code (for example, for restoring logs or performing rliable evaluations) """ result = os.path.join(persistence_base_dir, experiment_name) if seed is not None: result = os.path.join(result, Experiment.seeding_info_str_static(seed)) return result def create_experiment_world( self, override_experiment_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, reset_collectors: bool = True, ) -> World: """Creates the world object for the experiment. The world object contains all relevant instances for the experiment, such as environments, policy, collectors, etc. This method is the main entrypoint for users who don't want to use `run` directly. A common use case is that some configuration or custom logic should happen before the training loop starts, but one still wants to use the convenience of high-level interfaces for setting up the experiment. :param override_experiment_name: pass to override the experiment name in the resulting `World`. Affects the name of the persistence directory and logger configuration. If None, the experiment's name will be used. The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case a nested directory structure will be created. :param logger_run_id: Run identifier to use for logger initialization/resumption. :param raise_error_on_dirname_collision: whether to raise an error on collisions when creating the persistence directory. Only takes effect if persistence is enabled. Set to `False` e.g., when continuing a previously executed experiment with the same `persistence_base_dir` and name. :param reset_collectors: whether to reset the collectors before training starts. Setting to `False` can be useful when continuing training from a previous run with restored collectors, or for adding custom logic before training starts. """ if override_experiment_name is not None: exp_name = override_experiment_name else: exp_name = self.name # initialize persistence directory use_persistence = self.config.persistence_enabled persistence_dir = os.path.join( self.config.persistence_base_dir, exp_name, self.get_seeding_info_as_str() ) if use_persistence: os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision) with logging.FileLoggerContext( os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence and self.config.log_file_enabled, ): # log initial information log.info(f"Preparing experiment world (name='{exp_name}'):\n{self.pprints()}") log.info(f"Working directory: {os.getcwd()}") self._set_seed() # create environments envs = self.env_factory.create_envs( self.training_config.num_training_envs, self.training_config.num_test_envs, create_watch_env=self.config.watch, seed=self.config.seed, ) log.info(f"Created {envs}") # initialize persistence additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence) policy_persistence = PolicyPersistence( additional_persistence, enabled=use_persistence, mode=self.config.policy_persistence_mode, ) if use_persistence: log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}") self.save(persistence_dir) # initialize logger full_config = self._build_config_dict() full_config.update(envs.info()) full_config["experiment_config"] = asdict(self.config) full_config["training_config_config"] = asdict(self.training_config) with suppress(AttributeError): full_config["policy_params"] = asdict(self.algorithm_factory.params) logger: TLogger if use_persistence: logger = self.logger_factory.create_logger( log_dir=persistence_dir, experiment_name=exp_name, run_id=logger_run_id, config_dict=full_config, ) else: logger = LazyLogger() # create policy and collectors log.info("Creating policy") policy = self.algorithm_factory.create_algorithm(envs, self.config.device) log.info("Creating collectors") training_collector: BaseCollector | None = None test_collector: BaseCollector | None = None if self.config.train: ( training_collector, test_collector, ) = self.algorithm_factory.create_train_test_collectors( policy, envs, reset_collectors=reset_collectors, ) # create context object with all relevant instances (except trainer; added later) world = World( envs=envs, algorithm=policy, training_collector=training_collector, test_collector=test_collector, logger=logger, persist_directory=persistence_dir, restore_directory=self.config.policy_restore_directory, ) # restore policy parameters if applicable if self.config.policy_restore_directory: policy_persistence.restore( policy, world, self.config.device, ) if self.config.train: trainer = self.algorithm_factory.create_trainer(world, policy_persistence) world.trainer = trainer return world def run( self, run_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, ) -> ExperimentResult: """Run the experiment and return the results. :param run_name: Defines a name for this run of the experiment, which determines the subdirectory (within the persistence base directory) where all results will be saved. If None, the experiment's name will be used. The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case a nested directory structure will be created. :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when using wandb, in particular). :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed experiment with the same name. :return: """ if run_name is None: run_name = self.name world = self.create_experiment_world( override_experiment_name=run_name, logger_run_id=logger_run_id, raise_error_on_dirname_collision=raise_error_on_dirname_collision, ) persistence_dir = world.persist_directory use_persistence = self.config.persistence_enabled with logging.FileLoggerContext( os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence and self.config.log_file_enabled, ): trainer_result: InfoStats | None = None if self.config.train: assert world.trainer is not None assert world.training_collector is not None assert world.test_collector is not None # prefilling buffers with either random or current agent's actions if self.training_config.start_timesteps > 0: log.info( f"Collecting {self.training_config.start_timesteps} initial environment " f"steps before training (random={self.training_config.start_timesteps_random})", ) world.training_collector.collect( n_step=self.training_config.start_timesteps, random=self.training_config.start_timesteps_random, ) log.info("Starting training") world.trainer.run() if use_persistence: world.logger.finalize() log.info(f"Training result:\n{pformat(trainer_result)}") # watch agent performance if self.config.watch: assert world.envs.watch_env is not None log.info("Watching agent performance") self._watch_agent( self.config.watch_num_episodes, world.algorithm, world.envs.watch_env, self.config.watch_render, ) return ExperimentResult(world=world, trainer_result=trainer_result) @staticmethod def _watch_agent( num_episodes: int, policy: Algorithm, env: BaseVectorEnv, render: float, ) -> None: collector = Collector[CollectStats](policy, env) collector.reset() result = collector.collect(n_episode=num_episodes, render=render) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info( f"Watched episodes: mean reward={result.returns_stat.mean}, mean episode length={result.lens_stat.mean}", ) class ExperimentCollection: """Shallow wrapper around a list of experiments providing a simple interface for running them with a launcher.""" def __init__(self, experiments: list[Experiment]): self.experiments = experiments def run( self, launcher: ExpLauncher | RegisteredExpLauncher | str = RegisteredExpLauncher.SEQUENTIAL, ) -> list[InfoStats | None]: if isinstance(launcher, str): launcher = RegisteredExpLauncher[launcher.upper()] if isinstance(launcher, RegisteredExpLauncher): launcher = launcher.create_launcher() log.info( f"Running {len(self.experiments)} experiments using launcher {launcher.get_name()}" ) return launcher.launch(experiments=self.experiments) class ExperimentBuilder(ABC, Generic[TTrainingConfig]): """A helper class (following the builder pattern) for creating experiments. It contains a lot of defaults for the setup which can be adjusted using the various `with_` methods. For example, the default optimizer is Adam, but can be adjusted using :meth:`with_optim_factory`. Moreover, for simply configuring the default optimizer instead of using a different one, one can use :meth:`with_optim_factory_default`. """ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: TTrainingConfig | None = None, ): """:param env_factory: controls how environments are to be created. :param experiment_config: the configuration for the experiment. If None, will use the default values of `ExperimentConfig`. :param training_config: the training configuration to use. If None, use default values (not recommended). """ if experiment_config is None: experiment_config = ExperimentConfig() if training_config is None: training_config = self._create_training_config() self._config = experiment_config self._env_factory = env_factory self._training_config = training_config self._logger_factory: LoggerFactory | None = None self._optim_factory: OptimizerFactoryFactory | None = None self._algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None self._collector_factory: CollectorFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() @abstractmethod def _create_training_config(self) -> TTrainingConfig: pass def copy(self) -> Self: return deepcopy(self) @property def experiment_config(self) -> ExperimentConfig: return self._config @experiment_config.setter def experiment_config(self, experiment_config: ExperimentConfig) -> None: self._config = experiment_config @property def training_config(self) -> TrainingConfig: return self._training_config @training_config.setter def training_config(self, config: TrainingConfig) -> None: self._training_config = config def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: """Allows to customize the logger factory to use. If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used. :param logger_factory: the factory to use :return: the builder """ self._logger_factory = logger_factory return self def with_algorithm_wrapper_factory( self, algorithm_wrapper_factory: AlgorithmWrapperFactory ) -> Self: """Allows to define a wrapper around the algorithm that is created, extending the original algorithm. :param algorithm_wrapper_factory: the factory for the wrapper :return: the builder """ self._algorithm_wrapper_factory = algorithm_wrapper_factory return self def with_optim_default(self, optim_factory: OptimizerFactoryFactory) -> Self: """Allows to customize the default optimizer to use. The default optimizer applies when optimizer factory factories are set to None in algorithm parameter objects. By default, :class:`OptimizerFactoryFactoryAdam` will be used with default parameters. :param optim_factory: the optimizer factory :return: the builder """ self._optim_factory = optim_factory return self def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self: """Allows to define a callback function which is called at the beginning of every epoch during training. :param callback: the callback :return: the builder """ self._trainer_callbacks.epoch_train_callback = callback return self def with_epoch_test_callback(self, callback: EpochTestCallback) -> Self: """Allows to define a callback function which is called at the beginning of testing in each epoch. :param callback: the callback :return: the builder """ self._trainer_callbacks.epoch_test_callback = callback return self def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self: """Allows to define a callback that decides whether training shall stop early. The callback receives the undiscounted returns of the testing result. :param callback: the callback :return: the builder """ self._trainer_callbacks.epoch_stop_callback = callback return self def with_name( self, name: str, ) -> Self: """Sets the name of the experiment. :param name: the name to use for this experiment, which, when the experiment is run, will determine the storage sub-folder by default :return: the builder """ self._name = name return self def with_collector_factory(self, collector_factory: CollectorFactory) -> Self: """Allows customizing the collector factory to use. :param collector_factory: the factory to use for the creation of collectors :return: the builder """ self._collector_factory = collector_factory return self @abstractmethod def _create_algorithm_factory(self) -> AlgorithmFactory: pass def _get_optim_factory(self) -> OptimizerFactoryFactory: if self._optim_factory is None: return OptimizerFactoryFactoryAdam() else: return self._optim_factory def build(self) -> Experiment: """Creates the experiment based on the options specified via this builder. :return: the experiment """ algorithm_factory = self._create_algorithm_factory() algorithm_factory.set_trainer_callbacks(self._trainer_callbacks) if self._algorithm_wrapper_factory: algorithm_factory.set_policy_wrapper_factory(self._algorithm_wrapper_factory) if self._collector_factory: algorithm_factory.set_collector_factory(self._collector_factory) experiment: Experiment = Experiment( config=self._config, env_factory=self._env_factory, algorithm_factory=algorithm_factory, training_config=self._training_config, name=self._name, logger_factory=self._logger_factory, ) return experiment def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: """Creates a collection of experiments with non-overlapping random seeds, starting from the configured seed. Useful for performing statistically meaningful evaluations of an algorithm's performance. The `rliable` recommendation is to use at least 5 experiments for computing quantities such as the interquantile mean and performance profiles. See the usage in example scripts like `examples/mujoco/mujoco_ppo_hl_multi.py`. Each experiment in the collection will have a unique name created from the original experiment name and the seeds used. """ seeded_experiments = [] for i in range(num_experiments): builder = self.copy() builder.experiment_config.seed += i experiment = builder.build() seeded_experiments.append(experiment) return ExperimentCollection(seeded_experiments) def build_and_run( self, num_experiments: int = 1, launcher: ExpLauncher | RegisteredExpLauncher | str = RegisteredExpLauncher.SEQUENTIAL, perform_rliable_analysis: bool = True, ) -> list[InfoStats | None]: """Build and run experiments. With multiple experiments, the seeds will be non-overlapping and the parallelism is controlled by the launcher. :param num_experiments: the number of experiments to create and run :param launcher: the launcher (or the corresponding enum value) to use for running the experiments :param perform_rliable_analysis: whether to perform rliable analysis on the results (only applicable if `num_experiments > 1`). This will show plots and store them in the persistence directory. :return: list of results, one per experiment """ collection = self.build_seeded_collection(num_experiments) successful_experiment_stats = collection.run(launcher) num_successful_experiments = len(successful_experiment_stats) for i, info_stats in enumerate(successful_experiment_stats, start=1): if info_stats is not None: log.info( f"Training stats for successful experiment {i}/{num_successful_experiments}:" ) log.info(info_stats.pprints_asdict()) else: log.info( f"No training stats available for successful experiment {i}/{num_successful_experiments}.", ) if perform_rliable_analysis and num_successful_experiments > 1: log.info(f"Performing rliable evaluation over {num_successful_experiments} experiments") persistence_dir = Experiment.persistence_dir_static( self._config.persistence_base_dir, self._name ) load_and_eval_experiment(persistence_dir) return successful_experiment_stats class OnPolicyExperimentBuilder(ExperimentBuilder[OnPolicyTrainingConfig], ABC): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OnPolicyTrainingConfig | None = None, ): """ :param env_factory: controls how environments are to be created. :param experiment_config: the configuration for the experiment. If None, will use the default values of :class:`ExperimentConfig`. :param training_config: the training configuration to use. If None, use default values (not recommended). """ super().__init__(env_factory, experiment_config, training_config) def _create_training_config(self) -> OnPolicyTrainingConfig: return OnPolicyTrainingConfig() class OffPolicyExperimentBuilder(ExperimentBuilder[OffPolicyTrainingConfig], ABC): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OffPolicyTrainingConfig | None = None, ): """ :param env_factory: controls how environments are to be created. :param experiment_config: the configuration for the experiment. If None, will use the default values of :class:`ExperimentConfig`. :param training_config: the training configuration to use. If None, use default values (not recommended). """ super().__init__(env_factory, experiment_config, training_config) def _create_training_config(self) -> OffPolicyTrainingConfig: return OffPolicyTrainingConfig() class _BuilderMixinActorFactory(ActorFutureProviderProtocol): def __init__(self, continuous_actor_type: ContinuousActorType): self._continuous_actor_type = continuous_actor_type self._actor_future = ActorFuture() self._actor_factory: ActorFactory | None = None def with_actor_factory(self, actor_factory: ActorFactory) -> Self: """Allows customizing the actor component via the specification of a factory. If this function is not called, a default actor factory (with default parameters) will be used. :param actor_factory: the factory to use for the creation of the actor network :return: the builder """ self._actor_factory = actor_factory return self def _with_actor_factory_default( self, hidden_sizes: Sequence[int], hidden_activation: ModuleType = torch.nn.ReLU, continuous_unbounded: bool = False, continuous_conditioned_sigma: bool = False, ) -> Self: """Adds a default actor factory with the given parameters. :param hidden_sizes: the sequence of hidden dimensions to use in the network structure :param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits :param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma) shall be computed from the input; if False, sigma is an independent parameter. :return: the builder """ self._actor_factory = ActorFactoryDefault( self._continuous_actor_type, hidden_sizes, hidden_activation=hidden_activation, continuous_unbounded=continuous_unbounded, continuous_conditioned_sigma=continuous_conditioned_sigma, ) return self def get_actor_future(self) -> ActorFuture: """:return: an object, which, in the future, will contain the actor instance that is created for the experiment.""" return self._actor_future def _get_actor_factory(self) -> ActorFactory: actor_factory: ActorFactory if self._actor_factory is None: actor_factory = ActorFactoryDefault(self._continuous_actor_type) else: actor_factory = self._actor_factory return ActorFactoryTransientStorageDecorator(actor_factory, self._actor_future) class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): """Specialization of the actor mixin where, in the continuous case, the actor component outputs Gaussian distribution parameters.""" def __init__(self) -> None: super().__init__(ContinuousActorType.GAUSSIAN) def with_actor_factory_default( self, hidden_sizes: Sequence[int], hidden_activation: ModuleType = torch.nn.ReLU, continuous_unbounded: bool = False, continuous_conditioned_sigma: bool = False, ) -> Self: """Defines use of the default actor factory, allowing its parameters it to be customized. The default actor factory uses an MLP-style architecture. :param hidden_sizes: dimensions of hidden layers used by the network :param hidden_activation: the activation function to use for hidden layers :param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits :param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma) shall be computed from the input; if False, sigma is an independent parameter. :return: the builder """ return super()._with_actor_factory_default( hidden_sizes, hidden_activation=hidden_activation, continuous_unbounded=continuous_unbounded, continuous_conditioned_sigma=continuous_conditioned_sigma, ) class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory): """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" def __init__(self) -> None: super().__init__(ContinuousActorType.DETERMINISTIC) def with_actor_factory_default( self, hidden_sizes: Sequence[int], hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Defines use of the default actor factory, allowing its parameters it to be customized. The default actor factory uses an MLP-style architecture. :param hidden_sizes: dimensions of hidden layers used by the network :param hidden_activation: the activation function to use for hidden layers :return: the builder """ return super()._with_actor_factory_default(hidden_sizes, hidden_activation) class _BuilderMixinActorFactory_DiscreteOnly(_BuilderMixinActorFactory): """Specialization of the actor mixin where only environments with discrete action spaces are supported.""" def __init__(self) -> None: super().__init__(ContinuousActorType.UNSUPPORTED) def with_actor_factory_default( self, hidden_sizes: Sequence[int], hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Defines use of the default actor factory, allowing its parameters it to be customized. The default actor factory uses an MLP-style architecture. :param hidden_sizes: dimensions of hidden layers used by the network :param hidden_activation: the activation function to use for hidden layers :return: the builder """ return super()._with_actor_factory_default(hidden_sizes, hidden_activation) class _BuilderMixinCriticsFactory: def __init__(self, num_critics: int, actor_future_provider: ActorFutureProviderProtocol): self._actor_future_provider = actor_future_provider self._critic_factories: list[CriticFactory | None] = [None] * num_critics def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self: self._critic_factories[idx] = critic_factory return self def _with_critic_factory_default( self, idx: int, hidden_sizes: Sequence[int], hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: self._critic_factories[idx] = CriticFactoryDefault( hidden_sizes, hidden_activation=hidden_activation, ) return self def _with_critic_factory_use_actor(self, idx: int) -> Self: self._critic_factories[idx] = CriticFactoryReuseActor( self._actor_future_provider.get_actor_future(), ) return self def _get_critic_factory(self, idx: int) -> CriticFactory: factory = self._critic_factories[idx] if factory is None: return CriticFactoryDefault() else: return factory class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None: super().__init__(1, actor_future_provider) def with_critic_factory(self, critic_factory: CriticFactory) -> Self: """Specifies that the given factory shall be used for the critic. :param critic_factory: the critic factory :return: the builder """ self._with_critic_factory(0, critic_factory) return self def with_critic_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Makes the critic use the default, MLP-style architecture with the given parameters. :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network :param hidden_activation: the activation function to use for hidden layers :return: the builder """ self._with_critic_factory_default(0, hidden_sizes, hidden_activation) return self class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory): def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None: super().__init__(actor_future_provider) def with_critic_factory_use_actor(self) -> Self: """Makes the first critic reuse the actor's preprocessing network (parameter sharing).""" return self._with_critic_factory_use_actor(0) class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None: super().__init__(2, actor_future_provider) def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self: """Specifies that the given factory shall be used for both critics. :param critic_factory: the critic factory :return: the builder """ for i in range(len(self._critic_factories)): self._with_critic_factory(i, critic_factory) return self def with_common_critic_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Makes both critics use the default, MLP-style architecture with the given parameters. :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network :param hidden_activation: the activation function to use for hidden layers :return: the builder """ for i in range(len(self._critic_factories)): self._with_critic_factory_default(i, hidden_sizes, hidden_activation) return self def with_common_critic_factory_use_actor(self) -> Self: """Makes both critics reuse the actor's preprocessing network (parameter sharing).""" for i in range(len(self._critic_factories)): self._with_critic_factory_use_actor(i) return self def with_critic1_factory(self, critic_factory: CriticFactory) -> Self: """Specifies that the given factory shall be used for the first critic. :param critic_factory: the critic factory :return: the builder """ self._with_critic_factory(0, critic_factory) return self def with_critic1_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Makes the first critic use the default, MLP-style architecture with the given parameters. :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network :param hidden_activation: the activation function to use for hidden layers :return: the builder """ self._with_critic_factory_default(0, hidden_sizes, hidden_activation) return self def with_critic1_factory_use_actor(self) -> Self: """Makes the first critic reuse the actor's preprocessing network (parameter sharing).""" return self._with_critic_factory_use_actor(0) def with_critic2_factory(self, critic_factory: CriticFactory) -> Self: """Specifies that the given factory shall be used for the second critic. :param critic_factory: the critic factory :return: the builder """ self._with_critic_factory(1, critic_factory) return self def with_critic2_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Makes the second critic use the default, MLP-style architecture with the given parameters. :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network :param hidden_activation: the activation function to use for hidden layers :return: the builder """ self._with_critic_factory_default(1, hidden_sizes, hidden_activation) return self def with_critic2_factory_use_actor(self) -> Self: """Makes the second critic reuse the actor's preprocessing network (parameter sharing).""" return self._with_critic_factory_use_actor(1) class _BuilderMixinCriticEnsembleFactory: def __init__(self) -> None: self.critic_ensemble_factory: CriticEnsembleFactory | None = None def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self: """Specifies that the given factory shall be used for the critic ensemble. If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used. :param factory: the critic ensemble factory :return: the builder """ self.critic_ensemble_factory = factory return self def with_critic_ensemble_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, ) -> Self: """Allows to customize the parameters of the default critic ensemble factory. :param hidden_sizes: the sequence of sizes of hidden layers in the network architecture :return: the builder """ self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes) return self def _get_critic_ensemble_factory(self) -> CriticEnsembleFactory: if self.critic_ensemble_factory is None: return CriticEnsembleFactoryDefault() else: return self.critic_ensemble_factory class ReinforceExperimentBuilder( OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OnPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) self._params: ReinforceParams = ReinforceParams() self._env_config = None def with_reinforce_params(self, params: ReinforceParams) -> Self: self._params = params return self def _create_algorithm_factory(self) -> AlgorithmFactory: return ReinforceAlgorithmFactory( self._params, self._training_config, self._get_actor_factory(), self._get_optim_factory(), ) class A2CExperimentBuilder( OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OnPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: A2CParams = A2CParams() self._env_config = None def with_a2c_params(self, params: A2CParams) -> Self: self._params = params return self def _create_algorithm_factory(self) -> AlgorithmFactory: return A2CAlgorithmFactory( self._params, self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), ) class PPOExperimentBuilder( OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OnPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: PPOParams = PPOParams() def with_ppo_params(self, params: PPOParams) -> Self: self._params = params return self def _create_algorithm_factory(self) -> AlgorithmFactory: return PPOAlgorithmFactory( self._params, self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), ) class NPGExperimentBuilder( OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OnPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: NPGParams = NPGParams() def with_npg_params(self, params: NPGParams) -> Self: self._params = params return self def _create_algorithm_factory(self) -> AlgorithmFactory: return NPGAlgorithmFactory( self._params, self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), ) class TRPOExperimentBuilder( OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OnPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: TRPOParams = TRPOParams() def with_trpo_params(self, params: TRPOParams) -> Self: self._params = params return self def _create_algorithm_factory(self) -> AlgorithmFactory: return TRPOAlgorithmFactory( self._params, self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), ) class DQNExperimentBuilder( OffPolicyExperimentBuilder, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OffPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) self._params: DQNParams = DQNParams() self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory( ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False), ) def with_dqn_params(self, params: DQNParams) -> Self: self._params = params return self def with_model_factory(self, module_factory: IntermediateModuleFactory) -> Self: """:param module_factory: factory for a module which maps environment observations to a vector of Q-values (one for each action) :return: the builder """ self._model_factory = module_factory return self def with_model_factory_default( self, hidden_sizes: Sequence[int], hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Allows to configure the default factory for the model of the Q function, which maps environment observations to a vector of Q-values (one for each action). The default model is a multi-layer perceptron. :param hidden_sizes: the sequence of dimensions used for hidden layers :param hidden_activation: the activation function to use for hidden layers (not used for the output layer) :return: the builder """ self._model_factory = IntermediateModuleFactoryFromActorFactory( ActorFactoryDefault( ContinuousActorType.UNSUPPORTED, hidden_sizes=hidden_sizes, hidden_activation=hidden_activation, discrete_softmax=False, ), ) return self def _create_algorithm_factory(self) -> AlgorithmFactory: return DQNAlgorithmFactory( self._params, self._training_config, self._model_factory, self._get_optim_factory(), ) class IQNExperimentBuilder(OffPolicyExperimentBuilder): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OffPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) self._params: IQNParams = IQNParams() self._preprocess_network_factory: IntermediateModuleFactory = ( IntermediateModuleFactoryFromActorFactory( ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False), ) ) def with_iqn_params(self, params: IQNParams) -> Self: self._params = params return self def with_preprocess_network_factory(self, module_factory: IntermediateModuleFactory) -> Self: self._preprocess_network_factory = module_factory return self def _create_algorithm_factory(self) -> AlgorithmFactory: model_factory = ImplicitQuantileNetworkFactory( self._preprocess_network_factory, hidden_sizes=self._params.hidden_sizes, num_cosines=self._params.num_cosines, ) return IQNAlgorithmFactory( self._params, self._training_config, model_factory, self._get_optim_factory(), ) class DDPGExperimentBuilder( OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinSingleCriticCanUseActorFactory, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OffPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: DDPGParams = DDPGParams() def with_ddpg_params(self, params: DDPGParams) -> Self: self._params = params return self def _create_algorithm_factory(self) -> AlgorithmFactory: return DDPGAlgorithmFactory( self._params, self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), ) class REDQExperimentBuilder( OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinCriticEnsembleFactory, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OffPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinCriticEnsembleFactory.__init__(self) self._params: REDQParams = REDQParams() def with_redq_params(self, params: REDQParams) -> Self: self._params = params return self def _create_algorithm_factory(self) -> AlgorithmFactory: return REDQAlgorithmFactory( self._params, self._training_config, self._get_actor_factory(), self._get_critic_ensemble_factory(), self._get_optim_factory(), ) class SACExperimentBuilder( OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinDualCriticFactory, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OffPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self) self._params: SACParams = SACParams() def with_sac_params(self, params: SACParams) -> Self: self._params = params return self def _create_algorithm_factory(self) -> AlgorithmFactory: return SACAlgorithmFactory( self._params, self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory(), ) class DiscreteSACExperimentBuilder( OffPolicyExperimentBuilder, _BuilderMixinActorFactory_DiscreteOnly, _BuilderMixinDualCriticFactory, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OffPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_DiscreteOnly.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self) self._params: DiscreteSACParams = DiscreteSACParams() def with_sac_params(self, params: DiscreteSACParams) -> Self: self._params = params return self def _create_algorithm_factory(self) -> AlgorithmFactory: return DiscreteSACAlgorithmFactory( self._params, self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory(), ) class TD3ExperimentBuilder( OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinDualCriticFactory, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, training_config: OffPolicyTrainingConfig | None = None, ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self) self._params: TD3Params = TD3Params() def with_td3_params(self, params: TD3Params) -> Self: self._params = params return self def _create_algorithm_factory(self) -> AlgorithmFactory: return TD3AlgorithmFactory( self._params, self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory(), ) ================================================ FILE: tianshou/highlevel/logger.py ================================================ import os from abc import ABC, abstractmethod from typing import Literal, TypeAlias from sensai.util.string import ToStringMixin from torch.utils.tensorboard import SummaryWriter from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger TLogger: TypeAlias = BaseLogger class LoggerFactory(ToStringMixin, ABC): @abstractmethod def create_logger( self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict | None = None, ) -> TLogger: """Creates the logger. :param log_dir: path to the directory in which log data is to be stored :param experiment_name: the name of the job, which may contain `os.path.delimiter` :param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger :param config_dict: a dictionary with data that is to be logged :return: the logger """ @abstractmethod def get_logger_class(self) -> type[TLogger]: """Returns the class of the logger that is to be created.""" class LoggerFactoryDefault(LoggerFactory): """ :param save_interval: the interval size (in env steps) after which the checkpoint and end of epoch related logs will be saved. """ def __init__( self, logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard", wand_entity: str | None = None, wandb_project: str | None = None, group: str | None = None, job_type: str | None = None, save_interval: int | None = None, ): if logger_type == "wandb" and wandb_project is None: raise ValueError("Must provide 'wandb_project'") self.logger_type = logger_type self.wandb_entity = wand_entity self.wandb_project = wandb_project self.group = group self.job_type = job_type self.save_interval = save_interval def create_logger( self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict | None = None, ) -> TLogger: match self.logger_type: case "wandb": logger = WandbLogger( save_interval=self.save_interval, name=experiment_name.replace(os.path.sep, "__"), run_id=run_id, config=config_dict, entity=self.wandb_entity, project=self.wandb_project, group=self.group, job_type=self.job_type, log_dir=log_dir, ) writer = self._create_writer(log_dir) # writer has to be created after wandb.init! logger.load(writer) return logger case "tensorboard": writer = self._create_writer(log_dir) return TensorboardLogger(writer, save_interval=self.save_interval) case _: raise ValueError(f"Unknown logger type '{self.logger_type}'") def _create_writer(self, log_dir: str) -> SummaryWriter: """Creates a tensorboard writer and adds a text artifact.""" writer = SummaryWriter(log_dir) writer.add_text( "args", str( dict( log_dir=log_dir, logger_type=self.logger_type, wandb_project=self.wandb_project, ), ), ) return writer def get_logger_class(self) -> type[TLogger]: match self.logger_type: case "wandb": return WandbLogger case "tensorboard": return TensorboardLogger case _: raise ValueError(f"Unknown logger type '{self.logger_type}'") ================================================ FILE: tianshou/highlevel/module/__init__.py ================================================ ================================================ FILE: tianshou/highlevel/module/actor.py ================================================ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass from enum import Enum from typing import Protocol import torch from sensai.util.string import ToStringMixin from torch import nn from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrOrCont from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.core import ( ModuleFactory, TDevice, init_linear_orthogonal, ) from tianshou.highlevel.module.intermediate import ( IntermediateModule, IntermediateModuleFactory, ) from tianshou.highlevel.params.dist_fn import ( DistributionFunctionFactoryCategorical, DistributionFunctionFactoryIndependentGaussians, ) from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import ( Actor, ModuleType, ModuleWithVectorOutput, Net, ) class ContinuousActorType(Enum): GAUSSIAN = "gaussian" DETERMINISTIC = "deterministic" UNSUPPORTED = "unsupported" @dataclass class ActorFuture: """Container, which, in the future, will hold an actor instance.""" actor: Actor | nn.Module | None = None class ActorFutureProviderProtocol(Protocol): def get_actor_future(self) -> ActorFuture: pass class ActorFactory(ModuleFactory, ToStringMixin, ABC): @abstractmethod def create_module(self, envs: Environments, device: TDevice) -> Actor | nn.Module: pass @abstractmethod def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: """ :param envs: the environments :return: the distribution function, which converts the actor's output into a distribution, or None if the actor does not output distribution parameters """ @staticmethod def _init_linear(actor: torch.nn.Module) -> None: """Initializes linear layers of an actor module using default mechanisms. :param actor: the actor module. """ init_linear_orthogonal(actor) if hasattr(actor, "mu"): # For continuous action spaces with Gaussian policies # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details for m in actor.mu.modules(): if isinstance(m, torch.nn.Linear): m.weight.data.copy_(0.01 * m.weight.data) class ActorFactoryDefault(ActorFactory): """An actor factory which, depending on the type of environment, creates a suitable MLP-based policy.""" DEFAULT_HIDDEN_SIZES = (64, 64) def __init__( self, continuous_actor_type: ContinuousActorType, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, hidden_activation: ModuleType = nn.ReLU, continuous_unbounded: bool = False, continuous_conditioned_sigma: bool = False, discrete_softmax: bool = True, ): self.continuous_actor_type = continuous_actor_type self.continuous_unbounded = continuous_unbounded self.continuous_conditioned_sigma = continuous_conditioned_sigma self.hidden_sizes = hidden_sizes self.hidden_activation = hidden_activation self.discrete_softmax = discrete_softmax def _create_factory(self, envs: Environments) -> ActorFactory: env_type = envs.get_type() factory: ( ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet ) if env_type == EnvType.CONTINUOUS: match self.continuous_actor_type: case ContinuousActorType.GAUSSIAN: factory = ActorFactoryContinuousGaussianNet( self.hidden_sizes, activation=self.hidden_activation, unbounded=self.continuous_unbounded, conditioned_sigma=self.continuous_conditioned_sigma, ) case ContinuousActorType.DETERMINISTIC: factory = ActorFactoryContinuousDeterministicNet( self.hidden_sizes, activation=self.hidden_activation, ) case ContinuousActorType.UNSUPPORTED: raise ValueError("Continuous action spaces are not supported by the algorithm") case _: raise ValueError(self.continuous_actor_type) elif env_type == EnvType.DISCRETE: factory = ActorFactoryDiscreteNet( self.hidden_sizes, activation=self.hidden_activation, softmax_output=self.discrete_softmax, ) else: raise ValueError(f"{env_type} not supported") return factory def create_module(self, envs: Environments, device: TDevice) -> Actor | nn.Module: factory = self._create_factory(envs) return factory.create_module(envs, device) def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: factory = self._create_factory(envs) return factory.create_dist_fn(envs) class ActorFactoryContinuous(ActorFactory, ABC): """Serves as a type bound for actor factories that are suitable for continuous action spaces.""" class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous): def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU): self.hidden_sizes = hidden_sizes self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> Actor: net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, ) return continuous.ContinuousActorDeterministic( preprocess_net=net_a, action_shape=envs.get_action_shape(), hidden_sizes=(), ).to(device) def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: return None class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): def __init__( self, hidden_sizes: Sequence[int], unbounded: bool = True, conditioned_sigma: bool = False, activation: ModuleType = nn.ReLU, ): """For actors with Gaussian policies. :param hidden_sizes: the sequence of hidden dimensions to use in the network structure :param unbounded: whether to apply tanh activation on final logits :param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the input; if False, sigma is an independent parameter """ self.hidden_sizes = hidden_sizes self.unbounded = unbounded self.conditioned_sigma = conditioned_sigma self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> Actor: net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, ) actor = continuous.ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=envs.get_action_shape(), unbounded=self.unbounded, conditioned_sigma=self.conditioned_sigma, ).to(device) # init params if not self.conditioned_sigma: torch.nn.init.constant_(actor.sigma_param, -0.5) self._init_linear(actor) return actor def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs) class ActorFactoryDiscreteNet(ActorFactory): def __init__( self, hidden_sizes: Sequence[int], softmax_output: bool = True, activation: ModuleType = nn.ReLU, ): self.hidden_sizes = hidden_sizes self.softmax_output = softmax_output self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> Actor: net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, ) return discrete.DiscreteActor( preprocess_net=net_a, action_shape=envs.get_action_shape(), hidden_sizes=(), softmax_output=self.softmax_output, ).to(device) def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: return DistributionFunctionFactoryCategorical( is_probs_input=self.softmax_output, ).create_dist_fn(envs) class ActorFactoryTransientStorageDecorator(ActorFactory): """Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved.""" def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture): self.actor_factory = actor_factory self._actor_future = actor_future def __getstate__(self) -> dict: d = dict(self.__dict__) del d["_actor_future"] return d def __setstate__(self, state: dict) -> None: self.__dict__ = state self._actor_future = ActorFuture() def _tostring_excludes(self) -> list[str]: return [*super()._tostring_excludes(), "_actor_future"] def create_module(self, envs: Environments, device: TDevice) -> Actor | nn.Module: module = self.actor_factory.create_module(envs, device) self._actor_future.actor = module return module def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: return self.actor_factory.create_dist_fn(envs) class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory): def __init__(self, actor_factory: ActorFactory): self.actor_factory = actor_factory def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: actor = self.actor_factory.create_module(envs, device) assert isinstance(actor, ModuleWithVectorOutput), ( "Actor factory must produce an actor with known vector output dimension" ) return IntermediateModule(actor, actor.get_output_dim()) ================================================ FILE: tianshou/highlevel/module/core.py ================================================ from abc import ABC, abstractmethod from typing import TypeAlias import numpy as np import torch from tianshou.highlevel.env import Environments TDevice: TypeAlias = str | torch.device def init_linear_orthogonal(module: torch.nn.Module) -> None: """Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0. :param module: the module whose submodules are to be processed """ for m in module.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) class ModuleFactory(ABC): """Represents a factory for the creation of a torch module given an environment and target device.""" @abstractmethod def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: pass ================================================ FILE: tianshou/highlevel/module/critic.py ================================================ from abc import ABC, abstractmethod from collections.abc import Sequence import numpy as np from sensai.util.string import ToStringMixin from torch import nn from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.actor import ActorFuture from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.utils.net import continuous from tianshou.utils.net.common import Actor, EnsembleLinear, ModuleType, Net from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic class CriticFactory(ToStringMixin, ABC): """Represents a factory for the generation of a critic module.""" @abstractmethod def create_module( self, envs: Environments, device: TDevice, use_action: bool, discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: """Creates the critic module. :param envs: the environments :param device: the torch device :param use_action: whether to expect the action as an additional input (in addition to the observations) :param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape :return: the module """ class CriticFactoryDefault(CriticFactory): """A critic factory which, depending on the type of environment, creates a suitable MLP-based critic.""" DEFAULT_HIDDEN_SIZES = (64, 64) def __init__( self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, hidden_activation: ModuleType = nn.ReLU, ): self.hidden_sizes = hidden_sizes self.hidden_activation = hidden_activation def create_module( self, envs: Environments, device: TDevice, use_action: bool, discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: factory: CriticFactory env_type = envs.get_type() match env_type: case EnvType.CONTINUOUS: factory = CriticFactoryContinuousNet( self.hidden_sizes, activation=self.hidden_activation, ) case EnvType.DISCRETE: factory = CriticFactoryDiscreteNet( self.hidden_sizes, activation=self.hidden_activation, ) case _: raise ValueError(f"{env_type} not supported") return factory.create_module( envs, device, use_action, discrete_last_size_use_action_shape=discrete_last_size_use_action_shape, ) class CriticFactoryContinuousNet(CriticFactory): def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU): self.hidden_sizes = hidden_sizes self.activation = activation def create_module( self, envs: Environments, device: TDevice, use_action: bool, discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, activation=self.activation, ) critic = continuous.ContinuousCritic(preprocess_net=net_c).to(device) init_linear_orthogonal(critic) return critic class CriticFactoryDiscreteNet(CriticFactory): def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU): self.hidden_sizes = hidden_sizes self.activation = activation def create_module( self, envs: Environments, device: TDevice, use_action: bool, discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, activation=self.activation, ) last_size = ( int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 ) critic = DiscreteCritic(preprocess_net=net_c, last_size=last_size).to(device) init_linear_orthogonal(critic) return critic class CriticFactoryReuseActor(CriticFactory): """A critic factory which reuses the actor's preprocessing component. This class is for internal use in experiment builders only. Reuse of the actor network is supported through the concept of an actor future (:class:`ActorFuture`). When the user declares that he wants to reuse the actor for the critic, we use this factory to support this, but the actor does not exist yet. So the factory instead receives the future, which will eventually be filled when the actor factory is called. When the creation method of this factory is eventually called, it can use the then-filled actor to create the critic. """ def __init__(self, actor_future: ActorFuture): """:param actor_future: the object, which will hold the actor instance later when the critic is to be created""" self.actor_future = actor_future def _tostring_excludes(self) -> list[str]: return ["actor_future"] def create_module( self, envs: Environments, device: TDevice, use_action: bool, discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: actor = self.actor_future.actor if not isinstance(actor, Actor): raise ValueError( f"Option critic_use_action can only be used if actor is of type {Actor.__class__.__name__}", ) if envs.get_type().is_discrete(): # TODO get rid of this prod pattern here and elsewhere last_size = ( int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 ) return DiscreteCritic( preprocess_net=actor.get_preprocess_net(), last_size=last_size, ).to(device) elif envs.get_type().is_continuous(): return ContinuousCritic( preprocess_net=actor.get_preprocess_net(), apply_preprocess_net_to_obs_only=True, ).to(device) else: raise ValueError class CriticEnsembleFactory: @abstractmethod def create_module( self, envs: Environments, device: TDevice, ensemble_size: int, use_action: bool, ) -> nn.Module: pass class CriticEnsembleFactoryDefault(CriticEnsembleFactory): """A critic ensemble factory which, depending on the type of environment, creates a suitable MLP-based critic.""" DEFAULT_HIDDEN_SIZES = (64, 64) def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES): self.hidden_sizes = hidden_sizes def create_module( self, envs: Environments, device: TDevice, ensemble_size: int, use_action: bool, ) -> nn.Module: env_type = envs.get_type() factory: CriticEnsembleFactory match env_type: case EnvType.CONTINUOUS: factory = CriticEnsembleFactoryContinuousNet(self.hidden_sizes) case EnvType.DISCRETE: raise NotImplementedError("No default is implemented for the discrete case") case _: raise ValueError(f"{env_type} not supported") return factory.create_module( envs, device, ensemble_size, use_action, ) class CriticEnsembleFactoryContinuousNet(CriticEnsembleFactory): def __init__(self, hidden_sizes: Sequence[int]): self.hidden_sizes = hidden_sizes def create_module( self, envs: Environments, device: TDevice, ensemble_size: int, use_action: bool, ) -> nn.Module: def linear_layer(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(ensemble_size, x, y) action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, activation=nn.Tanh, linear_layer=linear_layer, ) critic = continuous.ContinuousCritic( preprocess_net=net_c, linear_layer=linear_layer, flatten_input=False, ).to(device) init_linear_orthogonal(critic) return critic ================================================ FILE: tianshou/highlevel/module/intermediate.py ================================================ from abc import ABC, abstractmethod from dataclasses import dataclass import torch from sensai.util.string import ToStringMixin from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import ModuleFactory, TDevice from tianshou.utils.net.common import ModuleWithVectorOutput @dataclass class IntermediateModule: """Container for a module which computes an intermediate representation (with a known dimension).""" module: torch.nn.Module output_dim: int def get_module_with_vector_output(self) -> ModuleWithVectorOutput: if isinstance(self.module, ModuleWithVectorOutput): return self.module else: return ModuleWithVectorOutput.from_module(self.module, self.output_dim) class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC): """Factory for the generation of a module which computes an intermediate representation.""" @abstractmethod def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: pass def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: return self.create_intermediate_module(envs, device).module ================================================ FILE: tianshou/highlevel/module/special.py ================================================ from collections.abc import Sequence from sensai.util.string import ToStringMixin from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import ModuleFactory, TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.utils.net.discrete import ImplicitQuantileNetwork class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin): def __init__( self, preprocess_net_factory: IntermediateModuleFactory, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, ): self.preprocess_net_factory = preprocess_net_factory self.hidden_sizes = hidden_sizes self.num_cosines = num_cosines def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork: preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device) return ImplicitQuantileNetwork( preprocess_net=preprocess_net.get_module_with_vector_output(), action_shape=envs.get_action_shape(), hidden_sizes=self.hidden_sizes, num_cosines=self.num_cosines, ).to(device) ================================================ FILE: tianshou/highlevel/params/__init__.py ================================================ ================================================ FILE: tianshou/highlevel/params/algorithm_params.py ================================================ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from dataclasses import asdict, dataclass from typing import Any, Literal, Protocol from sensai.util.string import ToStringMixin from tianshou.exploration import BaseNoise from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactory from tianshou.highlevel.params.noise import NoiseFactory from tianshou.highlevel.params.optim import OptimizerFactoryFactory @dataclass(kw_only=True) class ParamTransformerData: """Holds data that can be used by `ParamTransformer` instances to perform their transformation. The representation contains the superset of all data items that are required by different types of agent factories. An agent factory is expected to set only the attributes that are relevant to its parameters. """ envs: Environments device: TDevice optim_factory_default: OptimizerFactoryFactory class ParamTransformer(ABC): """Base class for parameter transformations from high to low-level API. Transforms one or more parameters from the representation used by the high-level API to the representation required by the (low-level) policy implementation. It operates directly on a dictionary of keyword arguments, which is initially generated from the parameter dataclass (subclass of `Params`). """ @abstractmethod def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: pass @staticmethod def get( d: dict[str, Any], key: str, drop: bool = False, default_factory: Callable[[], Any] | None = None, ) -> Any: try: value = d[key] except KeyError as e: raise Exception(f"Key not found: '{key}'; available keys: {list(d.keys())}") from e if value is None and default_factory is not None: value = default_factory() if drop: del d[key] return value class ParamTransformerDrop(ParamTransformer): def __init__(self, *keys: str): self.keys = keys def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: for k in self.keys: del kwargs[k] class ParamTransformerRename(ParamTransformer): def __init__(self, renamed_params: dict[str, str]): self.renamed_params = renamed_params def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: for old_name, new_name in self.renamed_params.items(): v = kwargs[old_name] del kwargs[old_name] kwargs[new_name] = v class ParamTransformerChangeValue(ParamTransformer): def __init__(self, key: str): self.key = key def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: params[self.key] = self.change_value(params[self.key], data) @abstractmethod def change_value(self, value: Any, data: ParamTransformerData) -> Any: pass class ParamTransformerOptimFactory(ParamTransformer): """Transformer for learning rate scheduler params. Transforms a key containing a learning rate scheduler factory (removed) into a key containing a learning rate scheduler (added) for the data member `optim`. """ def __init__( self, key_optim_factory_factory: str, key_lr: str, key_lr_scheduler_factory_factory: str, key_optim_output: str, ): self.key_optim_factory_factory = key_optim_factory_factory self.key_lr = key_lr self.key_scheduler_factory = key_lr_scheduler_factory_factory self.key_optim_output = key_optim_output def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: optim_factory_factory: OptimizerFactoryFactory = self.get( params, self.key_optim_factory_factory, drop=True, default_factory=lambda: data.optim_factory_default, ) lr_scheduler_factory_factory: LRSchedulerFactoryFactory | None = self.get( params, self.key_scheduler_factory, drop=True ) lr: float = self.get(params, self.key_lr, drop=True) optim_factory = optim_factory_factory.create_optimizer_factory(lr) if lr_scheduler_factory_factory is not None: optim_factory.with_lr_scheduler_factory( lr_scheduler_factory_factory.create_lr_scheduler_factory() ) params[self.key_optim_output] = optim_factory class ParamTransformerAutoAlpha(ParamTransformer): def __init__(self, key: str): self.key = key def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: alpha = self.get(kwargs, self.key) if isinstance(alpha, AutoAlphaFactory): kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.device) class ParamTransformerNoiseFactory(ParamTransformerChangeValue): def change_value(self, value: Any, data: ParamTransformerData) -> Any: if isinstance(value, NoiseFactory): value = value.create_noise(data.envs) return value class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue): def change_value(self, value: Any, data: ParamTransformerData) -> Any: if isinstance(value, EnvValueFactory): value = value.create_value(data.envs) return value class ParamTransformerActionScaling(ParamTransformerChangeValue): def change_value(self, value: Any, data: ParamTransformerData) -> Any: if value == "default": return data.envs.get_type().is_continuous() else: return value class GetParamTransformersProtocol(Protocol): def _get_param_transformers(self) -> list[ParamTransformer]: pass @dataclass(kw_only=True) class Params(GetParamTransformersProtocol, ToStringMixin): def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: params = asdict(self) for transformer in self._get_param_transformers(): transformer.transform(params, data) return params def _get_param_transformers(self) -> list[ParamTransformer]: return [] @dataclass(kw_only=True) class ParamsMixinSingleModel(GetParamTransformersProtocol): optim: OptimizerFactoryFactory | None = None """the factory for the creation of the model's optimizer; if None, use default""" lr: float = 1e-3 """the learning rate to use in the gradient-based optimizer""" lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler""" def _get_param_transformers(self) -> list[ParamTransformer]: return [ ParamTransformerOptimFactory("optim", "lr", "lr_scheduler", "optim"), ] @dataclass(kw_only=True) class ParamsMixinActorAndCritic(GetParamTransformersProtocol): actor_optim: OptimizerFactoryFactory | None = None """the factory for the creation of the actor's optimizer; if None, use default""" critic_optim: OptimizerFactoryFactory | None = None """the factory for the creation of the critic's optimizer; if None, use default""" actor_lr: float = 1e-3 """the learning rate to use for the actor network""" critic_lr: float = 1e-3 """the learning rate to use for the critic network""" actor_lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler to use for the actor network (if any)""" critic_lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler to use for the critic network (if any)""" def _get_param_transformers(self) -> list[ParamTransformer]: return [ ParamTransformerOptimFactory( "actor_optim", "actor_lr", "actor_lr_scheduler", "policy_optim" ), ParamTransformerOptimFactory( "critic_optim", "critic_lr", "critic_lr_scheduler", "critic_optim" ), ] @dataclass(kw_only=True) class ParamsMixinActionScaling(GetParamTransformersProtocol): action_scaling: bool | Literal["default"] = "default" """ flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. This applies to continuous action spaces only (gym.spaces.Box) and has no effect for discrete spaces. When enabled, policy outputs are expected to be in the normalized range [-1, 1] (after bounding), and are then linearly transformed to the actual required range. This improves neural network training stability, allows the same algorithm to work across environments with different action ranges, and standardizes exploration strategies. Should be disabled if the actor model already produces outputs in the correct range. """ def _get_param_transformers(self) -> list[ParamTransformer]: return [ParamTransformerActionScaling("action_scaling")] @dataclass(kw_only=True) class ParamsMixinActionScalingAndBounding(ParamsMixinActionScaling): action_bound_method: Literal["clip", "tanh"] | None = "clip" """ the method used for bounding actions in continuous action spaces to the range [-1, 1] before scaling them to the environment's action space (provided that `action_scaling` is enabled). This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None for discrete spaces. When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly constrains outputs to [-1, 1] while preserving gradients. The choice of bounding method affects both training dynamics and exploration behavior. Clipping provides hard boundaries but may create plateau regions in the gradient landscape, while tanh provides smoother transitions but can compress sensitivity near the boundaries. Should be set to None if the actor model inherently produces bounded outputs. Typically used together with `action_scaling=True`. """ @dataclass(kw_only=True) class ParamsMixinExplorationNoise(GetParamTransformersProtocol): exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None """ If not None, add noise to actions for exploration. This is useful when solving "hard exploration" problems. It can either be a distribution, a factory for the creation of a distribution or "default". When set to "default", use Gaussian noise with standard deviation 0.1. """ def _get_param_transformers(self) -> list[ParamTransformer]: return [ParamTransformerNoiseFactory("exploration_noise")] @dataclass(kw_only=True) class ParamsMixinNStepReturnHorizon: n_step_return_horizon: int = 1 """ the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: Higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. """ @dataclass(kw_only=True) class ParamsMixinGamma: gamma: float = 0.99 """ the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" behavior. Higher values (closer to 1) make the agent value long-term rewards more, potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks """ @dataclass(kw_only=True) class ParamsMixinTau: tau: float = 0.005 """ the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's parameters, a weighted average is used: target = tau * source + (1 - tau) * target. Smaller values (closer to 0) create more stable but slower learning as target networks change more gradually. Higher values (closer to 1) allow faster learning but may reduce stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. """ @dataclass(kw_only=True) class ParamsMixinDeterministicEval: deterministic_eval: bool = False """ flag indicating whether the policy should use deterministic actions (using the mode of the action distribution) instead of stochastic ones (using random sampling) during evaluation. When enabled, the policy will always select the most probable action according to the learned distribution during evaluation phases, while still using stochastic sampling during training. This creates a clear distinction between exploration (training) and exploitation (evaluation) behaviors. Deterministic actions are generally preferred for final deployment and reproducible evaluation as they provide consistent behavior, reduce variance in performance metrics, and are more interpretable for human observers. Note that this parameter only affects behavior when the policy is not within a training step. When collecting rollouts for training, actions remain stochastic regardless of this setting to maintain proper exploration behaviour. """ class OnPolicyAlgorithmParams( Params, ParamsMixinGamma, ParamsMixinActionScalingAndBounding, ParamsMixinSingleModel, ParamsMixinDeterministicEval, ): def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) transformers.extend(ParamsMixinSingleModel._get_param_transformers(self)) return transformers @dataclass(kw_only=True) class ReinforceParams(OnPolicyAlgorithmParams): return_standardization: bool = False """ whether to standardize episode returns by subtracting the running mean and dividing by the running standard deviation. Note that this is known to be detrimental to performance in many cases! """ @dataclass(kw_only=True) class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol): gae_lambda: float = 0.95 """ the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values (closer to 1) reduce bias but increase variance by giving more weight to longer trajectories, while lower values (closer to 0) reduce variance but increase bias by relying more on the immediate TD error and value function estimates. At λ=0, GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. """ max_batchsize: int = 256 """the maximum number of samples to process at once when computing generalized advantage estimation (GAE) and value function predictions. Controls memory usage by breaking large batches into smaller chunks processed sequentially. Higher values may increase speed but require more GPU/CPU memory; lower values reduce memory requirements but may increase computation time. Should be adjusted based on available hardware resources and total batch size of your training data.""" def _get_param_transformers(self) -> list[ParamTransformer]: return [] @dataclass(kw_only=True) class ActorCriticOnPolicyParams(OnPolicyAlgorithmParams): return_scaling: bool = False """ flag indicating whether to enable scaling of estimated returns by dividing them by their running standard deviation without centering the mean. This reduces the magnitude variation of advantages across different episodes while preserving their signs and relative ordering. The use of running statistics (rather than batch-specific scaling) means that early training experiences may be scaled differently than later ones as the statistics evolve. When enabled, this improves training stability in environments with highly variable reward scales and makes the algorithm less sensitive to learning rate settings. However, it may reduce the algorithm's ability to distinguish between episodes with different absolute return magnitudes. Best used in environments where the relative ordering of actions is more important than the absolute scale of returns. """ @dataclass(kw_only=True) class A2CParams(ActorCriticOnPolicyParams, ParamsMixinGeneralAdvantageEstimation): vf_coef: float = 0.5 """ coefficient that weights the value loss relative to the actor loss in the overall loss function. Higher values prioritize accurate value function estimation over policy improvement. Controls the trade-off between policy optimization and value function fitting. Typically set between 0.5 and 1.0 for most actor-critic implementations. """ ent_coef: float = 0.01 """ coefficient that weights the entropy bonus relative to the actor loss. Controls the exploration-exploitation trade-off by encouraging policy entropy. Higher values promote more exploration by encouraging a more uniform action distribution. Lower values focus more on exploitation of the current policy's knowledge. Typically set between 0.01 and 0.05 for most actor-critic implementations. """ max_grad_norm: float | None = None """ the maximum L2 norm threshold for gradient clipping. When not None, gradients will be rescaled using to ensure their L2 norm does not exceed this value. This prevents exploding gradients and stabilizes training by limiting the magnitude of parameter updates. Set to None to disable gradient clipping. """ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self)) return transformers @dataclass(kw_only=True) class PPOParams(A2CParams): eps_clip: float = 0.2 """ determines the range of allowed change in the policy during a policy update: The ratio of action probabilities indicated by the new and old policy is constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. Small values thus force the new policy to stay close to the old policy. Typical values range between 0.1 and 0.3, the value of 0.2 is recommended in the original PPO paper. The optimal value depends on the environment; more stochastic environments may need larger values. """ dual_clip: float | None = None """ a clipping parameter (denoted as c in the literature) that prevents excessive pessimism in policy updates for negative-advantage actions. Excessive pessimism occurs when the policy update too strongly reduces the probability of selecting actions that led to negative advantages, potentially eliminating useful actions based on limited negative experiences. When enabled (c > 1), the objective for negative advantages becomes: max(min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A), c*A), where min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) is the original single-clipping objective determined by `eps_clip`. This creates a floor on negative policy gradients, maintaining some probability of exploring actions despite initial negative outcomes. Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer to 1.0 provide less protection against pessimistic updates. Set to None to disable dual clipping. """ value_clip: bool = False """ flag indicating whether to enable clipping for value function updates. When enabled, restricts how much the value function estimate can change from its previous prediction, using the same clipping range as the policy updates (eps_clip). This stabilizes training by preventing large fluctuations in value estimates, particularly useful in environments with high reward variance. The clipped value loss uses a pessimistic approach, taking the maximum of the original and clipped value errors: max((returns - value)², (returns - v_clipped)²) Setting to True often improves training stability but may slow convergence. Implementation follows the approach mentioned in arXiv:1811.02553v3 Sec. 4.1. """ advantage_normalization: bool = True """whether to apply per mini-batch advantage normalization.""" recompute_advantage: bool = False """ whether to recompute advantage every update repeat as described in https://arxiv.org/pdf/2006.05990.pdf, Sec. 3.5. The original PPO implementation splits the data in each policy iteration step into individual transitions and then randomly assigns them to minibatches. This makes it impossible to compute advantages as the temporal structure is broken. Therefore, the advantages are computed once at the beginning of each policy iteration step and then used in minibatch policy and value function optimization. This results in higher diversity of data in each minibatch at the cost of using slightly stale advantage estimations. Enabling this option will, as a remedy to this problem, recompute the advantages at the beginning of each pass over the data instead of just once per iteration. """ @dataclass(kw_only=True) class NPGParams(ActorCriticOnPolicyParams, ParamsMixinGeneralAdvantageEstimation): optim_critic_iters: int = 5 """ the number of optimization steps performed on the critic network for each policy (actor) update. Controls the learning rate balance between critic and actor. Higher values prioritize critic accuracy by training the value function more extensively before each policy update, which can improve stability but slow down training. Lower values maintain a more even learning pace between policy and value function but may lead to less reliable advantage estimates. Typically set between 1 and 10, depending on the complexity of the value function. """ trust_region_size: float = 0.5 """ the parameter delta - a scalar multiplier for policy updates in the natural gradient direction. The mathematical meaning is the trust region size, which is the maximum KL divergence allowed between the old and new policy distributions. Controls how far the policy parameters move in the calculated direction during each update. Higher values allow for faster learning but may cause instability or policy deterioration; lower values provide more stable but slower learning. Unlike regular policy gradients, natural gradients already account for the local geometry of the parameter space, making this step size more robust to different parameterizations. Typically set between 0.1 and 1.0 for most reinforcement learning tasks. """ advantage_normalization: bool = True """whether to do per mini-batch advantage normalization.""" def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self)) return transformers @dataclass(kw_only=True) class TRPOParams(NPGParams): max_kl: float = 0.01 """ maximum KL divergence, used to constrain each actor network update. """ backtrack_coeff: float = 0.8 """ coefficient with which to reduce the step size when constraints are not met. """ max_backtracks: int = 10 """maximum number of times to backtrack in line search when the constraints are not met.""" @dataclass(kw_only=True) class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol): actor_optim: OptimizerFactoryFactory | None = None """the factory for the creation of the actor's optimizer; if None, use default""" critic1_optim: OptimizerFactoryFactory | None = None """the factory for the creation of the first critic's optimizer; if None, use default""" critic2_optim: OptimizerFactoryFactory | None = None """the factory for the creation of the second critic's optimizer; if None, use default""" actor_lr: float = 1e-3 """the learning rate to use for the actor network""" critic1_lr: float = 1e-3 """the learning rate to use for the first critic network""" critic2_lr: float = 1e-3 """the learning rate to use for the second critic network""" actor_lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler to use for the actor network (if any)""" critic1_lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler to use for the first critic network (if any)""" critic2_lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler to use for the second critic network (if any)""" def _get_param_transformers(self) -> list[ParamTransformer]: return [ ParamTransformerOptimFactory( "actor_optim", "actor_lr", "actor_lr_scheduler", "policy_optim" ), ParamTransformerOptimFactory( "critic1_optim", "critic1_lr", "critic1_lr_scheduler", "critic_optim" ), ParamTransformerOptimFactory( "critic2_optim", "critic2_lr", "critic2_lr_scheduler", "critic2_optim" ), ] @dataclass(kw_only=True) class ParamsMixinAlpha(GetParamTransformersProtocol): alpha: float | AutoAlphaFactory = 0.2 """ the entropy regularization coefficient, which balances exploration and exploitation. This coefficient controls how much the agent values randomness in its policy versus pursuing higher rewards. Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent for maintaining diverse action choices, even if this means selecting some lower-value actions. Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become more focused on the highest-value actions. A value of 0 would completely remove entropy regularization, potentially leading to premature convergence to suboptimal deterministic policies. Can be provided as a fixed float (0.2 is a reasonable default) or via a factory to support automatic tuning during training. """ def _get_param_transformers(self) -> list[ParamTransformer]: return [ParamTransformerAutoAlpha("alpha")] @dataclass(kw_only=True) class _SACParams( Params, ParamsMixinGamma, ParamsMixinActorAndDualCritics, ParamsMixinNStepReturnHorizon, ParamsMixinTau, ParamsMixinDeterministicEval, ParamsMixinAlpha, ): def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) transformers.extend(ParamsMixinAlpha._get_param_transformers(self)) return transformers @dataclass(kw_only=True) class SACParams(_SACParams, ParamsMixinExplorationNoise, ParamsMixinActionScaling): def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) return transformers @dataclass(kw_only=True) class DiscreteSACParams(_SACParams): pass @dataclass(kw_only=True) class QLearningOffPolicyParams( Params, ParamsMixinGamma, ParamsMixinSingleModel, ParamsMixinNStepReturnHorizon ): target_update_freq: int = 0 """ the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current Q-network values. A value of 0 disables the target network entirely, using only a single network for both action selection and bootstrap targets. Higher values provide more stable learning targets but slow down the propagation of new value estimates. Lower positive values allow faster learning but may lead to instability due to rapidly changing targets. Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. """ eps_training: float = 0.0 """ the epsilon value for epsilon-greedy exploration during training. When collecting data for training, this is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). """ eps_inference: float = 0.0 """ the epsilon value for epsilon-greedy exploration during inference, i.e. non-training cases (such as evaluation during test steps). The epsilon value is the probability of choosing a random action instead of the action chosen by the policy. A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full exploration (fully random). """ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinSingleModel._get_param_transformers(self)) return transformers @dataclass(kw_only=True) class DQNParams(QLearningOffPolicyParams): is_double: bool = True """ flag indicating whether to use the Double DQN algorithm for target value computation. If True, the algorithm uses the online network to select actions and the target network to evaluate their Q-values. This approach helps reduce the overestimation bias in Q-learning by decoupling action selection from action evaluation. If False, the algorithm follows the vanilla DQN method that directly takes the maximum Q-value from the target network. Note: Double Q-learning will only be effective when a target network is used (target_update_freq > 0). """ huber_loss_delta: float | None = None """ controls whether to use the Huber loss instead of the MSE loss for the TD error and the threshold for the Huber loss. If None, the MSE loss is used. If not None, uses the Huber loss as described in the Nature DQN paper (nature14236) with the given delta, which limits the influence of outliers. Unlike the MSE loss where the gradients grow linearly with the error magnitude, the Huber loss causes the gradients to plateau at a constant value for large errors, providing more stable training. NOTE: The magnitude of delta should depend on the scale of the returns obtained in the environment. """ def _get_param_transformers(self) -> list[ParamTransformer]: return super()._get_param_transformers() @dataclass(kw_only=True) class IQNParams(QLearningOffPolicyParams): sample_size: int = 32 """the number of samples for policy evaluation""" online_sample_size: int = 8 """the number of samples for online model in training""" target_sample_size: int = 8 """the number of samples for target model in training.""" num_quantiles: int = 200 """the number of quantile midpoints in the inverse cumulative distribution function of the value""" hidden_sizes: Sequence[int] = () """hidden dimensions to use in the IQN network""" num_cosines: int = 64 """number of cosines to use in the IQN network""" def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines")) return transformers @dataclass(kw_only=True) class DDPGParams( Params, ParamsMixinGamma, ParamsMixinActorAndCritic, ParamsMixinExplorationNoise, ParamsMixinActionScalingAndBounding, ParamsMixinNStepReturnHorizon, ParamsMixinTau, ): def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self)) transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) return transformers @dataclass(kw_only=True) class REDQParams(DDPGParams, ParamsMixinDeterministicEval, ParamsMixinAlpha): ensemble_size: int = 10 """ the total number of critic networks in the ensemble. This parameter implements the randomized ensemble approach described in REDQ. The algorithm maintains `ensemble_size` different critic networks that all share the same architecture. During target value computation, a random subset of these networks (determined by `subset_size`) is used. Larger values increase the diversity of the ensemble but require more memory and computation. The original paper recommends a value of 10 for most tasks, balancing performance and computational efficiency. """ subset_size: int = 2 """ the number of critic networks randomly selected from the ensemble for computing target Q-values. During each update, the algorithm samples `subset_size` networks from the ensemble of `ensemble_size` networks without replacement. The target Q-value is then calculated as either the minimum or mean (based on target_mode) of the predictions from this subset. Smaller values increase randomization and sample efficiency but may introduce more variance. Larger values provide more stable estimates but reduce the benefits of randomization. The REDQ paper recommends a value of 2 for optimal sample efficiency. Must satisfy 0 < subset_size <= ensemble_size. """ actor_delay: int = 20 """ the number of critic updates performed before each actor update. The actor network is only updated once for every actor_delay critic updates, implementing a delayed policy update strategy similar to TD3. Larger values stabilize training by allowing critics to become more accurate before policy updates. Smaller values allow the policy to adapt more quickly but may lead to less stable learning. The REDQ paper recommends a value of 20 for most tasks. """ target_mode: Literal["mean", "min"] = "min" """ the method used to aggregate Q-values from the subset of critic networks. Can be either "min" or "mean". If "min", uses the minimum Q-value across the selected subset of critics for each state-action pair. If "mean", uses the average Q-value across the selected subset of critics. Using "min" helps prevent overestimation bias but may lead to more conservative value estimates. Using "mean" provides more optimistic value estimates but may suffer from overestimation bias. Default is "min" following the conservative value estimation approach common in recent Q-learning algorithms. """ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinAlpha._get_param_transformers(self)) return transformers @dataclass(kw_only=True) class TD3Params( Params, ParamsMixinGamma, ParamsMixinActorAndDualCritics, ParamsMixinExplorationNoise, ParamsMixinActionScalingAndBounding, ParamsMixinNStepReturnHorizon, ParamsMixinTau, ): policy_noise: float | FloatEnvValueFactory = 0.2 """ scaling factor for the Gaussian noise added to target policy actions. This parameter implements target policy smoothing, a regularization technique described in the TD3 paper. The noise is sampled from a normal distribution and multiplied by this value before being added to actions. Higher values increase exploration in the target policy, helping to address function approximation error. The added noise is optionally clipped to a range determined by the noise_clip parameter. Typically set between 0.1 and 0.5 relative to the action scale of the environment. """ noise_clip: float | FloatEnvValueFactory = 0.5 """ defines the maximum absolute value of the noise added to target policy actions, i.e. noise values are clipped to the range [-noise_clip, noise_clip] (after generating and scaling the noise via `policy_noise`). This parameter implements bounded target policy smoothing as described in the TD3 paper. It prevents extreme noise values from causing unrealistic target values during training. Setting it 0.0 (or a negative value) disables clipping entirely. It is typically set to about twice the `policy_noise` value (e.g. 0.5 when `policy_noise` is 0.2). """ update_actor_freq: int = 2 """ the frequency of actor network updates relative to critic network updates (the actor network is only updated once for every `update_actor_freq` critic updates). This implements the "delayed" policy updates from the TD3 algorithm, where the actor is updated less frequently than the critics. Higher values (e.g., 2-5) help stabilize training by allowing the critic to become more accurate before updating the policy. The default value of 2 follows the original TD3 paper's recommendation of updating the policy at half the rate of the Q-functions. """ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise")) transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip")) return transformers ================================================ FILE: tianshou/highlevel/params/algorithm_wrapper.py ================================================ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Generic, TypeVar from sensai.util.string import ToStringMixin from tianshou.algorithm import Algorithm, ICMOffPolicyWrapper from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, OnPolicyAlgorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.params.optim import OptimizerFactoryFactory from tianshou.utils.net.discrete import IntrinsicCuriosityModule TAlgorithmOut = TypeVar("TAlgorithmOut", bound=Algorithm) class AlgorithmWrapperFactory(Generic[TAlgorithmOut], ToStringMixin, ABC): @abstractmethod def create_wrapped_algorithm( self, policy: Algorithm, envs: Environments, optim_factory: OptimizerFactoryFactory, device: TDevice, ) -> TAlgorithmOut: pass class AlgorithmWrapperFactoryIntrinsicCuriosity( AlgorithmWrapperFactory[ICMOffPolicyWrapper | ICMOnPolicyWrapper], ): def __init__( self, *, feature_net_factory: IntermediateModuleFactory, hidden_sizes: Sequence[int], lr: float, lr_scale: float, reward_scale: float, forward_loss_weight: float, optim: OptimizerFactoryFactory | None = None, ): self.feature_net_factory = feature_net_factory self.hidden_sizes = hidden_sizes self.lr = lr self.lr_scale = lr_scale self.reward_scale = reward_scale self.forward_loss_weight = forward_loss_weight self.optim_factory = optim def create_wrapped_algorithm( self, algorithm: Algorithm, envs: Environments, optim_factory_default: OptimizerFactoryFactory, device: TDevice, ) -> ICMOffPolicyWrapper | ICMOnPolicyWrapper: feature_net = self.feature_net_factory.create_intermediate_module(envs, device) action_dim = envs.get_action_shape() if not isinstance(action_dim, int): raise ValueError(f"Environment action shape must be an integer, got {action_dim}") feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net=feature_net.module, feature_dim=feature_dim, action_dim=action_dim, hidden_sizes=self.hidden_sizes, ) optim_factory = self.optim_factory or optim_factory_default icm_optim = optim_factory.create_optimizer_factory(lr=self.lr) if isinstance(algorithm, OffPolicyAlgorithm): return ICMOffPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=self.lr_scale, reward_scale=self.reward_scale, forward_loss_weight=self.forward_loss_weight, ).to(device) elif isinstance(algorithm, OnPolicyAlgorithm): return ICMOnPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=self.lr_scale, reward_scale=self.reward_scale, forward_loss_weight=self.forward_loss_weight, ).to(device) else: raise ValueError(f"{algorithm} is not supported by ICM") ================================================ FILE: tianshou/highlevel/params/alpha.py ================================================ from abc import ABC, abstractmethod import numpy as np from sensai.util.string import ToStringMixin from tianshou.algorithm.modelfree.sac import Alpha, AutoAlpha from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.params.optim import OptimizerFactoryFactory class AutoAlphaFactory(ToStringMixin, ABC): @abstractmethod def create_auto_alpha( self, envs: Environments, device: TDevice, ) -> Alpha: pass class AutoAlphaFactoryDefault(AutoAlphaFactory): def __init__( self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0, log_alpha: float = 0.0, optim: OptimizerFactoryFactory | None = None, ) -> None: """ :param lr: the learning rate for the optimizer of the alpha parameter :param target_entropy_coefficient: the coefficient with which to multiply the target entropy; The base value being scaled is `dim(A)` for continuous action spaces and `log(|A|)` for discrete action spaces, i.e. with the default coefficient -1, we obtain `-dim(A)` and `-log(dim(A))` for continuous and discrete action spaces respectively, which gives a reasonable trade-off between exploration and exploitation. For decidedly stochastic exploration, you can use a positive value closer to 1 (e.g. 0.98); 1.0 would give full entropy exploration. :param log_alpha: the (initial) value of the log of the entropy regularization coefficient alpha. :param optim: the optimizer factory to use; if None, use default """ self.lr = lr self.target_entropy_coefficient = target_entropy_coefficient self.log_alpha = log_alpha self.optimizer_factory_factory = optim or OptimizerFactoryFactory.default() def create_auto_alpha( self, envs: Environments, device: TDevice, ) -> AutoAlpha: action_dim = np.prod(envs.get_action_shape()) if envs.get_type().is_continuous(): target_entropy = self.target_entropy_coefficient * float(action_dim) else: target_entropy = self.target_entropy_coefficient * np.log(action_dim) optim_factory = self.optimizer_factory_factory.create_optimizer_factory(lr=self.lr) return AutoAlpha(target_entropy, self.log_alpha, optim_factory) ================================================ FILE: tianshou/highlevel/params/collector.py ================================================ from abc import ABC, abstractmethod from tianshou.algorithm import Algorithm from tianshou.data import BaseCollector, Collector, ReplayBuffer from tianshou.env import BaseVectorEnv class CollectorFactory(ABC): @abstractmethod def create_collector( self, algorithm: Algorithm, vector_env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> BaseCollector: """ Creates a collector for the given algorithm and vectorized environment. :param algorithm: the algorithm :param vector_env: the vectorized environment :param buffer: the replay buffer to be used by the collector; if None, a new buffer will be created with default parameters :param exploration_noise: whether action shall be modified using the policy's exploration noise :return: the collector """ class CollectorFactoryDefault(CollectorFactory): def create_collector( self, algorithm: Algorithm, vector_env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> BaseCollector: return Collector( algorithm.policy, vector_env, buffer=buffer, exploration_noise=exploration_noise ) ================================================ FILE: tianshou/highlevel/params/dist_fn.py ================================================ from abc import ABC, abstractmethod from collections.abc import Callable from typing import Any import torch from sensai.util.string import ToStringMixin from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrete, TDistFnDiscrOrCont from tianshou.highlevel.env import Environments class DistributionFunctionFactory(ToStringMixin, ABC): # True return type defined in subclasses @abstractmethod def create_dist_fn( self, envs: Environments, ) -> Callable[[Any], torch.distributions.Distribution]: pass class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): def __init__(self, is_probs_input: bool = True): """ :param is_probs_input: If True, the distribution function shall create a categorical distribution from a tensor containing probabilities; otherwise the tensor is assumed to contain logits. """ self.is_probs_input = is_probs_input def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete: envs.get_type().assert_discrete(self) if self.is_probs_input: return self._dist_fn_probs else: return self._dist_fn # NOTE: Do not move/rename because a reference to the function can appear in persisted policies @staticmethod def _dist_fn(logits: torch.Tensor) -> torch.distributions.Categorical: return torch.distributions.Categorical(logits=logits) # NOTE: Do not move/rename because a reference to the function can appear in persisted policies @staticmethod def _dist_fn_probs(probs: torch.Tensor) -> torch.distributions.Categorical: return torch.distributions.Categorical(probs=probs) class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory): def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: envs.get_type().assert_continuous(self) return self._dist_fn # NOTE: Do not move/rename because a reference to the function can appear in persisted policies @staticmethod def _dist_fn( loc_scale: tuple[torch.Tensor, torch.Tensor], ) -> torch.distributions.Distribution: loc, scale = loc_scale return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1) ================================================ FILE: tianshou/highlevel/params/env_param.py ================================================ """Factories for the generation of environment-dependent parameters.""" from abc import ABC, abstractmethod from typing import Generic, TypeVar from sensai.util.string import ToStringMixin from tianshou.highlevel.env import ContinuousEnvironments, Environments TValue = TypeVar("TValue") TEnvs = TypeVar("TEnvs", bound=Environments) class EnvValueFactory(Generic[TValue, TEnvs], ToStringMixin, ABC): @abstractmethod def create_value(self, envs: TEnvs) -> TValue: pass class FloatEnvValueFactory(EnvValueFactory[float, TEnvs], Generic[TEnvs], ABC): """Serves as a type bound for float value factories.""" class FloatEnvValueFactoryMaxActionScaled(FloatEnvValueFactory[ContinuousEnvironments]): def __init__(self, value: float): """:param value: value with which to scale the max action value""" self.value = value def create_value(self, envs: ContinuousEnvironments) -> float: envs.get_type().assert_continuous(self) return envs.max_action * self.value class MaxActionScaled(FloatEnvValueFactoryMaxActionScaled): pass ================================================ FILE: tianshou/highlevel/params/lr_scheduler.py ================================================ from abc import ABC, abstractmethod from sensai.util.string import ToStringMixin from tianshou.algorithm.optim import LRSchedulerFactory, LRSchedulerFactoryLinear from tianshou.highlevel.config import TrainingConfig class LRSchedulerFactoryFactory(ToStringMixin, ABC): """Factory for the creation of a learning rate scheduler factory.""" @abstractmethod def create_lr_scheduler_factory(self) -> LRSchedulerFactory: pass class LRSchedulerFactoryFactoryLinear(LRSchedulerFactoryFactory): def __init__(self, training_config: TrainingConfig): self.training_config = training_config def create_lr_scheduler_factory(self) -> LRSchedulerFactory: if ( self.training_config.epoch_num_steps is None or self.training_config.collection_step_num_env_steps is None ): raise ValueError( f"{self.__class__.__name__} requires epoch_num_steps and collection_step_num_env_steps to be set " f"in order for the scheduling to be well-defined." ) return LRSchedulerFactoryLinear( max_epochs=self.training_config.max_epochs, epoch_num_steps=self.training_config.epoch_num_steps, collection_step_num_env_steps=self.training_config.collection_step_num_env_steps, ) ================================================ FILE: tianshou/highlevel/params/noise.py ================================================ from abc import ABC, abstractmethod from sensai.util.string import ToStringMixin from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.highlevel.env import ContinuousEnvironments, Environments class NoiseFactory(ToStringMixin, ABC): @abstractmethod def create_noise(self, envs: Environments) -> BaseNoise: pass class NoiseFactoryMaxActionScaledGaussian(NoiseFactory): def __init__(self, std_fraction: float): """Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value. This factory can only be applied to continuous action spaces. :param std_fraction: fraction (between 0 and 1) of the maximum action value that shall be used as the standard deviation """ self.std_fraction = std_fraction def create_noise(self, envs: Environments) -> GaussianNoise: envs.get_type().assert_continuous(self) envs: ContinuousEnvironments return GaussianNoise(sigma=envs.max_action * self.std_fraction) class MaxActionScaledGaussian(NoiseFactoryMaxActionScaledGaussian): pass ================================================ FILE: tianshou/highlevel/params/optim.py ================================================ from abc import ABC, abstractmethod from collections.abc import Iterable from typing import Any, Protocol, TypeAlias import torch from sensai.util.string import ToStringMixin from tianshou.algorithm.optim import ( AdamOptimizerFactory, OptimizerFactory, RMSpropOptimizerFactory, TorchOptimizerFactory, ) TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] class OptimizerWithLearningRateProtocol(Protocol): def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer: pass class OptimizerFactoryFactory(ABC, ToStringMixin): @staticmethod def default() -> "OptimizerFactoryFactory": return OptimizerFactoryFactoryAdam() @abstractmethod def create_optimizer_factory(self, lr: float) -> OptimizerFactory: pass class OptimizerFactoryFactoryTorch(OptimizerFactoryFactory): def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any): """Factory for torch optimizers. :param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), which will be passed the module parameters, the learning rate as `lr` and the kwargs provided. :param kwargs: keyword arguments to provide at optimizer construction """ self.optim_class = optim_class self.kwargs = kwargs def create_optimizer_factory(self, lr: float) -> OptimizerFactory: return TorchOptimizerFactory(optim_class=self.optim_class, lr=lr) class OptimizerFactoryFactoryAdam(OptimizerFactoryFactory): def __init__( self, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, ): self.weight_decay = weight_decay self.eps = eps self.betas = betas def create_optimizer_factory(self, lr: float) -> AdamOptimizerFactory: return AdamOptimizerFactory( lr=lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, ) class OptimizerFactoryFactoryRMSprop(OptimizerFactoryFactory): def __init__( self, alpha: float = 0.99, eps: float = 1e-08, weight_decay: float = 0, momentum: float = 0, centered: bool = False, ): self.alpha = alpha self.momentum = momentum self.centered = centered self.weight_decay = weight_decay self.eps = eps def create_optimizer_factory(self, lr: float) -> RMSpropOptimizerFactory: return RMSpropOptimizerFactory( lr=lr, alpha=self.alpha, eps=self.eps, weight_decay=self.weight_decay, momentum=self.momentum, centered=self.centered, ) ================================================ FILE: tianshou/highlevel/persistence.py ================================================ import logging from abc import ABC, abstractmethod from collections.abc import Callable from enum import Enum from pathlib import Path from typing import TYPE_CHECKING import torch from tianshou.highlevel.world import World if TYPE_CHECKING: from tianshou.highlevel.module.core import TDevice log = logging.getLogger(__name__) class PersistEvent(Enum): """Enumeration of persistence events that Persistence objects can react to.""" PERSIST_POLICY = "persist_policy" """Policy neural network is persisted (new best found)""" class RestoreEvent(Enum): """Enumeration of restoration events that Persistence objects can react to.""" RESTORE_POLICY = "restore_policy" """Policy neural network parameters are restored""" class Persistence(ABC): @abstractmethod def persist(self, event: PersistEvent, world: World) -> None: pass @abstractmethod def restore(self, event: RestoreEvent, world: World) -> None: pass class PersistenceGroup(Persistence): """Groups persistence handler such that they can be applied collectively.""" def __init__(self, *p: Persistence, enabled: bool = True): self.items = p self.enabled = enabled def persist(self, event: PersistEvent, world: World) -> None: if not self.enabled: return for item in self.items: item.persist(event, world) def restore(self, event: RestoreEvent, world: World) -> None: for item in self.items: item.restore(event, world) class PolicyPersistence: class Mode(Enum): """Mode of persistence.""" POLICY_STATE_DICT = "policy_state_dict" """Persist only the policy's state dictionary. Note that for a policy to be restored from such a dictionary, it is necessary to first create a structurally equivalent object which can accept the respective state.""" POLICY = "policy" """Persist the entire policy. This is larger but has the advantage of the policy being loadable without requiring an environment to be instantiated. It has the potential disadvantage that upon breaking code changes in the policy implementation (e.g. renamed/moved class), it will no longer be loadable. Note that a precondition is that the policy be picklable in its entirety. """ def get_filename(self) -> str: return self.value + ".pt" def __init__( self, additional_persistence: Persistence | None = None, enabled: bool = True, mode: Mode = Mode.POLICY, ): """Handles persistence of the policy. :param additional_persistence: a persistence instance which is to be invoked whenever this object is used to persist/restore data :param enabled: whether persistence is enabled (restoration is always enabled) :param mode: the persistence mode """ self.additional_persistence = additional_persistence self.enabled = enabled self.mode = mode def persist(self, policy: torch.nn.Module, world: World) -> None: if not self.enabled: return path = world.persist_path(self.mode.get_filename()) match self.mode: case self.Mode.POLICY_STATE_DICT: log.info(f"Saving policy state dictionary in {path}") torch.save(policy.state_dict(), path) case self.Mode.POLICY: log.info(f"Saving policy object in {path}") torch.save(policy, path) case _: raise NotImplementedError if self.additional_persistence is not None: self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world) def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None: path = world.restore_path(self.mode.get_filename()) log.info(f"Restoring policy from {path}") match self.mode: case self.Mode.POLICY_STATE_DICT: state_dict = torch.load(path, map_location=device) case self.Mode.POLICY: loaded_policy: torch.nn.Module = torch.load(path, map_location=device) state_dict = loaded_policy.state_dict() case _: raise NotImplementedError policy.load_state_dict(state_dict) if self.additional_persistence is not None: self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world) def get_save_best_fn(self, world: World) -> Callable[[torch.nn.Module], None]: def save_best_fn(pol: torch.nn.Module) -> None: self.persist(pol, world) return save_best_fn def get_save_checkpoint_fn(self, world: World) -> Callable[[int, int, int], str] | None: if not self.enabled: return None def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: path = Path(self.mode.get_filename()) path_with_epoch = path.with_stem(f"{path.stem}_epoch_{epoch}") path = world.persist_path(path_with_epoch.name) match self.mode: case self.Mode.POLICY_STATE_DICT: log.info(f"Saving policy state dictionary in {path}") torch.save(world.algorithm.state_dict(), path) case self.Mode.POLICY: log.info(f"Saving policy object in {path}") torch.save(world.algorithm, path) case _: raise NotImplementedError if self.additional_persistence is not None: self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world) return path return save_checkpoint_fn ================================================ FILE: tianshou/highlevel/trainer.py ================================================ import logging from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass from typing import TypeVar, cast from sensai.util.string import ToStringMixin from tianshou.algorithm import DQN, Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) log = logging.getLogger(__name__) class TrainingContext: def __init__(self, algorithm: TAlgorithm, envs: Environments, logger: TLogger): self.algorithm = algorithm self.envs = envs self.logger = logger class EpochTrainCallback(ToStringMixin, ABC): """Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase of each epoch. """ @abstractmethod def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: pass def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]: def fn(epoch: int, env_step: int) -> None: return self.callback(epoch, env_step, context) return fn class EpochTestCallback(ToStringMixin, ABC): """Callback which is called at the beginning of the test phase of each epoch.""" @abstractmethod def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: pass def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]: def fn(epoch: int, env_step: int | None) -> None: return self.callback(epoch, env_step, context) return fn class EpochStopCallback(ToStringMixin, ABC): """Callback which is called after the test phase of each epoch in order to determine whether training should stop early. """ @abstractmethod def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: """Determines whether training should stop. :param mean_rewards: the average undiscounted returns of the testing result :param context: the training context :return: True if the goal has been reached and training should stop, False otherwise """ def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]: def fn(mean_rewards: float) -> bool: return self.should_stop(mean_rewards, context) return fn @dataclass class TrainerCallbacks: """Container for callbacks used during training.""" epoch_train_callback: EpochTrainCallback | None = None epoch_test_callback: EpochTestCallback | None = None epoch_stop_callback: EpochStopCallback | None = None class EpochTrainCallbackDQNSetEps(EpochTrainCallback): """Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch. """ def __init__(self, eps: float): self.eps = eps def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: algorithm = cast(DQN, context.algorithm) policy: DiscreteQLearningPolicy = algorithm.policy policy.set_eps_training(self.eps) class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback): """Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch, using a linear decay in the first `decay_steps` steps. """ def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000): self.eps_train = eps_train self.eps_train_final = eps_train_final self.decay_steps = decay_steps def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: algorithm = cast(DQN, context.algorithm) policy: DiscreteQLearningPolicy = algorithm.policy logger = context.logger if env_step <= self.decay_steps: eps = self.eps_train - env_step / self.decay_steps * ( self.eps_train - self.eps_train_final ) else: eps = self.eps_train_final policy.set_eps_training(eps) logger.write("train/env_step", env_step, {"train/eps": eps}) class EpochTestCallbackDQNSetEps(EpochTestCallback): """Sets the epsilon value for DQN-based policies at the beginning of the test stage in each epoch. """ def __init__(self, eps: float): self.eps = eps def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: algorithm = cast(DQN, context.algorithm) policy: DiscreteQLearningPolicy = algorithm.policy policy.set_eps_inference(self.eps) class EpochStopCallbackRewardThreshold(EpochStopCallback): """Stops training once the mean rewards exceed the given reward threshold or the threshold that is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`). """ def __init__(self, threshold: float | None = None): """ :param threshold: the reward threshold beyond which to stop training. If it is None, will use threshold specified by the environment, i.e. `env.spec.reward_threshold`. """ self.threshold = threshold def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: threshold = self.threshold if threshold is None: threshold = context.envs.env.spec.reward_threshold # type: ignore assert threshold is not None is_reached = mean_rewards >= threshold if is_reached: log.info(f"Reward threshold ({threshold}) exceeded") return is_reached ================================================ FILE: tianshou/highlevel/world.py ================================================ import os from dataclasses import dataclass from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from tianshou.algorithm import Algorithm from tianshou.data import BaseCollector from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger from tianshou.trainer import Trainer @dataclass(kw_only=True) class World: """Container for instances and configuration items that are relevant to an experiment.""" envs: "Environments" algorithm: "Algorithm" training_collector: Optional["BaseCollector"] = None test_collector: Optional["BaseCollector"] = None logger: "TLogger" persist_directory: str restore_directory: str | None trainer: Optional["Trainer"] = None def persist_path(self, filename: str) -> str: return os.path.abspath(os.path.join(self.persist_directory, filename)) def restore_path(self, filename: str) -> str: if self.restore_directory is None: raise ValueError( "Path cannot be formed because no directory for restoration was provided", ) return os.path.join(self.restore_directory, filename) ================================================ FILE: tianshou/py.typed ================================================ ================================================ FILE: tianshou/trainer.py ================================================ """ This module contains Tianshou's trainer classes, which orchestrate the training and call upon an RL algorithm's specific network updating logic to perform the actual gradient updates. Training is structured as follows (hierarchical glossary): - **epoch**: the outermost iteration level of the training loop. Each epoch consists of a number of training steps and one test step (see :attr:`TrainerParams.max_epoch` for a detailed explanation). - **training step**: a training step performs the steps necessary in order to apply a single update of the neural network components as defined by the underlying RL algorithm (:class:`Algorithm`). This involves the following sub-steps: - for online learning algorithms: - **collection step**: collecting environment steps/transitions to be used for training. - (Potentially) a test step (see below) if the early stopping criterion is satisfied based on the data collected (see :attr:`OnlineTrainerParams.test_in_train`). - **update step**: applying the actual gradient updates using the RL algorithm. The update is based on either: - data from only the preceding collection step (on-policy learning), - data from the collection step and previously collected data (off-policy learning), or - data from the user-provided replay buffer (offline learning). For offline learning algorithms, a training step is thus equivalent to an update step. - **test step**: collects test episodes from dedicated test environments which are used to evaluate the performance of the policy. Optionally, the performance result can be used to determine whether training shall stop early (see :attr:`TrainerParams.stop_fn`). """ import logging import time from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Callable from dataclasses import asdict, dataclass from functools import partial from typing import Generic, TypeVar import numpy as np import torch import tqdm from sensai.util.helper import count_none from sensai.util.pickle import setstate from sensai.util.string import ToStringMixin from tianshou.algorithm.algorithm_base import ( Algorithm, OfflineAlgorithm, OffPolicyAlgorithm, OnPolicyAlgorithm, TrainingStats, ) from tianshou.data import ( AsyncCollector, CollectStats, EpochStats, InfoStats, ReplayBuffer, SequenceSummaryStats, TimingStats, ) from tianshou.data.buffer.buffer_base import MalformedBufferError from tianshou.data.collector import BaseCollector, CollectStatsBase from tianshou.utils import ( BaseLogger, LazyLogger, MovAvg, ) from tianshou.utils.determinism import TraceLogger, torch_param_hash from tianshou.utils.logging import set_numerical_fields_to_precision from tianshou.utils.torch_utils import policy_within_training_step log = logging.getLogger(__name__) @dataclass(kw_only=True) class TrainerParams(ToStringMixin): max_epochs: int = 100 """ the (maximum) number of epochs to run training for. An **epoch** is the outermost iteration level and each epoch consists of a number of training steps and one test step, where each training step * [for the online case] collects environment steps/transitions (**collection step**), adding them to the (replay) buffer (see :attr:`collection_step_num_env_steps` and :attr:`collection_step_num_episodes`) * performs an **update step** via the RL algorithm being used, which can involve one or more actual gradient updates, depending on the algorithm and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate agent performance. Training may be stopped early if the stop criterion is met (see :attr:`stop_fn`). For online training, the number of training steps in each epoch is indirectly determined by :attr:`epoch_num_steps`: As many training steps will be performed as are required in order to reach :attr:`epoch_num_steps` total steps in the training environments. Specifically, if the number of transitions collected per step is `c` (see :attr:`collection_step_num_env_steps`) and :attr:`epoch_num_steps` is set to `s`, then the number of training steps per epoch is `ceil(s / c)`. Therefore, if `max_epochs = e`, the total number of environment steps taken during training can be computed as `e * ceil(s / c) * c`. For offline training, the number of training steps per epoch is equal to :attr:`epoch_num_steps`. """ epoch_num_steps: int = 30000 """ For an online algorithm, this is the total number of environment steps to be collected per epoch, and, for an offline algorithm, it is the total number of training steps to take per epoch. See :attr:`max_epochs` for an explanation of epoch semantics. """ test_collector: BaseCollector | None = None """ the collector to use for test episode collection (test steps); if None, perform no test steps. """ test_step_num_episodes: int = 1 """the number of episodes to collect in each test step. """ training_fn: Callable[[int, int], None] | None = None """ a callback function which is called at the beginning of each training step. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. """ test_fn: Callable[[int, int | None], None] | None = None """ a callback function to be called at the beginning of each test step. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. """ stop_fn: Callable[[float], bool] | None = None """ a callback function with signature ``f(score: float) -> bool``, which is used to decide whether training shall be stopped early based on the score achieved in a test step. The score it receives is computed by the :attr:`compute_score_fn` callback (which defaults to the mean reward if the function is not provided). Requires test steps to be activated and thus :attr:`test_collector` to be set. Note: The function is also used when :attr:`test_in_train` is activated (see docstring). """ compute_score_fn: Callable[[CollectStats], float] | None = None """ the callback function to use in order to compute the test batch performance score, which is used to determine what the best model is (score is maximized); if None, use the mean reward. """ save_best_fn: Callable[["Algorithm"], None] | None = None """ the callback function to call in order to save the best model whenever a new best score (see :attr:`compute_score_fn`) is achieved in a test step. It should have the signature ``f(algorithm: Algorithm) -> None``. """ save_checkpoint_fn: Callable[[int, int, int], str] | None = None """ the callback function with which to save checkpoint data after each training step, which can save whatever data is desired to a file and returns the path of the file. Signature: ``f(epoch: int, env_step: int, gradient_step: int) -> str``. """ resume_from_log: bool = False """ whether to load env_step/gradient_step and other metadata from the existing log, which is given in :attr:`logger`. """ multi_agent_return_reduction: Callable[[np.ndarray], np.ndarray] | None = None """ a function with signature ``f(returns: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, which is used in multi-agent RL. We need to return a single scalar for each episode's return to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the return achieved by agent 1 or the average return over all agents. """ logger: BaseLogger | None = None """ the logger with which to log statistics during training/testing/updating. To not log anything, use None. Relevant step types for logger update intervals: * `update_interval`: update step * `training_interval`: env step * `test_interval`: env step """ verbose: bool = True """ whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the `logging` Python module). """ show_progress: bool = True """ whether to display a progress bars during training. """ def __setstate__(self, state: dict) -> None: setstate(TrainerParams, self, state, renamed_properties={"train_fn": "training_fn"}) def __post_init__(self) -> None: if self.resume_from_log and self.logger is None: raise ValueError("Cannot resume from log without a logger being provided") if self.test_collector is None: if self.stop_fn is not None: raise ValueError( "stop_fn cannot be activated without test steps being enabled (test_collector being set)" ) if self.test_fn is not None: raise ValueError( "test_fn is set while test steps are disabled (test_collector is None)" ) if self.save_best_fn is not None: raise ValueError( "save_best_fn is set while test steps are disabled (test_collector is None)" ) else: if self.test_step_num_episodes < 1: raise ValueError( "test_step_num_episodes must be positive if test steps are enabled " "(test_collector not None)" ) @dataclass(kw_only=True) class OnlineTrainerParams(TrainerParams): training_collector: BaseCollector """ the collector with which to gather new data for training in each training step """ collection_step_num_env_steps: int | None = 2048 """ the number of environment steps/transitions to collect in each collection step before the network update within each training step. This is mutually exclusive with :attr:`collection_step_num_episodes`, and one of the two must be set. Note that the exact number can be reached only if this is a multiple of the number of training environments being used, as each training environment will produce the same (non-zero) number of transitions. Specifically, if this is set to `n` and `m` training environments are used, then the total number of transitions collected per collection step is `ceil(n / m) * m =: c`. See :attr:`max_epochs` for information on the total number of environment steps being collected during training. """ collection_step_num_episodes: int | None = None """ the number of episodes to collect in each collection step before the network update within each training step. If this is set, the number of environment steps collected in each collection step is the sum of the lengths of the episodes collected. This is mutually exclusive with :attr:`collection_step_num_env_steps`, and one of the two must be set. """ test_in_training: bool = False """ Whether to apply a test step within a training step depending on the early stopping criterion (given by :attr:`stop_fn`) being satisfied based on the data collected within the training step. Specifically, after each collect step, we check whether the early stopping criterion (:attr:`stop_fn`) would be satisfied by data we collected (provided that at least one episode was indeed completed, such that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step (collecting :attr:`test_step_num_episodes` episodes in order to evaluate performance), and if the early stopping criterion is also satisfied based on the test data, we stop training early. """ def __setstate__(self, state: dict) -> None: setstate( OnlineTrainerParams, self, state, renamed_properties={ "test_in_train": "test_in_training", "training_collector": "training_collector", }, ) def __post_init__(self) -> None: super().__post_init__() if count_none(self.collection_step_num_env_steps, self.collection_step_num_episodes) != 1: raise ValueError( "Exactly one of {collection_step_num_env_steps, collection_step_num_episodes} must be set" ) if self.test_in_training and (self.test_collector is None or self.stop_fn is None): raise ValueError("test_in_training requires test_collector and stop_fn to be set") @dataclass(kw_only=True) class OnPolicyTrainerParams(OnlineTrainerParams): batch_size: int | None = 64 """ Use mini-batches of this size for gradient updates (causing the gradient to be less accurate, a form of regularization). Set ``batch_size=None`` for the full buffer that was collected within the training step to be used for the gradient update (no mini-batching). """ update_step_num_repetitions: int = 1 """ controls, within one update step of an on-policy algorithm, the number of times the full collected data is applied for gradient updates, i.e. if the parameter is 5, then the collected data shall be used five times to update the policy within the same update step. """ @dataclass(kw_only=True) class OffPolicyTrainerParams(OnlineTrainerParams): batch_size: int = 64 """ the the number of environment steps/transitions to sample from the buffer for a gradient update. """ update_step_num_gradient_steps_per_sample: float = 1.0 """ the number of gradient steps to perform per sample collected (see :attr:`collection_step_num_env_steps`). Specifically, if this is set to `u` and the number of samples collected in the preceding collection step is `n`, then `round(u * n)` gradient steps will be performed. """ @dataclass(kw_only=True) class OfflineTrainerParams(TrainerParams): buffer: ReplayBuffer """ the replay buffer with environment steps to use as training data for offline learning. This buffer will be pre-processed using the RL algorithm's pre-processing function (if any) before training. """ batch_size: int = 64 """ the number of environment steps/transitions to sample from the buffer for a gradient update. """ TTrainerParams = TypeVar("TTrainerParams", bound=TrainerParams) TOnlineTrainerParams = TypeVar("TOnlineTrainerParams", bound=OnlineTrainerParams) TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) class Trainer(Generic[TAlgorithm, TTrainerParams], ABC): """ Base class for trainers in Tianshou, which orchestrate the training process and call upon an RL algorithm's specific network updating logic to perform the actual gradient updates. The base class already implements the fundamental epoch logic and fully implements the test step logic, which is common to all trainers. The training step logic is left to be implemented by subclasses. """ def __init__( self, algorithm: TAlgorithm, params: TTrainerParams, ): self.algorithm = algorithm self.params = params self._logger = params.logger or LazyLogger() self._start_time = time.time() self._stat: defaultdict[str, MovAvg] = defaultdict(MovAvg) self._start_epoch = 0 self._epoch = self._start_epoch # initialize stats on the best model found during a test step # NOTE: The values don't matter, as in the first test step (which is taken in reset() # at the beginning of the training process), these will all be updated self._best_score = 0.0 self._best_reward = 0.0 self._best_reward_std = 0.0 self._best_epoch = self._start_epoch self._current_update_step = 0 """ the current (1-based) update step/training step number (to be incremented before the actual step is taken) """ self._env_step = 0 """ the step counter which is used to track progress of the training process. For online learning (i.e. on-policy and off-policy learning), this is the total number of environment steps collected, and for offline training, it is the total number of environment steps that have been sampled from the replay buffer to perform gradient updates. """ self._policy_update_time = 0.0 self._compute_score_fn: Callable[[CollectStats], float] = ( params.compute_score_fn or self._compute_score_fn_default ) self._stop_fn_flag = False @staticmethod def _compute_score_fn_default(stat: CollectStats) -> float: """ The default score function, which returns the mean return/reward. :param stat: the collection stats :return: the mean return """ assert stat.returns_stat is not None # for mypy return stat.returns_stat.mean @property def _pbar(self) -> Callable[..., tqdm.tqdm]: """Use as context manager or iterator, i.e., `with self._pbar(...) as t:` or `for _ in self._pbar(...):`.""" return partial( tqdm.tqdm, dynamic_ncols=True, ascii=True, disable=not self.params.show_progress, ) def _reset_collectors(self, reset_buffer: bool = False) -> None: if self.params.test_collector is not None: self.params.test_collector.reset(reset_buffer=reset_buffer) def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = False) -> None: """Initializes the training process. :param reset_collectors: whether to reset the collectors prior to starting the training process. Specifically, this will reset the environments in the collectors (starting new episodes), and the statistics stored in the collector. Whether the contained buffers will be reset/cleared is determined by the `reset_buffer` parameter. :param reset_collector_buffers: whether, for the case where the collectors are reset, to reset/clear the contained buffers as well. This has no effect if `reset_collectors` is False. """ TraceLogger.log(log, lambda: "Trainer reset") self._env_step = 0 self._current_update_step = 0 if self.params.resume_from_log: ( self._start_epoch, self._env_step, self._current_update_step, ) = self._logger.restore_data() self._epoch = self._start_epoch self._start_time = time.time() if reset_collectors: self._reset_collectors(reset_buffer=reset_collector_buffers) # make an initial test step to determine the initial best model if self.params.test_collector is not None: assert self.params.test_step_num_episodes is not None assert not isinstance(self.params.test_collector, AsyncCollector) # Issue 700 self._test_step(force_update_best=True, log_msg_prefix="Initial test step") self._stop_fn_flag = False self._log_params(self.algorithm) def _log_params(self, module: torch.nn.Module) -> None: """Logs the parameters of the module to the trace logger by subcomponent (if the trace logger is enabled).""" if not TraceLogger.is_enabled: return def module_has_params(m: torch.nn.Module) -> bool: return any(p.requires_grad for p in m.parameters()) relevant_modules = {} def gather_modules(m: torch.nn.Module) -> None: for name, submodule in m.named_children(): if name == "policy": gather_modules(submodule) else: if module_has_params(submodule): relevant_modules[name] = submodule gather_modules(module) for name, module in sorted(relevant_modules.items()): TraceLogger.log( log, lambda: f"Params[{name}]: {torch_param_hash(module)}", ) class _TrainingStepResult(ABC): @abstractmethod def get_steps_in_epoch_advancement(self) -> int: """ :return: the number of steps that were done within the epoch, where the concrete semantics of what a step is depend on the type of algorithm. See docstring of `TrainerParams.epoch_num_steps`. """ @abstractmethod def get_collect_stats(self) -> CollectStats | None: pass @abstractmethod def get_training_stats(self) -> TrainingStats | None: pass @abstractmethod def is_training_done(self) -> bool: """:return: whether the early stopping criterion is satisfied and training shall stop.""" @abstractmethod def get_env_step_advancement(self) -> int: """ :return: the number of steps by which to advance the env_step counter in the trainer (see docstring of trainer attribute). The semantics depend on the type of the algorithm. """ @abstractmethod def _create_epoch_pbar_data_dict( self, training_step_result: _TrainingStepResult ) -> dict[str, str]: pass def _create_info_stats( self, ) -> InfoStats: test_collector = self.params.test_collector if isinstance(self.params, OnlineTrainerParams): training_collector = self.params.training_collector else: training_collector = None duration = max(0.0, time.time() - self._start_time) test_time = 0.0 update_speed = 0.0 train_time_collect = 0.0 if test_collector is not None: test_time = test_collector.collect_time if training_collector is not None: train_time_collect = training_collector.collect_time update_speed = training_collector.collect_step / (duration - test_time) timing_stat = TimingStats( total_time=duration, train_time=duration - test_time, train_time_collect=train_time_collect, train_time_update=self._policy_update_time, test_time=test_time, update_speed=update_speed, ) return InfoStats( update_step=self._current_update_step, best_score=self._best_score, best_reward=self._best_reward, best_reward_std=self._best_reward_std, train_step=training_collector.collect_step if training_collector is not None else 0, train_episode=training_collector.collect_episode if training_collector is not None else 0, test_step=test_collector.collect_step if test_collector is not None else 0, test_episode=test_collector.collect_episode if test_collector is not None else 0, timing=timing_stat, ) def execute_epoch(self) -> EpochStats: self._epoch += 1 TraceLogger.log(log, lambda: f"Epoch #{self._epoch} start") # perform the required number of steps for the epoch (`epoch_num_steps`) steps_done_in_this_epoch = 0 train_collect_stats, training_stats = None, None with self._pbar( total=self.params.epoch_num_steps, desc=f"Epoch #{self._epoch}", position=1 ) as t: while steps_done_in_this_epoch < self.params.epoch_num_steps and not self._stop_fn_flag: # perform a training step and update progress TraceLogger.log(log, lambda: "Training step") self._current_update_step += 1 training_step_result = self._training_step() steps_done_in_this_epoch += training_step_result.get_steps_in_epoch_advancement() t.update(training_step_result.get_steps_in_epoch_advancement()) self._stop_fn_flag = training_step_result.is_training_done() self._env_step += training_step_result.get_env_step_advancement() training_stats = training_step_result.get_training_stats() TraceLogger.log( log, lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict() if training_stats is not None else None}", ) self._log_params(self.algorithm) collect_stats = training_step_result.get_collect_stats() if collect_stats is not None: self._logger.log_training_data(asdict(collect_stats), self._env_step) pbar_data_dict = self._create_epoch_pbar_data_dict(training_step_result) pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) pbar_data_dict["update_step"] = str(self._current_update_step) t.set_postfix(**pbar_data_dict) test_collect_stats = None if not self._stop_fn_flag: self._logger.save_data( self._epoch, self._env_step, self._current_update_step, self.params.save_checkpoint_fn, ) # test step if self.params.test_collector is not None: test_collect_stats, self._stop_fn_flag = self._test_step() info_stats = self._create_info_stats() self._logger.log_info_data(asdict(info_stats), self._epoch) return EpochStats( epoch=self._epoch, train_collect_stat=train_collect_stats, test_collect_stat=test_collect_stats, training_stat=training_stats, info_stat=info_stats, ) def _should_stop_training_early( self, *, score: float | None = None, collect_stats: CollectStats | None = None ) -> bool: """ Determine whether, given the early stopping criterion stop_fn, training shall be stopped early based on the score achieved or the collection stats (from which the score could be computed). """ # If no stop criterion is defined, we can never stop training early if self.params.stop_fn is None: return False if score is None: if collect_stats is None: raise ValueError("Must provide collect_stats if score is not given") # If no episodes were collected, we have no episode returns and thus cannot compute a score if collect_stats.n_collected_episodes == 0: return False score = self._compute_score_fn(collect_stats) return self.params.stop_fn(score) def _collect_test_episodes( self, ) -> CollectStats: assert self.params.test_collector is not None collector = self.params.test_collector collector.reset(reset_stats=False) if self.params.test_fn: self.params.test_fn(self._epoch, self._env_step) result = collector.collect(n_episode=self.params.test_step_num_episodes) if self.params.multi_agent_return_reduction: rew = self.params.multi_agent_return_reduction(result.returns) result.returns = rew result.returns_stat = SequenceSummaryStats.from_sequence(rew) if self._logger and self._env_step is not None: assert result.n_collected_episodes > 0 self._logger.log_test_data(asdict(result), self._env_step) return result def _test_step( self, force_update_best: bool = False, log_msg_prefix: str | None = None ) -> tuple[CollectStats, bool]: """Performs one test step. :param log_msg_prefix: a prefix to prepend to the log message, which is to establish the context within which the test step is being carried out :param force_update_best: whether to force updating of the best model stats (best score, reward, etc.) and call the `save_best_fn` callback """ assert self.params.test_step_num_episodes is not None assert self.params.test_collector is not None # collect test episodes test_stat = self._collect_test_episodes() assert test_stat.returns_stat is not None # for mypy # check whether we have a new best score and, if so, update stats and save the model # (or if forced) rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std score = self._compute_score_fn(test_stat) if score > self._best_score or force_update_best: self._best_score = score self._best_epoch = self._epoch self._best_reward = float(rew) self._best_reward_std = rew_std if self.params.save_best_fn: self.params.save_best_fn(self.algorithm) # log results cur_info, best_info = "", "" if score != rew: cur_info, best_info = ( f", score: {score: .6f}", f", best_score: {self._best_score:.6f}", ) if log_msg_prefix is None: log_msg_prefix = f"Epoch #{self._epoch}" log_msg = ( f"{log_msg_prefix}: test_reward: {rew:.6f} ± {rew_std:.6f},{cur_info}" f" best_reward: {self._best_reward:.6f} ± " f"{self._best_reward_std:.6f}{best_info} in #{self._best_epoch}" ) log.info(log_msg) if self.params.verbose: print(log_msg, flush=True) # determine whether training shall be stopped early stop_fn_flag = self._should_stop_training_early(score=self._best_score) return test_stat, stop_fn_flag @abstractmethod def _training_step(self) -> _TrainingStepResult: """Performs one training step.""" def _update_moving_avg_stats_and_log_update_data(self, update_stat: TrainingStats) -> None: """Log losses, update moving average stats, and also modify the smoothed_loss in update_stat.""" cur_losses_dict = update_stat.get_loss_stats_dict() update_stat.smoothed_loss = self._update_moving_avg_stats_and_get_averaged_data( cur_losses_dict, ) self._logger.log_update_data(asdict(update_stat), self._current_update_step) # TODO: seems convoluted, there should be a better way of dealing with the moving average stats def _update_moving_avg_stats_and_get_averaged_data( self, data: dict[str, float], ) -> dict[str, float]: """Add entries to the moving average object in the trainer and retrieve the averaged results. :param data: any entries to be tracked in the moving average object. :return: A dictionary containing the averaged values of the tracked entries. """ smoothed_data = {} for key, loss_item in data.items(): self._stat[key].add(loss_item) smoothed_data[key] = self._stat[key].get() return smoothed_data def run( self, reset_collectors: bool = True, reset_collector_buffers: bool = False ) -> InfoStats: """Runs the training process with the configuration given at construction. :param reset_collectors: whether to reset the collectors prior to starting the training process. Specifically, this will reset the environments in the collectors (starting new episodes), and the statistics stored in the collector. Whether the contained buffers will be reset/cleared is determined by the `reset_buffer` parameter. :param reset_collector_buffers: whether, for the case where the collectors are reset, to reset/clear the contained buffers as well. This has no effect if `reset_collectors` is False. """ self.reset( reset_collectors=reset_collectors, reset_collector_buffers=reset_collector_buffers, ) while self._epoch < self.params.max_epochs and not self._stop_fn_flag: self.execute_epoch() return self._create_info_stats() class OfflineTrainer(Trainer[OfflineAlgorithm, OfflineTrainerParams]): """An offline trainer, which samples mini-batches from a given buffer and passes them to the algorithm's update function. """ def __init__( self, algorithm: OfflineAlgorithm, params: OfflineTrainerParams, ): super().__init__(algorithm, params) self._buffer = algorithm.process_buffer(self.params.buffer) class _TrainingStepResult(Trainer._TrainingStepResult): def __init__(self, training_stats: TrainingStats, env_step_advancement: int): self._training_stats = training_stats self._env_step_advancement = env_step_advancement def get_steps_in_epoch_advancement(self) -> int: return 1 def get_collect_stats(self) -> None: return None def get_training_stats(self) -> TrainingStats: return self._training_stats def is_training_done(self) -> bool: return False def get_env_step_advancement(self) -> int: return self._env_step_advancement def _training_step(self) -> _TrainingStepResult: with policy_within_training_step(self.algorithm.policy): # Note: since sample_size=batch_size, this will perform # exactly one gradient step. This is why we don't need to calculate the # number of gradient steps, like in the on-policy case. training_stats = self.algorithm.update( sample_size=self.params.batch_size, buffer=self._buffer ) self._update_moving_avg_stats_and_log_update_data(training_stats) self._policy_update_time += training_stats.train_time return self._TrainingStepResult( training_stats=training_stats, env_step_advancement=self.params.batch_size, ) def _create_epoch_pbar_data_dict( self, training_step_result: Trainer._TrainingStepResult ) -> dict[str, str]: return {} class OnlineTrainer( Trainer[TAlgorithm, TOnlineTrainerParams], Generic[TAlgorithm, TOnlineTrainerParams], ABC, ): """ An online trainer, which collects data from the environment in each training step and uses the collected data to perform an update step, the nature of which is to be defined in subclasses. """ def __init__( self, algorithm: TAlgorithm, params: TOnlineTrainerParams, ): super().__init__(algorithm, params) self._env_episode = 0 """ the total number of episodes collected in the environment """ def _reset_collectors(self, reset_buffer: bool = False) -> None: super()._reset_collectors(reset_buffer=reset_buffer) self.params.training_collector.reset(reset_buffer=reset_buffer) def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = False) -> None: super().reset( reset_collectors=reset_collectors, reset_collector_buffers=reset_collector_buffers, ) if ( self.params.test_in_training and self.params.training_collector.policy is not self.algorithm.policy ): log.warning( "The training data collector's policy is not the same as the one being trained, " "yet test_in_training is enabled. This may lead to unexpected results." ) self._env_episode = 0 class _TrainingStepResult(Trainer._TrainingStepResult): def __init__( self, collect_stats: CollectStats, training_stats: TrainingStats | None, is_training_done: bool, ): self._collect_stats = collect_stats self._training_stats = training_stats self._is_training_done = is_training_done def get_steps_in_epoch_advancement(self) -> int: return self.get_env_step_advancement() def get_collect_stats(self) -> CollectStats: return self._collect_stats def get_training_stats(self) -> TrainingStats | None: return self._training_stats def is_training_done(self) -> bool: return self._is_training_done def get_env_step_advancement(self) -> int: return self._collect_stats.n_collected_steps def _training_step(self) -> _TrainingStepResult: """Perform one training step. For an online algorithm, a training step involves: * collecting data * for the case where `test_in_train` is activated, determining whether the stop condition has been reached (and returning without performing any actual training if so) * performing a gradient update step """ with policy_within_training_step(self.algorithm.policy): # collect data collect_stats = self._collect_training_data() # determine whether we should stop training based on the data collected should_stop_training = False if self.params.test_in_training: should_stop_training = self._test_in_train(collect_stats) # perform gradient update step (if not already done) training_stats: TrainingStats | None = None if not should_stop_training: training_stats = self._update_step(collect_stats) return self._TrainingStepResult( collect_stats=collect_stats, training_stats=training_stats, is_training_done=should_stop_training, ) def _collect_training_data(self) -> CollectStats: """Performs training data collection. :return: the data collection stats """ assert self.params.test_step_num_episodes is not None assert self.params.training_collector is not None if self.params.training_fn: self.params.training_fn(self._epoch, self._env_step) collect_stats = self.params.training_collector.collect( n_step=self.params.collection_step_num_env_steps, n_episode=self.params.collection_step_num_episodes, ) TraceLogger.log( log, lambda: f"Collected {collect_stats.n_collected_steps} steps, {collect_stats.n_collected_episodes} episodes", ) if self.params.training_collector.buffer.hasnull(): from tianshou.data.collector import EpisodeRolloutHook from tianshou.env import DummyVectorEnv raise MalformedBufferError( f"Encountered NaNs in buffer after {self._env_step} steps." f"Such errors are usually caused by either a bug in the environment or by " f"problematic implementations {EpisodeRolloutHook.__class__.__name__}. " f"For debugging such issues it is recommended to run the training in a single process, " f"e.g., by using {DummyVectorEnv.__class__.__name__}.", ) if collect_stats.n_collected_episodes > 0: assert collect_stats.returns_stat is not None # for mypy assert collect_stats.lens_stat is not None # for mypy if self.params.multi_agent_return_reduction: rew = self.params.multi_agent_return_reduction(collect_stats.returns) collect_stats.returns = rew collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) # update collection stats specific to this specialization self._env_episode += collect_stats.n_collected_episodes return collect_stats def _test_in_train( self, train_collect_stats: CollectStats, ) -> bool: """ Performs a test step if the data collected in the current training step suggests that performance is good enough to stop training early. If the test step confirms that performance is indeed good enough, returns True, and False otherwise. Specifically, applies the early stopping criterion to the data collected in the current training step, and if the criterion is satisfied, performs a test step which returns the relevant result. :param train_collect_stats: the data collection stats from the preceding collection step :return: flag indicating whether to stop training early """ should_stop_training = False # check whether the stop criterion is satisfied based on the data collected in the training step # (if any full episodes were indeed collected) if train_collect_stats.n_collected_episodes > 0 and self._should_stop_training_early( collect_stats=train_collect_stats ): # apply a test step, temporarily switching out of "is_training_step" semantics such that the policy can # be evaluated, in order to determine whether we should stop training with policy_within_training_step(self.algorithm.policy, enabled=False): _, should_stop_training = self._test_step( log_msg_prefix=f"Test step triggered by train stats (env_step={self._env_step})" ) return should_stop_training @abstractmethod def _update_step( self, collect_stats: CollectStatsBase, ) -> TrainingStats: """Performs a gradient update step, calling the algorithm's update method accordingly. :param collect_stats: provides info about the preceding data collection step. """ def _create_epoch_pbar_data_dict( self, training_step_result: Trainer._TrainingStepResult ) -> dict[str, str]: collect_stats = training_step_result.get_collect_stats() assert collect_stats is not None result = { "env_step": str(self._env_step), "env_episode": str(self._env_episode), "n_ep": str(collect_stats.n_collected_episodes), "n_st": str(collect_stats.n_collected_steps), } # return and episode length info is only available if at least one episode was completed if collect_stats.n_collected_episodes > 0: assert collect_stats.returns_stat is not None assert collect_stats.lens_stat is not None result.update( { "rew": f"{collect_stats.returns_stat.mean:.2f}", "len": str(int(collect_stats.lens_stat.mean)), } ) return result class OffPolicyTrainer(OnlineTrainer[OffPolicyAlgorithm, OffPolicyTrainerParams]): """An off-policy trainer, which samples mini-batches from the buffer of collected data and passes them to algorithm's `update` function. The algorithm's `update` method is expected to not perform additional mini-batching but just update model parameters from the received mini-batch. """ def _update_step( self, collect_stats: CollectStatsBase, ) -> TrainingStats: """Perform `update_step_num_gradient_steps_per_sample * n_collected_steps` gradient steps by sampling mini-batches from the buffer. :param collect_stats: the :class:`~TrainingStats` instance returned by the last gradient step. Some values in it will be replaced by their moving averages. """ assert self.params.training_collector is not None n_collected_steps = collect_stats.n_collected_steps n_gradient_steps = round( self.params.update_step_num_gradient_steps_per_sample * n_collected_steps ) if n_gradient_steps == 0: raise ValueError( f"n_gradient_steps is 0, n_collected_steps={n_collected_steps}, " f"update_step_num_gradient_steps_per_sample={self.params.update_step_num_gradient_steps_per_sample}", ) update_stat = None disable_pbar = n_gradient_steps < 20 # only show progress bar if there are many steps for _ in self._pbar( range(n_gradient_steps), desc="Offpolicy gradient update", position=0, leave=False, disable=disable_pbar, ): update_stat = self._sample_and_update(self.params.training_collector.buffer) self._policy_update_time += update_stat.train_time # TODO: only the last update_stat is returned, should be improved assert update_stat is not None return update_stat def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" # Note: since sample_size=batch_size, this will perform # exactly one gradient step. This is why we don't need to calculate the # number of gradient steps, like in the on-policy case. update_stat = self.algorithm.update(sample_size=self.params.batch_size, buffer=buffer) self._update_moving_avg_stats_and_log_update_data(update_stat) return update_stat class OnPolicyTrainer(OnlineTrainer[OnPolicyAlgorithm, OnPolicyTrainerParams]): """An on-policy trainer, which passes the entire buffer to the algorithm's `update` methods and resets the buffer thereafter. Note that it is expected that the update method of the algorithm will perform batching when using this trainer. """ def _update_step( self, collect_stats: CollectStatsBase | None = None, ) -> TrainingStats: """Perform one on-policy update by passing the entire buffer to the algorithm's update method.""" assert self.params.training_collector is not None log.debug( f"Performing on-policy update on buffer of length {len(self.params.training_collector.buffer)}", ) training_stat = self.algorithm.update( buffer=self.params.training_collector.buffer, batch_size=self.params.batch_size, repeat=self.params.update_step_num_repetitions, ) # just for logging, no functional role self._policy_update_time += training_stat.train_time # Note 2: in the policy-update we modify the buffer, which is not very clean. # currently the modification will erase previous samples but keep things like # _ep_rew and _ep_len (b/c keep_statistics=True). This is needed since the collection might have stopped # in the middle of an episode and in the next collect iteration we need these numbers to compute correct # return and episode length values. With the current code structure, this means that after an update and buffer reset # such quantities can no longer be computed # from samples still contained in the buffer, which is also not clean self.params.training_collector.reset_buffer(keep_statistics=True) # The step is the number of mini-batches used for the update, so essentially self._update_moving_avg_stats_and_log_update_data(training_stat) return training_stat ================================================ FILE: tianshou/utils/__init__.py ================================================ """Utils package.""" from tianshou.utils.logger.logger_base import BaseLogger, LazyLogger from tianshou.utils.logger.tensorboard import TensorboardLogger from tianshou.utils.logger.wandb import WandbLogger from tianshou.utils.progress_bar import DummyTqdm, tqdm_config from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.warning import deprecation __all__ = [ "BaseLogger", "DummyTqdm", "LazyLogger", "MovAvg", "RunningMeanStd", "TensorboardLogger", "WandbLogger", "deprecation", "tqdm_config", ] ================================================ FILE: tianshou/utils/conversion.py ================================================ from typing import overload import torch @overload def to_optional_float(x: torch.Tensor) -> float: ... @overload def to_optional_float(x: float) -> float: ... @overload def to_optional_float(x: None) -> None: ... def to_optional_float(x: torch.Tensor | float | None) -> float | None: """For the common case where one needs to extract a float from a scalar Tensor, which may be None.""" if isinstance(x, torch.Tensor): return x.item() return x ================================================ FILE: tianshou/utils/determinism.py ================================================ import difflib import inspect import os import re import time from collections.abc import Callable, Sequence from dataclasses import dataclass from io import StringIO from pathlib import Path from typing import Self import torch from sensai.util import logging from sensai.util.git import GitStatus, git_status from sensai.util.pickle import dump_pickle, load_pickle def format_log_message( logger: logging.Logger, level: int, msg: str, formatter: logging.Formatter, stacklevel: int = 1, ) -> str: """ Formats a log message as it would have been created by `logger.log(level, msg)` with the given formatter. :param logger: the logger :param level: the log level :param msg: the message :param formatter: the formatter :param stacklevel: the stack level of the function to report as the generator :return: the formatted log message (not including trailing newline) """ frame_info = inspect.stack()[stacklevel] pathname = frame_info.filename lineno = frame_info.lineno func = frame_info.function record = logger.makeRecord( name=logger.name, level=level, fn=pathname, lno=lineno, msg=msg, args=(), exc_info=None, func=func, extra=None, ) record.created = time.time() record.asctime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.created)) return formatter.format(record) class TraceLogger: """Supports the collection of behavioural trace logs, which can, in particular, be used for determinism tests.""" is_enabled = False """ whether the trace logger is enabled. NOTE: The preferred way to enable this is via the context manager. """ verbose = False """ whether to print trace log messages to stdout. """ MESSAGE_TAG = "[TRACE]" """ a tag which is added at the beginning of log messages generated by this logger """ LOG_LEVEL = logging.DEBUG log_buffer: StringIO | None = None log_formatter: logging.Formatter | None = None @classmethod def log(cls, logger: logging.Logger, message_generator: Callable[[], str]) -> None: """ Logs a message intended for tracing agent-env interaction, which is enabled via `TraceAgentEnvLoggerContext`. :param logger: the logger to use for the actual logging :param message_generator: function which generates the log message (which may be expensive); if logging is disabled, the function will not be called. """ if not cls.is_enabled: return msg = message_generator() msg = cls.MESSAGE_TAG + " " + msg # Log with caller's frame info logger.log(logging.DEBUG, msg, stacklevel=2) # If a dedicated memory buffer is configured, also store the message there if cls.log_buffer is not None: msg_formatted = format_log_message( logger, logging.DEBUG, msg, cls.log_formatter, stacklevel=2, ) cls.log_buffer.write(msg_formatted + "\n") if cls.verbose: print(msg_formatted) @dataclass class TraceLog: log_lines: list[str] def save_log(self, path: str) -> None: with open(path, "w") as f: for line in self.log_lines: f.write(line + "\n") def print_log(self) -> None: for line in self.log_lines: print(line) def get_full_log(self) -> str: return "\n".join(self.log_lines) def reduce_log_to_messages(self) -> "TraceLog": """ Removes logger names and function names from the log entries, such that each log message contains only the main text message itself (starting with the content after the logger's tag). :return: the result with reduced log messages """ lines = [] tag = re.escape(TraceLogger.MESSAGE_TAG) for line in self.log_lines: lines.append(re.sub(r".*" + tag, "", line)) return TraceLog(lines) def filter_messages( self, required_messages: Sequence[str] = (), optional_messages: Sequence[str] = (), ignored_messages: Sequence[str] = (), ) -> "TraceLog": """ Applies inclusion and or exclusion filtering to the log messages. If either `required_messages` or `optional_messages` is empty, inclusion filtering is applied. If `ignored_messages` is empty, exclusion filtering is applied. If both inclusion and exclusion filtering are applied, the exclusion filtering takes precedence. :param required_messages: required message substrings to filter for; each message is required to appear at least once (triggering exception otherwise) :param optional_messages: additional messages fragments to filter for; these are not required :param ignored_messages: message fragments that result in exclusion; takes precedence over `required_messages` and `optional_messages` :return: the result with reduced log messages """ import numpy as np required_message_counters = np.zeros(len(required_messages)) def retain_line(line: str) -> bool: for ignored_message in ignored_messages: if ignored_message in line: return False if required_messages or optional_messages: for i, main_message in enumerate(required_messages): if main_message in line: required_message_counters[i] += 1 return True return any(add_message in line for add_message in optional_messages) else: return True lines = [] for line in self.log_lines: if retain_line(line): lines.append(line) assert np.all( required_message_counters > 0, ), "Not all types of required messages were found in the trace. Were log messages changed?" return TraceLog(lines) class TraceLoggerContext: """ A context manager which enables the trace logger. Apart from enabling the logging, it can optionally create a memory log buffer, such that getting the trace log is not strictly dependent on the logging system. """ def __init__( self, enable_log_buffer: bool = True, log_format: str = "%(name)s:%(funcName)s - %(message)s", ) -> None: """ :param enable_log_buffer: whether to enable the dedicated log buffer for trace logs, whose contents can, within the context of this manager, be accessed via method `get_log`. :param log_format: the logger format string to use for the dedicated log buffer """ self._enable_log_buffer = enable_log_buffer self._log_format: str = log_format self._log_buffer: StringIO | None = None def __enter__(self) -> Self: TraceLogger.is_enabled = True if self._enable_log_buffer: TraceLogger.log_buffer = StringIO() TraceLogger.log_formatter = logging.Formatter(self._log_format) self._log_buffer = TraceLogger.log_buffer return self def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore TraceLogger.is_enabled = False TraceLogger.log_buffer = None TraceLogger.log_formatter = None def get_log(self) -> TraceLog: """:return: the full trace log that was captured if `enable_log_buffer` was enabled at construction""" if self._log_buffer is None: raise Exception( "This method is only supported if the log buffer is enabled at construction", ) return TraceLog(log_lines=self._log_buffer.getvalue().split("\n")) def torch_param_hash(module: torch.nn.Module) -> str: """ Computes a hash of the parameters of the given module; parameters not requiring gradients are ignored. :param module: a torch module :return: a hex digest of the parameters of the module """ import hashlib hasher = hashlib.sha1() for param in module.parameters(): if param.requires_grad: np_array = param.detach().cpu().numpy() hasher.update(np_array.tobytes()) return hasher.hexdigest() class TraceDeterminismTest: def __init__( self, base_path: Path, core_messages: Sequence[str] = (), ignored_messages: Sequence[str] = (), log_filename: str | None = None, ) -> None: """ :param base_path: the directory where the reference results are stored (will be created if necessary) :param core_messages: message fragments that make up the core of a trace; if empty, all messages are considered core :param ignored_messages: message fragments to ignore in the trace log (if any); takes precedence over `core_messages` :param log_filename: the name of the log file to which results are to be written (if any) """ base_path.mkdir(parents=True, exist_ok=True) self.base_path = base_path self.core_messages = core_messages self.ignored_messages = ignored_messages self.log_filename = log_filename @dataclass(kw_only=True) class Result: git_status: GitStatus log: TraceLog def check( self, current_log: TraceLog, name: str, create_reference_result: bool = False, pass_if_core_messages_unchanged: bool = False, ) -> None: """ Checks the given log against the reference result for the given name. :param current_log: the result to check :param name: the name of the reference result; must be unique among all tests! :param create_reference_result: whether update the reference result with the given result """ import pytest reference_result_path = self.base_path / f"{name}.pkl.bz2" current_git_status = git_status() if create_reference_result: current_result = self.Result(git_status=current_git_status, log=current_log) dump_pickle(current_result, reference_result_path) reference_result: TraceDeterminismTest.Result = load_pickle( reference_result_path, ) reference_log = reference_result.log current_log_reduced = current_log.reduce_log_to_messages().filter_messages( ignored_messages=self.ignored_messages, ) reference_log_reduced = reference_log.reduce_log_to_messages().filter_messages( ignored_messages=self.ignored_messages, ) results: list[tuple[TraceLog, str]] = [ (reference_log_reduced, "expected"), (current_log_reduced, "current"), (reference_log, "expected_full"), (current_log, "current_full"), ] if self.core_messages: result_main_messages = current_log_reduced.filter_messages( required_messages=self.core_messages, ) reference_result_main_messages = reference_log_reduced.filter_messages( required_messages=self.core_messages, ) results.extend( [ (reference_result_main_messages, "expected_core"), (result_main_messages, "current_core"), ], ) else: result_main_messages = current_log_reduced reference_result_main_messages = reference_log_reduced logs_equivalent = current_log_reduced.get_full_log() == reference_log_reduced.get_full_log() if logs_equivalent: status_passed = True status_message = "OK" else: core_messages_unchanged = ( len(self.core_messages) > 0 and result_main_messages.get_full_log() == reference_result_main_messages.get_full_log() ) status_passed = core_messages_unchanged and pass_if_core_messages_unchanged if status_passed: status_message = "OK (core messages unchanged)" else: # save files for comparison files = [] for r, suffix in results: path = os.path.abspath(f"determinism_{name}_{suffix}.txt") r.save_log(path) files.append(path) paths_str = "\n".join(files) main_message = ( f"Please inspect the changes by diffing the log files:\n{paths_str}\n" f"If the changes are OK, enable the `create_reference_result` flag temporarily, " "rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n" ) # compute diff and add to message num_diff_lines_to_show = 30 for i, line in enumerate( difflib.unified_diff( reference_log_reduced.log_lines, current_log_reduced.log_lines, fromfile="expected.txt", tofile="current.txt", lineterm="", ), ): if i == num_diff_lines_to_show: break main_message += line + "\n" if core_messages_unchanged: status_message = ( "The behaviour log has changed, but the core messages are still the same (so this " f"probably isn't an issue). {main_message}" ) else: status_message = f"The behaviour log has changed; even the core messages are different. {main_message}" # write log message if self.log_filename: with open(self.log_filename, "a") as f: hr = "-" * 100 f.write(f"\n\n{hr}\nName: {name}\n") f.write(f"Reference state: {reference_result.git_status}\n") f.write(f"Current state: {current_git_status}\n") f.write(f"Test result: {status_message}\n") if not status_passed: pytest.fail(status_message) ================================================ FILE: tianshou/utils/lagged_network.py ================================================ from copy import deepcopy from dataclasses import dataclass from typing import Self import torch def polyak_parameter_update(tgt: torch.nn.Module, src: torch.nn.Module, tau: float) -> None: """Softly updates the parameters of a target network `tgt` with the parameters of a source network `src` using Polyak averaging: `tau * src + (1 - tau) * tgt`. :param tgt: the target network that receives the parameter update :param src: the source network whose parameters are used for the update :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being the fraction with which to retain the target network's parameters. """ for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) class EvalModeModuleWrapper(torch.nn.Module): """ A wrapper around a torch.nn.Module that forces the module to eval mode. The wrapped module supports only the forward method, attribute access is not supported. **NOTE**: It is *not* recommended to support attribute/method access beyond this via `__getattr__`, because torch.nn.Module already heavily relies on `__getattr__` to provides its own attribute access. Overriding it naively will cause problems! But it's also not necessary for our use cases; forward is enough. """ def __init__(self, m: torch.nn.Module): super().__init__() m.eval() self.module = m def forward(self, *args, **kwargs): # type: ignore self.module.eval() return self.module(*args, **kwargs) def train(self, mode: bool = True) -> Self: super().train(mode=mode) self.module.eval() # force eval mode return self @dataclass class LaggedNetworkPair: target: torch.nn.Module source: torch.nn.Module class LaggedNetworkCollection: def __init__(self) -> None: self._lagged_network_pairs: list[LaggedNetworkPair] = [] def add_lagged_network(self, source: torch.nn.Module) -> EvalModeModuleWrapper: """ Adds a lagged network to the collection, returning the target network, which is forced to eval mode. The target network is a copy of the source network, which, however, supports only the forward method (hence the type torch.nn.Module); attribute access is not supported. :param source: the source network whose parameters are to be copied to the target network :return: the target network, which supports only the forward method and is forced to eval mode """ target = deepcopy(source) self._lagged_network_pairs.append(LaggedNetworkPair(target, source)) return EvalModeModuleWrapper(target) def polyak_parameter_update(self, tau: float) -> None: """Softly updates the parameters of each target network `tgt` with the parameters of a source network `src` using Polyak averaging: `tau * src + (1 - tau) * tgt`. :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being the fraction with which to retain the target network's parameters. """ for pair in self._lagged_network_pairs: polyak_parameter_update(pair.target, pair.source, tau) def full_parameter_update(self) -> None: """Fully updates the target networks with the source networks' parameters (exact copy).""" for pair in self._lagged_network_pairs: for tgt_param, src_param in zip( pair.target.parameters(), pair.source.parameters(), strict=True ): tgt_param.data.copy_(src_param.data) ================================================ FILE: tianshou/utils/logger/__init__.py ================================================ ================================================ FILE: tianshou/utils/logger/logger_base.py ================================================ import typing from abc import ABC, abstractmethod from collections.abc import Callable from enum import StrEnum from numbers import Number import numpy as np VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray | float # It's unfortunate, but we can't use Union type in isinstance, hence we resort to this VALID_LOG_VALS = typing.get_args(VALID_LOG_VALS_TYPE) TRestoredData = dict[str, np.ndarray | dict[str, "TRestoredData"]] class DataScope(StrEnum): TRAINING = "training" TEST = "test" UPDATE = "update" INFO = "info" class BaseLogger(ABC): """The base class for any logger which is compatible with trainer. :param training_interval: the interval size (in env steps) after which log_training_data() will be called. :param test_interval: the interval size (in env steps) after which log_test_data() will be called. :param update_interval: the interval size (in env steps) after which log_update_data() will be called. :param info_interval: the interval size (in env steps) after which the method log_info() will be called. :param save_interval: the interval size (in env steps) after which the checkpoint and end of epoch related logs will be saved. """ def __init__( self, training_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, info_interval: int = 1, save_interval: int | None = None, exclude_arrays: bool = True, ) -> None: super().__init__() self.training_interval = training_interval self.test_interval = test_interval self.update_interval = update_interval self.info_interval = info_interval self.save_interval = save_interval self.exclude_arrays = exclude_arrays self.last_log_training_step = -1 self.last_log_test_step = -1 self.last_log_update_step = -1 self.last_log_info_step = -1 @abstractmethod def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: """Specify how the writer is used to log data. :param str step_type: namespace which the data dict belongs to. :param step: stands for the ordinate of the data dict. :param data: the data to write with format ``{key: value}``. """ @abstractmethod def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]: """Prepare the dict for logging by filtering out invalid data types. If necessary, reformulate the dict to be compatible with the writer. :param log_data: the dict to be prepared for logging. :return: the prepared dict. """ @abstractmethod def finalize(self) -> None: """Finalize the logger, e.g., close writers and connections.""" def log_training_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during training. :param log_data: a dict containing the information returned by the collector during the train step. :param step: stands for the timestep the collector result is logged. """ # TODO: move interval check to calling method if step - self.last_log_training_step >= self.training_interval: log_data = self.prepare_dict_for_logging(log_data) self.write(f"{DataScope.TRAINING}/env_step", step, log_data) self.last_log_training_step = step def log_test_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during evaluating. :param log_data:a dict containing the information returned by the collector during the evaluation step. :param step: stands for the timestep the collector result is logged. """ # TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer) if step - self.last_log_test_step >= self.test_interval: log_data = self.prepare_dict_for_logging(log_data) self.write(f"{DataScope.TEST}/env_step", step, log_data) self.last_log_test_step = step def log_update_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during updating. :param log_data:a dict containing the information returned during the policy update step. :param step: stands for the timestep the policy training data is logged. """ # TODO: move interval check to calling method if step - self.last_log_update_step >= self.update_interval: log_data = self.prepare_dict_for_logging(log_data) self.write(f"{DataScope.UPDATE}/update_step", step, log_data) self.last_log_update_step = step def log_info_data(self, log_data: dict, step: int) -> None: """Use writer to log global statistics. :param log_data: a dict containing information of data collected at the end of an epoch. :param step: stands for the timestep the training info is logged. """ if ( step - self.last_log_info_step >= self.info_interval ): # TODO: move interval check to calling method log_data = self.prepare_dict_for_logging(log_data) self.write(f"{DataScope.INFO}/epoch", step, log_data) self.last_log_info_step = step @abstractmethod def save_data( self, epoch: int, env_step: int, update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. :param epoch: the epoch in trainer. :param env_step: the env_step in trainer. :param update_step: the update step count in the trainer. :param function save_checkpoint_fn: a hook defined by user, see trainer documentation for detail. """ @abstractmethod def restore_data(self) -> tuple[int, int, int]: """Restore internal data if present and return the metadata from existing log for continuation of training. If it finds nothing or an error occurs during the recover process, it will return the default parameters. :return: epoch, env_step, update_step. """ @staticmethod @abstractmethod def restore_logged_data( log_path: str, ) -> TRestoredData: """Load the logged data from disk for post-processing. :return: a dict containing the logged data. """ class LazyLogger(BaseLogger): """A logger that does nothing. Used as the placeholder in trainer.""" def __init__(self) -> None: super().__init__() def prepare_dict_for_logging( self, data: dict[str, VALID_LOG_VALS_TYPE], ) -> dict[str, VALID_LOG_VALS_TYPE]: return data def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: """The LazyLogger writes nothing.""" def finalize(self) -> None: pass def save_data( self, epoch: int, env_step: int, update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: pass def restore_data(self) -> tuple[int, int, int]: return 0, 0, 0 @staticmethod def restore_logged_data(log_path: str) -> dict: return {} ================================================ FILE: tianshou/utils/logger/tensorboard.py ================================================ from collections.abc import Callable from typing import Any import numpy as np from matplotlib.figure import Figure from tensorboard.backend.event_processing import event_accumulator from torch.utils.tensorboard import SummaryWriter from tianshou.utils.logger.logger_base import ( VALID_LOG_VALS, VALID_LOG_VALS_TYPE, BaseLogger, TRestoredData, ) class TensorboardLogger(BaseLogger): """A logger that relies on tensorboard SummaryWriter by default to visualize and log statistics. :param SummaryWriter writer: the writer to log data. :param training_interval: the interval size (in env steps) after which log_training_data() will be called. :param test_interval: the interval size (in env steps) after which log_test_data() will be called. :param update_interval: the interval size (in env steps) after which log_update_data() will be called. :param info_interval: the interval size (in env steps) after which the method log_info() will be called. :param save_interval: the interval size (in env steps) after which the checkpoint and end of epoch related logs will be saved. :param write_flush: whether to flush tensorboard result after each add_scalar operation. Default to True. """ def __init__( self, writer: SummaryWriter, training_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, info_interval: int = 1, save_interval: int | None = None, write_flush: bool = True, ) -> None: super().__init__( training_interval, test_interval, update_interval, info_interval, save_interval ) self.write_flush = write_flush self.last_save_step = -1 self.writer = writer def prepare_dict_for_logging( self, input_dict: dict[str, Any], parent_key: str = "", delimiter: str = "/", exclude_arrays: bool = True, ) -> dict[str, VALID_LOG_VALS_TYPE]: """Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys. Filtering is performed with respect to valid logging data types. :param input_dict: The nested dictionary to be flattened and filtered. :param parent_key: The parent key used as a prefix before the input_dict keys. :param delimiter: The delimiter used to separate the keys. :param exclude_arrays: Whether to exclude numpy arrays from the output. :return: A flattened dictionary where the keys are compressed and values are filtered. """ result = {} def add_to_result( cur_dict: dict, prefix: str = "", ) -> None: for key, value in cur_dict.items(): if exclude_arrays and isinstance(value, np.ndarray): continue new_key = prefix + delimiter + str(key) new_key = new_key.lstrip(delimiter) if isinstance(value, dict): add_to_result( value, new_key, ) elif isinstance(value, VALID_LOG_VALS): result[new_key] = value add_to_result(input_dict, prefix=parent_key) return result def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: scope, step_name = step_type.split("/") self.writer.add_scalar(step_type, step, global_step=step) for k, v in data.items(): scope_key = f"{scope}/{k}" if isinstance(v, np.ndarray): self.writer.add_histogram(scope_key, v, global_step=step, bins="auto") elif isinstance(v, Figure): self.writer.add_figure(scope_key, v, global_step=step) else: self.writer.add_scalar(scope_key, v, global_step=step) if self.write_flush: # issue 580 self.writer.flush() # issue #482 def finalize(self) -> None: self.writer.close() def save_data( self, epoch: int, env_step: int, update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: if ( self.save_interval is not None and save_checkpoint_fn is not None and epoch - self.last_save_step >= self.save_interval ): self.last_save_step = epoch save_checkpoint_fn(epoch, env_step, update_step) self.write("save/epoch", epoch, {"save/epoch": epoch}) self.write("save/env_step", env_step, {"save/env_step": env_step}) self.write( "save/gradient_step", update_step, {"save/gradient_step": update_step}, ) def restore_data(self) -> tuple[int, int, int]: ea = event_accumulator.EventAccumulator(self.writer.log_dir) ea.Reload() try: # epoch / gradient_step epoch = ea.scalars.Items("save/epoch")[-1].step self.last_save_step = self.last_log_test_step = epoch gradient_step = ea.scalars.Items("save/gradient_step")[-1].step self.last_log_update_step = gradient_step except KeyError: epoch, gradient_step = 0, 0 try: # offline trainer doesn't have env_step env_step = ea.scalars.Items("save/env_step")[-1].step self.last_log_train_step = env_step except KeyError: env_step = 0 return epoch, env_step, gradient_step @staticmethod def restore_logged_data( log_path: str, ) -> TRestoredData: """Restores the logged data from the tensorboard log directory. The result is a nested dictionary where the keys are the tensorboard keys and the values are the corresponding numpy arrays. The keys in each level form a nested structure, where the hierarchy is represented by the slashes in the tensorboard key-strings. """ ea = event_accumulator.EventAccumulator(log_path) ea.Reload() def add_value_to_innermost_nested_dict( data_dict: dict[str, Any], key_string: str, value: Any, ) -> None: """A particular logic, walking through the keys in the `key_string` and adding the value to the `data_dict` in a nested manner, creating nested dictionaries on the fly if necessary, or updating existing ones. The value is added only to the innermost-nested dictionary. Example: ------- >>> data_dict = {} >>> add_value_to_innermost_nested_dict(data_dict, "a/b/c", 1) >>> data_dict {"a": {"b": {"c": 1}}} """ keys = key_string.split("/") cur_nested_dict = data_dict # walk through the intermediate keys to reach the innermost-nested dict, # creating nested dictionaries on the fly if necessary for k in keys[:-1]: cur_nested_dict = cur_nested_dict.setdefault(k, {}) # After the loop above, # this is the innermost-nested dict, where the value is finally set # for the last key in the key_string cur_nested_dict[keys[-1]] = value restored_data: dict[str, np.ndarray | dict] = {} for key_string in ea.scalars.Keys(): add_value_to_innermost_nested_dict( restored_data, key_string, np.array([s.value for s in ea.scalars.Items(key_string)]), ) return restored_data ================================================ FILE: tianshou/utils/logger/wandb.py ================================================ import argparse import logging import os from collections.abc import Callable from torch.utils.tensorboard import SummaryWriter from tianshou.utils import BaseLogger, TensorboardLogger from tianshou.utils.logger.logger_base import VALID_LOG_VALS_TYPE, TRestoredData log = logging.getLogger(__name__) class WandbLogger(BaseLogger): """Weights and Biases logger that sends data to https://wandb.ai/. This logger creates three panels with plots: train, test, and update. Make sure to select the correct access for each panel in weights and biases: Example of usage: :: logger = WandbLogger() logger.load(SummaryWriter(log_path)) :param training_interval: the log interval in log_training_data(). :param test_interval: the log interval in log_test_data(). :param update_interval: the log interval in log_update_data(). :param info_interval: the log interval in log_info_data(). :param save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). :param write_flush: whether to flush tensorboard result after each add_scalar operation. Default to True. :param str project: W&B project name. Default to "tianshou". :param str name: W&B run name. Default to None. If None, random name is assigned. :param str entity: W&B team/organization name. Default to None. :param str run_id: run id of W&B run to be resumed. Default to None. :param argparse.Namespace config: experiment configurations. Default to None. """ def __init__( self, training_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, info_interval: int = 1, save_interval: int | None = None, write_flush: bool = True, project: str | None = None, name: str | None = None, entity: str | None = None, run_id: str | None = None, group: str | None = None, job_type: str | None = None, config: argparse.Namespace | dict | None = None, monitor_gym: bool = True, disable_stats: bool = False, log_dir: str | None = None, ) -> None: import wandb super().__init__( training_interval, test_interval, update_interval, info_interval, save_interval ) self.last_save_step = -1 self.write_flush = write_flush self.restored = False if project is None: project = os.getenv("WANDB_PROJECT", "tianshou") wandb_run = ( wandb.init( project=project, group=group, job_type=job_type, name=name, id=run_id, resume="allow", entity=entity, sync_tensorboard=True, # monitor_gym=monitor_gym, # currently disabled until gymnasium version is bumped to >1.0.0 https://github.com/wandb/wandb/issues/7047 dir=log_dir, config=config, # type: ignore settings=wandb.Settings(x_disable_stats=disable_stats), ) if not wandb.run else wandb.run ) assert wandb_run is not None self.wandb_run = wandb_run self.wandb_run._label(repo="tianshou") self.tensorboard_logger: TensorboardLogger | None = None self.writer: SummaryWriter | None = None def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]: if self.tensorboard_logger is None: raise Exception( "`logger` needs to load the Tensorboard Writer before " "preparing data for logging. Try `logger.load(SummaryWriter(log_path))`", ) return self.tensorboard_logger.prepare_dict_for_logging(log_data) def load(self, writer: SummaryWriter) -> None: self.writer = writer self.tensorboard_logger = TensorboardLogger( writer, self.training_interval, self.test_interval, self.update_interval, self.info_interval, self.save_interval, self.write_flush, ) def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: if self.tensorboard_logger is None: raise RuntimeError( "`logger` needs to load the Tensorboard Writer before " "writing data. Try `logger.load(SummaryWriter(log_path))`", ) self.tensorboard_logger.write(step_type, step, data) def finalize(self) -> None: if self.wandb_run is not None: self.wandb_run.finish() if self.tensorboard_logger is not None: self.tensorboard_logger.finalize() def save_data( self, epoch: int, env_step: int, update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. :param epoch: the epoch in trainer. :param env_step: the env_step in trainer. :param update_step: the gradient_step in trainer. :param function save_checkpoint_fn: a hook defined by user, see trainer documentation for detail. """ import wandb if ( self.save_interval is not None and save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval ): self.last_save_step = epoch checkpoint_path = save_checkpoint_fn(epoch, env_step, update_step) checkpoint_artifact = wandb.Artifact( "run_" + self.wandb_run.id + "_checkpoint", type="model", metadata={ "save/epoch": epoch, "save/env_step": env_step, "save/gradient_step": update_step, "checkpoint_path": str(checkpoint_path), }, ) checkpoint_artifact.add_file(str(checkpoint_path)) self.wandb_run.log_artifact(checkpoint_artifact) def restore_data(self) -> tuple[int, int, int]: checkpoint_artifact = self.wandb_run.use_artifact( f"run_{self.wandb_run.id}_checkpoint:latest", ) assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist" checkpoint_artifact.download( os.path.dirname(checkpoint_artifact.metadata["checkpoint_path"]), ) try: # epoch / gradient_step epoch = checkpoint_artifact.metadata["save/epoch"] self.last_save_step = self.last_log_test_step = epoch gradient_step = checkpoint_artifact.metadata["save/gradient_step"] self.last_log_update_step = gradient_step except KeyError: epoch, gradient_step = 0, 0 try: # offline trainer doesn't have env_step env_step = checkpoint_artifact.metadata["save/env_step"] self.last_log_train_step = env_step except KeyError: env_step = 0 return epoch, env_step, gradient_step @staticmethod def restore_logged_data(log_path: str) -> TRestoredData: log.warning( "Logging data directly from W&B is not yet implemented, will use the " "TensorboardLogger to restore it from disc instead.", ) return TensorboardLogger.restore_logged_data(log_path) ================================================ FILE: tianshou/utils/logging.py ================================================ from typing import Any def set_numerical_fields_to_precision(data: dict[str, Any], precision: int = 3) -> dict[str, Any]: """Returns a copy of the given dictionary with all numerical values rounded to the given precision. Note: does not recurse into nested dictionaries. :param data: a dictionary :param precision: the precision to be used """ result = {} for k, v in data.items(): if isinstance(v, float): v = round(v, precision) result[k] = v return result ================================================ FILE: tianshou/utils/net/__init__.py ================================================ ================================================ FILE: tianshou/utils/net/common.py ================================================ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from typing import Any, Generic, TypeAlias, TypeVar, cast, no_type_check import numpy as np import torch from gymnasium import spaces from torch import nn from tianshou.data.batch import Batch from tianshou.data.types import RecurrentStateBatch, TObs from tianshou.utils.space_info import ActionSpaceInfo from tianshou.utils.torch_utils import torch_device ModuleType = type[nn.Module] ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] TActionShape: TypeAlias = Sequence[int] | int | np.int64 TLinearLayer: TypeAlias = Callable[[int, int], nn.Module] T = TypeVar("T") def miniblock( input_size: int, output_size: int = 0, norm_layer: ModuleType | None = None, norm_args: tuple[Any, ...] | dict[Any, Any] | None = None, activation: ModuleType | None = None, act_args: tuple[Any, ...] | dict[Any, Any] | None = None, linear_layer: TLinearLayer = nn.Linear, ) -> list[nn.Module]: """Construct a miniblock with given input/output-size, norm layer and activation.""" layers: list[nn.Module] = [linear_layer(input_size, output_size)] if norm_layer is not None: if isinstance(norm_args, tuple): layers += [norm_layer(output_size, *norm_args)] elif isinstance(norm_args, dict): layers += [norm_layer(output_size, **norm_args)] else: layers += [norm_layer(output_size)] if activation is not None: if isinstance(act_args, tuple): layers += [activation(*act_args)] elif isinstance(act_args, dict): layers += [activation(**act_args)] else: layers += [activation()] return layers class ModuleWithVectorOutput(nn.Module): """ A module that outputs a vector of a known size. Use `from_module` to adapt a module to this interface. """ def __init__(self, output_dim: int) -> None: """:param output_dim: the dimension of the output vector.""" super().__init__() self.output_dim = output_dim @staticmethod def from_module(module: nn.Module, output_dim: int) -> "ModuleWithVectorOutput": """ :param module: the module to adapt. :param output_dim: dimension of the output vector produced by the module. """ return ModuleWithVectorOutputAdapter(module, output_dim) def get_output_dim(self) -> int: """:return: the dimension of the output vector.""" return self.output_dim class ModuleWithVectorOutputAdapter(ModuleWithVectorOutput): """Adapts a module with vector output to provide the :class:`ModuleWithVectorOutput` interface.""" def __init__(self, module: nn.Module, output_dim: int) -> None: """ :param module: the module to adapt. :param output_dim: the dimension of the output vector produced by the module. """ super().__init__(output_dim) self.module = module def forward(self, *args: Any, **kwargs: Any) -> Any: return self.module(*args, **kwargs) class MLP(ModuleWithVectorOutput): """Simple MLP backbone.""" def __init__( self, *, input_dim: int, output_dim: int = 0, hidden_sizes: Sequence[int] = (), norm_layer: ModuleType | Sequence[ModuleType] | None = None, norm_args: ArgsType | None = None, activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU, act_args: ArgsType | None = None, linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, ) -> None: """ :param input_dim: dimension of the input vector. :param output_dim: dimension of the output vector. If set to 0, there is no explicit final linear layer and the output dimension is the last hidden layer's dimension. :param hidden_sizes: shape of MLP passed in as a list, not including input_dim and output_dim. :param norm_layer: use which normalization before activation, e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. You can also pass a list of normalization modules with the same length of hidden_sizes, to use different normalization module in different layers. Default to no normalization. :param activation: which activation to use after each layer, can be both the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. :param linear_layer: use this module as linear layer. Default to nn.Linear. :param flatten_input: whether to flatten input data. Default to True. """ if norm_layer: if isinstance(norm_layer, list): assert len(norm_layer) == len(hidden_sizes) norm_layer_list = norm_layer if isinstance(norm_args, list): assert len(norm_args) == len(hidden_sizes) norm_args_list = norm_args else: norm_args_list = [norm_args for _ in range(len(hidden_sizes))] else: norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))] norm_args_list = [norm_args for _ in range(len(hidden_sizes))] else: norm_layer_list = [None] * len(hidden_sizes) norm_args_list = [None] * len(hidden_sizes) if activation: if isinstance(activation, list): assert len(activation) == len(hidden_sizes) activation_list = activation if isinstance(act_args, list): assert len(act_args) == len(hidden_sizes) act_args_list = act_args else: act_args_list = [act_args for _ in range(len(hidden_sizes))] else: activation_list = [activation for _ in range(len(hidden_sizes))] act_args_list = [act_args for _ in range(len(hidden_sizes))] else: activation_list = [None] * len(hidden_sizes) act_args_list = [None] * len(hidden_sizes) hidden_sizes = [input_dim, *list(hidden_sizes)] model = [] for in_dim, out_dim, norm, norm_args, activ, act_args in zip( hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, norm_args_list, activation_list, act_args_list, strict=True, ): model += miniblock(in_dim, out_dim, norm, norm_args, activ, act_args, linear_layer) if output_dim > 0: model += [linear_layer(hidden_sizes[-1], output_dim)] super().__init__(output_dim or hidden_sizes[-1]) self.model = nn.Sequential(*model) self.flatten_input = flatten_input @no_type_check def forward(self, obs: np.ndarray | torch.Tensor) -> torch.Tensor: device = torch_device(self) obs = torch.as_tensor(obs, device=device, dtype=torch.float32) if self.flatten_input: obs = obs.flatten(1) return self.model(obs) TRecurrentState = TypeVar("TRecurrentState", bound=Any) class ActionReprNet(Generic[TRecurrentState], nn.Module, ABC): """Abstract base class for neural networks used to compute action-related representations from environment observations, which defines the signature of the forward method. An action-related representation can be a number of things, including: * a distribution over actions in a discrete action space in the form of a vector of unnormalized log probabilities (called "logits" in PyTorch jargon) * the Q-values of all actions in a discrete action space * the parameters of a distribution (e.g., mean and std. dev. for a Gaussian distribution) over actions in a continuous action space """ @abstractmethod def forward( self, obs: TObs, state: TRecurrentState | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor | Sequence[torch.Tensor], TRecurrentState | None]: """ The main method for tianshou to compute action representations (such as actions, inputs of distributions, Q-values, etc) from env observations. Implementations will always make use of the preprocess_net as the first processing step. :param obs: the observations from the environment as retrieved from `ObsBatchProtocol.obs`. If the environment is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors). :param state: the hidden state of the RNN, if applicable :param info: the info object from the environment step :return: a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or a representation from which it can be retrieved/sampled (e.g., mean and std for a Gaussian distribution), and hidden_state is the new hidden state of the RNN, if applicable. """ class ActionReprNetWithVectorOutput(Generic[T], ActionReprNet[T], ModuleWithVectorOutput): """A neural network for the computation of action-related representations which outputs a vector of a known size. """ def __init__(self, output_dim: int) -> None: super().__init__(output_dim) class Actor(Generic[T], ActionReprNetWithVectorOutput[T], ABC): @abstractmethod def get_preprocess_net(self) -> ModuleWithVectorOutput: """Returns the network component that is used for pre-processing, i.e. the component which produces a latent representation, which then is transformed into the final output. This is, therefore, the first part of the network which processes the input. For example, a CNN is often used in Atari examples. We need this method to be able to share latent representation computations with other networks (e.g. critics) within an algorithm. Actors that do not have a pre-processing stage can return nn.Identity() (see :class:`RandomActor` for an example). """ class Net(ActionReprNetWithVectorOutput[Any]): """A multi-layer perceptron which outputs an action-related representation. :param state_shape: int or a sequence of int of the shape of state. :param action_shape: int or a sequence of int of the shape of action. :param hidden_sizes: shape of MLP passed in as a list. :param norm_layer: use which normalization before activation, e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. You can also pass a list of normalization modules with the same length of hidden_sizes, to use different normalization module in different layers. Default to no normalization. :param activation: which activation to use after each layer, can be both the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. :param softmax: whether to apply a softmax layer over the last layer's output. :param concat: whether the input shape is concatenated by state_shape and action_shape. If it is True, ``action_shape`` is not the output shape, but affects the input shape only. :param num_atoms: in order to expand to the net of distributional RL. Default to 1 (not use). :param dueling_param: whether to use dueling network to calculate Q values (for Dueling DQN). If you want to use dueling option, you should pass a tuple of two dict (first for Q and second for V) stating self-defined arguments as stated in class:`~tianshou.utils.net.common.MLP`. Default to None. :param linear_layer: use this module constructor, which takes the input and output dimension as input, as linear layer. Default to nn.Linear. .. seealso:: Please refer to :class:`~tianshou.utils.net.common.MLP` for more detailed explanation on the usage of activation, norm_layer, etc. You can also refer to :class:`~tianshou.utils.net.continuous.Actor`, :class:`~tianshou.utils.net.continuous.Critic`, etc, to see how it's suggested be used. """ def __init__( self, *, state_shape: int | Sequence[int], action_shape: TActionShape = 0, hidden_sizes: Sequence[int] = (), norm_layer: ModuleType | Sequence[ModuleType] | None = None, norm_args: ArgsType | None = None, activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU, act_args: ArgsType | None = None, softmax: bool = False, concat: bool = False, num_atoms: int = 1, dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None, linear_layer: TLinearLayer = nn.Linear, ) -> None: input_dim = int(np.prod(state_shape)) action_dim = int(np.prod(action_shape)) * num_atoms if concat: input_dim += action_dim use_dueling = dueling_param is not None model = MLP( input_dim=input_dim, output_dim=action_dim if not use_dueling and not concat else 0, hidden_sizes=hidden_sizes, norm_layer=norm_layer, norm_args=norm_args, activation=activation, act_args=act_args, linear_layer=linear_layer, ) Q: MLP | None = None V: MLP | None = None if use_dueling: # dueling DQN assert dueling_param is not None kwargs_update = { "input_dim": model.output_dim, } # Important: don't change the original dict (e.g., don't use .update()) q_kwargs = {**dueling_param[0], **kwargs_update} v_kwargs = {**dueling_param[1], **kwargs_update} q_kwargs["output_dim"] = 0 if concat else action_dim v_kwargs["output_dim"] = 0 if concat else num_atoms Q, V = MLP(**q_kwargs), MLP(**v_kwargs) output_dim = Q.output_dim else: output_dim = model.output_dim super().__init__(output_dim) self.use_dueling = use_dueling self.softmax = softmax self.num_atoms = num_atoms self.model = model self.Q = Q self.V = V def forward( self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, T | Any]: """Mapping: obs -> flatten (inside MLP)-> logits. :param obs: :param state: unused and returned as is :param info: unused """ logits = self.model(obs) batch_size = logits.shape[0] if self.use_dueling: # Dueling DQN assert self.Q is not None assert self.V is not None q, v = self.Q(logits), self.V(logits) if self.num_atoms > 1: q = q.view(batch_size, -1, self.num_atoms) v = v.view(batch_size, -1, self.num_atoms) logits = q - q.mean(dim=1, keepdim=True) + v elif self.num_atoms > 1: logits = logits.view(batch_size, -1, self.num_atoms) if self.softmax: logits = torch.softmax(logits, dim=-1) return logits, state class Recurrent(ActionReprNetWithVectorOutput[RecurrentStateBatch]): """Simple Recurrent network based on LSTM.""" def __init__( self, *, layer_num: int, state_shape: int | Sequence[int], action_shape: TActionShape, hidden_layer_size: int = 128, ) -> None: output_dim = int(np.prod(action_shape)) super().__init__(output_dim) self.nn = nn.LSTM( input_size=hidden_layer_size, hidden_size=hidden_layer_size, num_layers=layer_num, batch_first=True, ) self.fc1 = nn.Linear(int(np.prod(state_shape)), hidden_layer_size) self.fc2 = nn.Linear(hidden_layer_size, output_dim) def get_preprocess_net(self) -> ModuleWithVectorOutput: return ModuleWithVectorOutput.from_module(nn.Identity(), self.output_dim) def forward( self, obs: TObs, state: RecurrentStateBatch | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, RecurrentStateBatch]: """Mapping: obs -> flatten -> logits. In the evaluation mode, `obs` should be with shape ``[bsz, dim]``; in the training mode, `obs` should be with shape ``[bsz, len, dim]``. See the code and comment for more detail. :param obs: :param state: either None or a dict with keys 'hidden' and 'cell' :param info: unused :return: predicted action, next state as dict with keys 'hidden' and 'cell' """ # Note: the original type of state is Batch but it might also be a dict # If it is a Batch, .issubset(state) will not work. However, # issubset(state.keys()) always works if state is not None and not {"hidden", "cell"}.issubset(state.keys()): raise ValueError( f"Expected to find keys 'hidden' and 'cell' but instead found {state.keys()}", ) device = torch_device(self) obs = torch.as_tensor(obs, device=device, dtype=torch.float32) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(obs.shape) == 2: obs = obs.unsqueeze(-2) obs = self.fc1(obs) self.nn.flatten_parameters() if state is None: obs, (hidden, cell) = self.nn(obs) else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] obs, (hidden, cell) = self.nn( obs, ( state["hidden"].transpose(0, 1).contiguous(), state["cell"].transpose(0, 1).contiguous(), ), ) obs = self.fc2(obs[:, -1]) # please ensure the first dim is batch size: [bsz, len, ...] rnn_state_batch = cast( RecurrentStateBatch, Batch( { "hidden": hidden.transpose(0, 1).detach(), "cell": cell.transpose(0, 1).detach(), }, ), ) return obs, rnn_state_batch class ActorCritic(nn.Module): """An actor-critic network for parsing parameters. Using ``actor_critic.parameters()`` instead of set.union or list+list to avoid issue #449. :param nn.Module actor: the actor network. :param nn.Module critic: the critic network. """ def __init__(self, actor: nn.Module, critic: nn.Module) -> None: super().__init__() self.actor = actor self.critic = critic class DataParallelNet(nn.Module): """DataParallel wrapper for training agent with multi-GPU. This class does only the conversion of input data type, from numpy array to torch's Tensor. If the input is a nested dictionary, the user should create a similar class to do the same thing. :param net: the network to be distributed in different GPUs. """ def __init__(self, net: nn.Module) -> None: super().__init__() self.net = nn.DataParallel(net) def forward( self, obs: TObs, *args: Any, **kwargs: Any, ) -> tuple[Any, Any]: if not isinstance(obs, torch.Tensor): obs = torch.as_tensor(obs, dtype=torch.float32) obs = obs.cuda() return self.net(obs, *args, **kwargs) # The same functionality as DataParallelNet # The duplication is worth it because the ActionReprNet abstraction is so important class ActionReprNetDataParallelWrapper(ActionReprNet): def __init__(self, net: ActionReprNet) -> None: super().__init__() self.net = nn.DataParallel(net) def forward( self, obs: TObs, state: TRecurrentState | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, TRecurrentState | None]: if not isinstance(obs, torch.Tensor): obs = torch.as_tensor(obs, dtype=torch.float32) obs = obs.cuda() return self.net(obs, state=state, info=info) class EnsembleLinear(nn.Module): """Linear Layer of Ensemble network. :param ensemble_size: Number of subnets in the ensemble. :param in_feature: dimension of the input vector. :param out_feature: dimension of the output vector. :param bias: whether to include an additive bias, default to be True. """ def __init__( self, ensemble_size: int, in_feature: int, out_feature: int, bias: bool = True, ) -> None: super().__init__() # To be consistent with PyTorch default initializer k = np.sqrt(1.0 / in_feature) weight_data = torch.rand((ensemble_size, in_feature, out_feature)) * 2 * k - k self.weight = nn.Parameter(weight_data, requires_grad=True) self.bias_weights: nn.Parameter | None = None if bias: bias_data = torch.rand((ensemble_size, 1, out_feature)) * 2 * k - k self.bias_weights = nn.Parameter(bias_data, requires_grad=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.matmul(x, self.weight) if self.bias_weights is not None: x = x + self.bias_weights return x class BranchingNet(ActionReprNet): """Branching dual Q network. Network for the BranchingDQNPolicy, it uses a common network module, a value module and action "branches" one for each dimension. It allows for a linear scaling of Q-value the output w.r.t. the number of dimensions in the action space. This network architecture efficiently handles environments with multiple independent action dimensions by using a branching structure. Instead of representing all action combinations (which grows exponentially), it represents each action dimension separately (linear scaling). For example, if there are 3 actions with 3 possible values each, then we would normally need to consider 3^4 = 81 unique actions, whereas with this architecture, we can instead use 3 branches with 4 actions per dimension, resulting in 3 * 4 = 12 values to be considered. Common use cases include multi-joint robotic control tasks, where each joint can be controlled independently. For more information, please refer to: arXiv:1711.08946. """ def __init__( self, *, state_shape: int | Sequence[int], num_branches: int = 0, action_per_branch: int = 2, common_hidden_sizes: list[int] | None = None, value_hidden_sizes: list[int] | None = None, action_hidden_sizes: list[int] | None = None, norm_layer: ModuleType | None = None, norm_args: ArgsType | None = None, activation: ModuleType | None = nn.ReLU, act_args: ArgsType | None = None, ) -> None: """ :param state_shape: int or a sequence of int of the shape of state. :param num_branches: number of action dimensions in the environment. Each branch represents one independent action dimension. For example, in a robot with 7 joints, you would set this to 7. :param action_per_branch: Number of possible discrete values for each action dimension. For example, if each joint can have 3 positions (left, center, right), you would set this to 3. :param common_hidden_sizes: shape of the common MLP network passed in as a list. :param value_hidden_sizes: shape of the value MLP network passed in as a list. :param action_hidden_sizes: shape of the action MLP network passed in as a list. :param norm_layer: use which normalization before activation, e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. You can also pass a list of normalization modules with the same length of hidden_sizes, to use different normalization module in different layers. Default to no normalization. :param activation: which activation to use after each layer, can be both the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. """ super().__init__() common_hidden_sizes = common_hidden_sizes or [] value_hidden_sizes = value_hidden_sizes or [] action_hidden_sizes = action_hidden_sizes or [] self.num_branches = num_branches self.action_per_branch = action_per_branch # common network common_input_dim = int(np.prod(state_shape)) common_output_dim = 0 self.common = MLP( input_dim=common_input_dim, output_dim=common_output_dim, hidden_sizes=common_hidden_sizes, norm_layer=norm_layer, norm_args=norm_args, activation=activation, act_args=act_args, ) # value network value_input_dim = common_hidden_sizes[-1] value_output_dim = 1 self.value = MLP( input_dim=value_input_dim, output_dim=value_output_dim, hidden_sizes=value_hidden_sizes, norm_layer=norm_layer, norm_args=norm_args, activation=activation, act_args=act_args, ) # action branching network action_input_dim = common_hidden_sizes[-1] action_output_dim = action_per_branch self.branches = nn.ModuleList( [ MLP( input_dim=action_input_dim, output_dim=action_output_dim, hidden_sizes=action_hidden_sizes, norm_layer=norm_layer, norm_args=norm_args, activation=activation, act_args=act_args, ) for _ in range(self.num_branches) ], ) def forward( self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, T | None]: """Mapping: obs -> model -> logits.""" common_out = self.common(obs) value_out = self.value(common_out) value_out = torch.unsqueeze(value_out, 1) action_out = [] for b in self.branches: action_out.append(b(common_out)) action_scores = torch.stack(action_out, 1) action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True) logits = value_out + action_scores return logits, state def get_dict_state_decorator( state_shape: dict[str, int | Sequence[int]], keys: Sequence[str], ) -> tuple[Callable, int]: """A helper function to make Net or equivalent classes (e.g. Actor, Critic) applicable to dict state. The first return item, ``decorator_fn``, will alter the implementation of forward function of the given class by preprocessing the observation. The preprocessing is basically flatten the observation and concatenate them based on the ``keys`` order. The batch dimension is preserved if presented. The result observation shape will be equal to ``new_state_shape``, the second return item. :param state_shape: A dictionary indicating each state's shape :param keys: A list of state's keys. The flatten observation will be according to this list order. :returns: a 2-items tuple ``decorator_fn`` and ``new_state_shape`` """ original_shape = state_shape flat_state_shapes = [] for k in keys: flat_state_shapes.append(int(np.prod(state_shape[k]))) new_state_shape = sum(flat_state_shapes) def preprocess_obs(obs: Batch | dict | torch.Tensor | np.ndarray) -> torch.Tensor: if isinstance(obs, dict) or (isinstance(obs, Batch) and keys[0] in obs): if original_shape[keys[0]] == obs[keys[0]].shape: # No batch dim new_obs = torch.Tensor([obs[k] for k in keys]).flatten() # new_obs = torch.Tensor([obs[k] for k in keys]).reshape(1, -1) else: bsz = obs[keys[0]].shape[0] new_obs = torch.cat([torch.Tensor(obs[k].reshape(bsz, -1)) for k in keys], dim=1) else: new_obs = torch.Tensor(obs) return new_obs @no_type_check def decorator_fn(net_class): class new_net_class(net_class): def forward(self, obs: TObs, *args, **kwargs) -> Any: return super().forward(preprocess_obs(obs), *args, **kwargs) return new_net_class return decorator_fn, new_state_shape class AbstractContinuousActorProbabilistic(Actor, ABC): """Type bound for probabilistic actors which output distribution parameters for continuous action spaces.""" class AbstractDiscreteActor(Actor, ABC): """ Type bound for discrete actors. For on-policy algos like Reinforce, this typically directly outputs unnormalized log probabilities, which can be interpreted as "logits" in conjunction with a `torch.distributions.Categorical` instance. In Tianshou, discrete actors are also used for computing action distributions within Q-learning type algorithms (e.g., DQN). In this case, the observations are mapped to a vector of Q-values (one for each action). In other words, the component is actually a critic, not an actor in the traditional sense. Note that when sampling actions, the Q-values can be interpreted as inputs for a `torch.distributions.Categorical` instance, similar to the on-policy case mentioned above. """ class RandomActor(AbstractContinuousActorProbabilistic, AbstractDiscreteActor): """An actor that returns random actions. For continuous action spaces, forward returns a batch of random actions sampled from the action space. For discrete action spaces, forward returns a batch of n-dimensional arrays corresponding to the uniform distribution over the n possible actions (same interface as in :class:`~.net.discrete.Actor`). """ def __init__(self, action_space: spaces.Box | spaces.Discrete) -> None: if isinstance(action_space, spaces.Discrete): output_dim = action_space.n else: output_dim = np.prod(action_space.shape) super().__init__(int(output_dim)) self._action_space = action_space self._space_info = ActionSpaceInfo.from_space(action_space) @property def action_space(self) -> spaces.Box | spaces.Discrete: return self._action_space @property def space_info(self) -> ActionSpaceInfo: return self._space_info def get_preprocess_net(self) -> ModuleWithVectorOutput: return ModuleWithVectorOutput.from_module(nn.Identity(), self.output_dim) def get_output_dim(self) -> int: return self.space_info.action_dim @property def is_discrete(self) -> bool: return isinstance(self.action_space, spaces.Discrete) def forward( self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, T | None]: batch_size = len(obs) if isinstance(self.action_space, spaces.Box): action = np.stack([self.action_space.sample() for _ in range(batch_size)]) else: # Discrete Actors currently return an n-dimensional array of probabilities for each action action = 1 / self.action_space.n * np.ones((batch_size, self.action_space.n)) return torch.Tensor(action), state def compute_action_batch(self, obs: TObs) -> torch.Tensor: if self.is_discrete: # Different from forward which returns discrete probabilities, see comment there assert isinstance(self.action_space, spaces.Discrete) # for mypy return torch.Tensor(np.random.randint(low=0, high=self.action_space.n, size=len(obs))) else: return self.forward(obs)[0] ================================================ FILE: tianshou/utils/net/continuous.py ================================================ import warnings from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any, TypeVar import numpy as np import torch from sensai.util.pickle import setstate from torch import nn from tianshou.data.types import TObs from tianshou.utils.net.common import ( MLP, AbstractContinuousActorProbabilistic, Actor, ModuleWithVectorOutput, TActionShape, TLinearLayer, ) from tianshou.utils.torch_utils import torch_device SIGMA_MIN = -20 SIGMA_MAX = 2 T = TypeVar("T") class AbstractContinuousActorDeterministic(Actor, ABC): """Marker interface for continuous deterministic actors (DDPG like).""" class ContinuousActorDeterministic(AbstractContinuousActorDeterministic): """Actor network that directly outputs actions for continuous action space. Used primarily in DDPG and its variants. It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape. :param preprocess_net: first part of input processing. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. :param max_action: the scale for the final action. """ def __init__( self, *, preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, ) -> None: output_dim = int(np.prod(action_shape)) super().__init__(output_dim) self.preprocess = preprocess_net input_dim = preprocess_net.get_output_dim() self.last = MLP( input_dim=input_dim, output_dim=self.output_dim, hidden_sizes=hidden_sizes, ) self.max_action = max_action def get_preprocess_net(self) -> ModuleWithVectorOutput: return self.preprocess def get_output_dim(self) -> int: return self.output_dim def forward( self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, T | None]: """Mapping: s_B -> action_values_BA, hidden_state_BH | None. Returns a tensor representing the actions directly, i.e, of shape `(n_actions, )`, and a hidden state (which may be None). The hidden state is only not None if a recurrent net is used as part of the learning algorithm (support for RNNs is currently experimental). """ action_BA, hidden_BH = self.preprocess(obs, state) action_BA = self.max_action * torch.tanh(self.last(action_BA)) return action_BA, hidden_BH class AbstractContinuousCritic(ModuleWithVectorOutput, ABC): @abstractmethod def forward( self, obs: np.ndarray | torch.Tensor, act: np.ndarray | torch.Tensor | None = None, info: dict[str, Any] | None = None, ) -> torch.Tensor: """Mapping: (s_B, a_B) -> Q(s, a)_B.""" class ContinuousCritic(AbstractContinuousCritic): """Simple critic network. It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value). :param preprocess_net: the pre-processing network, which returns a vector of a known dimension. Typically, an instance of :class:`~tianshou.utils.net.common.Net`. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. :param linear_layer: use this module as linear layer. :param flatten_input: whether to flatten input data for the last layer. :param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before concatenating with the action) - and without the observations being modified in any way beforehand. This allows the actor's preprocessing network to be reused for the critic. """ def __init__( self, *, preprocess_net: ModuleWithVectorOutput, hidden_sizes: Sequence[int] = (), linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, apply_preprocess_net_to_obs_only: bool = False, ) -> None: super().__init__(output_dim=1) self.preprocess = preprocess_net self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only input_dim = preprocess_net.get_output_dim() self.last = MLP( input_dim=input_dim, output_dim=1, hidden_sizes=hidden_sizes, linear_layer=linear_layer, flatten_input=flatten_input, ) def __setstate__(self, state: dict) -> None: setstate( ContinuousCritic, self, state, new_default_properties={"apply_preprocess_net_to_obs_only": False}, ) def forward( self, obs: np.ndarray | torch.Tensor, act: np.ndarray | torch.Tensor | None = None, info: dict[str, Any] | None = None, ) -> torch.Tensor: """Mapping: (s_B, a_B) -> Q(s, a)_B.""" device = torch_device(self) obs = torch.as_tensor( obs, device=device, dtype=torch.float32, ) if self.apply_preprocess_net_to_obs_only: obs, _ = self.preprocess(obs) obs = obs.flatten(1) if act is not None: act = torch.as_tensor( act, device=device, dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) if not self.apply_preprocess_net_to_obs_only: obs, _ = self.preprocess(obs) return self.last(obs) class ContinuousActorProbabilistic(AbstractContinuousActorProbabilistic): """Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian). Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`. :param preprocess_net: the pre-processing network, which returns a vector of a known dimension. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. :param max_action: the scale for the final action logits. :param unbounded: whether to apply tanh activation on final logits. :param conditioned_sigma: True when sigma is calculated from the input, False when sigma is an independent parameter. """ def __init__( self, *, preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, unbounded: bool = False, conditioned_sigma: bool = False, ) -> None: output_dim = int(np.prod(action_shape)) super().__init__(output_dim) if unbounded and not np.isclose(max_action, 1.0): warnings.warn("Note that max_action input will be discarded when unbounded is True.") max_action = 1.0 self.preprocess = preprocess_net input_dim = preprocess_net.get_output_dim() self.mu = MLP(input_dim=input_dim, output_dim=output_dim, hidden_sizes=hidden_sizes) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = MLP( input_dim=input_dim, output_dim=output_dim, hidden_sizes=hidden_sizes, ) else: self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1)) self.max_action = max_action self._unbounded = unbounded def get_preprocess_net(self) -> ModuleWithVectorOutput: return self.preprocess def forward( self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[tuple[torch.Tensor, torch.Tensor], T | None]: if info is None: info = {} logits, hidden = self.preprocess(obs, state) mu = self.mu(logits) if not self._unbounded: mu = self.max_action * torch.tanh(mu) if self._c_sigma: sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() else: shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() return (mu, sigma), state class RecurrentActorProb(nn.Module): """Recurrent version of ActorProb.""" def __init__( self, *, layer_num: int, state_shape: Sequence[int], action_shape: Sequence[int], hidden_layer_size: int = 128, max_action: float = 1.0, unbounded: bool = False, conditioned_sigma: bool = False, ) -> None: super().__init__() if unbounded and not np.isclose(max_action, 1.0): warnings.warn("Note that max_action input will be discarded when unbounded is True.") max_action = 1.0 self.nn = nn.LSTM( input_size=int(np.prod(state_shape)), hidden_size=hidden_layer_size, num_layers=layer_num, batch_first=True, ) output_dim = int(np.prod(action_shape)) self.mu = nn.Linear(hidden_layer_size, output_dim) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = nn.Linear(hidden_layer_size, output_dim) else: self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1)) self.max_action = max_action self._unbounded = unbounded def forward( self, obs: np.ndarray | torch.Tensor, state: dict[str, torch.Tensor] | None = None, info: dict[str, Any] | None = None, ) -> tuple[tuple[torch.Tensor, torch.Tensor], dict[str, torch.Tensor]]: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" if info is None: info = {} device = torch_device(self) obs = torch.as_tensor( obs, device=device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(obs.shape) == 2: obs = obs.unsqueeze(-2) self.nn.flatten_parameters() if state is None: obs, (hidden, cell) = self.nn(obs) else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] obs, (hidden, cell) = self.nn( obs, ( state["hidden"].transpose(0, 1).contiguous(), state["cell"].transpose(0, 1).contiguous(), ), ) logits = obs[:, -1] mu = self.mu(logits) if not self._unbounded: mu = self.max_action * torch.tanh(mu) if self._c_sigma: sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() else: shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() # please ensure the first dim is batch size: [bsz, len, ...] return (mu, sigma), { "hidden": hidden.transpose(0, 1).detach(), "cell": cell.transpose(0, 1).detach(), } class RecurrentCritic(nn.Module): """Recurrent version of Critic.""" def __init__( self, layer_num: int, state_shape: Sequence[int], action_shape: Sequence[int] = (0,), hidden_layer_size: int = 128, ) -> None: super().__init__() self.state_shape = state_shape self.action_shape = action_shape self.nn = nn.LSTM( input_size=int(np.prod(state_shape)), hidden_size=hidden_layer_size, num_layers=layer_num, batch_first=True, ) self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1) def forward( self, obs: np.ndarray | torch.Tensor, act: np.ndarray | torch.Tensor | None = None, info: dict[str, Any] | None = None, ) -> torch.Tensor: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" if info is None: info = {} device = torch_device(self) obs = torch.as_tensor( obs, device=device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. assert len(obs.shape) == 3 self.nn.flatten_parameters() obs, (hidden, cell) = self.nn(obs) obs = obs[:, -1] if act is not None: act = torch.as_tensor( act, device=device, dtype=torch.float32, ) obs = torch.cat([obs, act], dim=1) return self.fc2(obs) class Perturbation(nn.Module): """Implementation of perturbation network in BCQ algorithm. Given a state and action, it can generate perturbed action. :param preprocess_net: a self-defined preprocess_net which output a flattened hidden state. :param max_action: the maximum value of each dimension of action. :param device: which device to create this model on. :param phi: max perturbation parameter for BCQ. .. seealso:: You can refer to `examples/offline/offline_bcq.py` to see how to use it. """ def __init__( self, *, preprocess_net: nn.Module, max_action: float, phi: float = 0.05, ): # preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim super().__init__() self.preprocess_net = preprocess_net self.max_action = max_action self.phi = phi def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: # preprocess_net logits = self.preprocess_net(torch.cat([state, action], -1))[0] noise = self.phi * self.max_action * torch.tanh(logits) # clip to [-max_action, max_action] return (noise + action).clamp(-self.max_action, self.max_action) class VAE(nn.Module): """Implementation of VAE. It models the distribution of action. Given a state, it can generate actions similar to those in batch. It is used in BCQ algorithm. :param encoder: the encoder in VAE. Its input_dim must be state_dim + action_dim, and output_dim must be hidden_dim. :param decoder: the decoder in VAE. Its input_dim must be state_dim + latent_dim, and output_dim must be action_dim. :param hidden_dim: the size of the last linear-layer in encoder. :param latent_dim: the size of latent layer. :param max_action: the maximum value of each dimension of action. :param device: which device to create this model on. .. seealso:: You can refer to `examples/offline/offline_bcq.py` to see how to use it. """ def __init__( self, *, encoder: nn.Module, decoder: nn.Module, hidden_dim: int, latent_dim: int, max_action: float, ): super().__init__() self.encoder = encoder self.mean = nn.Linear(hidden_dim, latent_dim) self.log_std = nn.Linear(hidden_dim, latent_dim) self.decoder = decoder self.max_action = max_action self.latent_dim = latent_dim def forward( self, state: torch.Tensor, action: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # [state, action] -> z , [state, z] -> action latent_z = self.encoder(torch.cat([state, action], -1)) # shape of z: (state.shape[:-1], hidden_dim) mean = self.mean(latent_z) # Clamped for numerical stability log_std = self.log_std(latent_z).clamp(-4, 15) std = torch.exp(log_std) # shape of mean, std: (state.shape[:-1], latent_dim) latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim) reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim) return reconstruction, mean, std def decode( self, state: torch.Tensor, latent_z: torch.Tensor | None = None, ) -> torch.Tensor: # decode(state) -> action if latent_z is None: # state.shape[0] may be batch_size # latent vector clipped to [-0.5, 0.5] device = torch_device(self) latent_z = ( torch.randn(state.shape[:-1] + (self.latent_dim,)).to(device).clamp(-0.5, 0.5) ) # decode z with state! return self.max_action * torch.tanh(self.decoder(torch.cat([state, latent_z], -1))) ================================================ FILE: tianshou/utils/net/discrete.py ================================================ from collections.abc import Sequence from typing import Any, TypeVar import numpy as np import torch import torch.nn.functional as F from torch import nn from tianshou.data import Batch, to_torch from tianshou.data.types import TObs from tianshou.utils.net.common import ( MLP, AbstractDiscreteActor, ModuleWithVectorOutput, TActionShape, ) from tianshou.utils.torch_utils import torch_device T = TypeVar("T") def dist_fn_categorical_from_logits( logits: torch.Tensor, ) -> torch.distributions.Categorical: """Default distribution function for categorical actors.""" return torch.distributions.Categorical(logits=logits) class DiscreteActor(AbstractDiscreteActor): """ Generic discrete actor which uses a preprocessing network to generate a latent representation which is subsequently passed to an MLP to compute the output. For common output semantics, see :class:`DiscreteActorInterface`. """ def __init__( self, *, preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), softmax_output: bool = True, ) -> None: """ :param preprocess_net: the preprocessing network, which outputs a vector of a known dimension; typically an instance of :class:`~tianshou.utils.net.common.Net`. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param softmax_output: whether to apply a softmax layer over the last layer's output. """ output_dim = int(np.prod(action_shape)) super().__init__(output_dim) self.preprocess = preprocess_net input_dim = preprocess_net.get_output_dim() self.last = MLP( input_dim=input_dim, output_dim=self.output_dim, hidden_sizes=hidden_sizes, ) self.softmax_output = softmax_output def get_preprocess_net(self) -> ModuleWithVectorOutput: return self.preprocess def forward( self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, T | None]: r"""Mapping: (s_B, ...) -> action_values_BA, hidden_state_BH | None. Returns a tensor representing the values of each action, i.e, of shape `(n_actions, )` (see class docstring for more info on the meaning of that), and a hidden state (which may be None). If `self.softmax_output` is True, they are the probabilities for taking each action. Otherwise, they will be action values. The hidden state is only not None if a recurrent net is used as part of the learning algorithm. """ x, hidden_BH = self.preprocess(obs, state) x = self.last(x) if self.softmax_output: x = F.softmax(x, dim=-1) # If we computed softmax, output is probabilities, otherwise it's the non-normalized action values output_BA = x return output_BA, hidden_BH class DiscreteCritic(ModuleWithVectorOutput): """Simple critic network for discrete action spaces. :param preprocess_net: the preprocessing network, which outputs a vector of a known dimension; typically an instance of :class:`~tianshou.utils.net.common.Net`. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param last_size: the output dimension of Critic network. """ def __init__( self, *, preprocess_net: ModuleWithVectorOutput, hidden_sizes: Sequence[int] = (), last_size: int = 1, ) -> None: super().__init__(output_dim=last_size) self.preprocess = preprocess_net input_dim = preprocess_net.get_output_dim() self.last = MLP(input_dim=input_dim, output_dim=last_size, hidden_sizes=hidden_sizes) def forward( self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None ) -> torch.Tensor: """Mapping: s_B -> V(s)_B.""" # TODO: don't use this mechanism for passing state logits, _ = self.preprocess(obs, state=state) return self.last(logits) class CosineEmbeddingNetwork(nn.Module): """Cosine embedding network for IQN. Convert a scalar in [0, 1] to a list of n-dim vectors. :param num_cosines: the number of cosines used for the embedding. :param embedding_dim: the dimension of the embedding/output. .. note:: From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master /fqf_iqn_qrdqn/network.py . """ def __init__(self, num_cosines: int, embedding_dim: int) -> None: super().__init__() self.net = nn.Sequential(nn.Linear(num_cosines, embedding_dim), nn.ReLU()) self.num_cosines = num_cosines self.embedding_dim = embedding_dim def forward(self, taus: torch.Tensor) -> torch.Tensor: batch_size = taus.shape[0] N = taus.shape[1] # Calculate i * \pi (i=1,...,N). i_pi = np.pi * torch.arange( start=1, end=self.num_cosines + 1, dtype=taus.dtype, device=taus.device, ).view(1, 1, self.num_cosines) # Calculate cos(i * \pi * \tau). cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi).view( batch_size * N, self.num_cosines, ) # Calculate embeddings of taus. return self.net(cosines).view(batch_size, N, self.embedding_dim) class ImplicitQuantileNetwork(DiscreteCritic): """Implicit Quantile Network. :param preprocess_net: a self-defined preprocess_net which output a flattened hidden state. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param num_cosines: the number of cosines to use for cosine embedding. Default to 64. .. note:: Although this class inherits Critic, it is actually a quantile Q-Network with output shape (batch_size, action_dim, sample_size). The second item of the first return value is tau vector. """ def __init__( self, *, preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, ) -> None: last_size = int(np.prod(action_shape)) super().__init__( preprocess_net=preprocess_net, hidden_sizes=hidden_sizes, last_size=last_size, ) self.input_dim = preprocess_net.get_output_dim() self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim) def forward( # type: ignore self, obs: np.ndarray | torch.Tensor, sample_size: int, **kwargs: Any, ) -> tuple[Any, torch.Tensor]: r"""Mapping: s -> Q(s, \*).""" logits, hidden = self.preprocess(obs, state=kwargs.get("state")) # Sample fractions. batch_size = logits.size(0) taus = torch.rand(batch_size, sample_size, dtype=logits.dtype, device=logits.device) embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view( batch_size * sample_size, -1, ) out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2) return (out, taus), hidden class FractionProposalNetwork(nn.Module): """Fraction proposal network for FQF. :param num_fractions: the number of factions to propose. :param embedding_dim: the dimension of the embedding/input. .. note:: Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master /fqf_iqn_qrdqn/network.py . """ def __init__(self, num_fractions: int, embedding_dim: int) -> None: super().__init__() self.net = nn.Linear(embedding_dim, num_fractions) torch.nn.init.xavier_uniform_(self.net.weight, gain=0.01) torch.nn.init.constant_(self.net.bias, 0) self.num_fractions = num_fractions self.embedding_dim = embedding_dim def forward( self, obs_embeddings: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Calculate (log of) probabilities q_i in the paper. dist = torch.distributions.Categorical(logits=self.net(obs_embeddings)) taus_1_N = torch.cumsum(dist.probs, dim=1) # Calculate \tau_i (i=0,...,N). taus = F.pad(taus_1_N, (1, 0)) # Calculate \hat \tau_i (i=0,...,N-1). tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0 # Calculate entropies of value distributions. entropies = dist.entropy() return taus, tau_hats, entropies class FullQuantileFunction(ImplicitQuantileNetwork): """Full(y parameterized) Quantile Function. :param preprocess_net: a self-defined preprocess_net which output a flattened hidden state. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param num_cosines: the number of cosines to use for cosine embedding. Default to 64. .. note:: The first return value is a tuple of (quantiles, fractions, quantiles_tau), where fractions is a Batch(taus, tau_hats, entropies). """ def __init__( self, *, preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, ) -> None: super().__init__( preprocess_net=preprocess_net, action_shape=action_shape, hidden_sizes=hidden_sizes, num_cosines=num_cosines, ) def _compute_quantiles(self, obs: torch.Tensor, taus: torch.Tensor) -> torch.Tensor: batch_size, sample_size = taus.shape embedding = (obs.unsqueeze(1) * self.embed_model(taus)).view(batch_size * sample_size, -1) return self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2) def forward( # type: ignore self, obs: np.ndarray | torch.Tensor, propose_model: FractionProposalNetwork, fractions: Batch | None = None, **kwargs: Any, ) -> tuple[Any, torch.Tensor]: r"""Mapping: s -> Q(s, \*).""" logits, hidden = self.preprocess(obs, state=kwargs.get("state")) # Propose fractions if fractions is None: taus, tau_hats, entropies = propose_model(logits.detach()) fractions = Batch(taus=taus, tau_hats=tau_hats, entropies=entropies) else: taus, tau_hats = fractions.taus, fractions.tau_hats quantiles = self._compute_quantiles(logits, tau_hats) # Calculate quantiles_tau for computing fraction grad quantiles_tau = None if self.training: with torch.no_grad(): quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1]) return (quantiles, fractions, quantiles_tau), hidden class NoisyLinear(nn.Module): """Implementation of Noisy Networks. arXiv:1706.10295. :param in_features: the number of input features. :param out_features: the number of output features. :param noisy_std: initial standard deviation of noisy linear layers. .. note:: Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master /fqf_iqn_qrdqn/network.py . """ def __init__(self, in_features: int, out_features: int, noisy_std: float = 0.5) -> None: super().__init__() # Learnable parameters. self.mu_W = nn.Parameter(torch.FloatTensor(out_features, in_features)) self.sigma_W = nn.Parameter(torch.FloatTensor(out_features, in_features)) self.mu_bias = nn.Parameter(torch.FloatTensor(out_features)) self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features)) # Factorized noise parameters. self.eps_p = nn.Parameter(torch.FloatTensor(in_features), requires_grad=False) self.eps_q = nn.Parameter(torch.FloatTensor(out_features), requires_grad=False) self.in_features = in_features self.out_features = out_features self.sigma = noisy_std self.reset() self.sample() def reset(self) -> None: bound = 1 / np.sqrt(self.in_features) self.mu_W.data.uniform_(-bound, bound) self.mu_bias.data.uniform_(-bound, bound) self.sigma_W.data.fill_(self.sigma / np.sqrt(self.in_features)) self.sigma_bias.data.fill_(self.sigma / np.sqrt(self.in_features)) def f(self, x: torch.Tensor) -> torch.Tensor: x = torch.randn(x.size(0), device=x.device) return x.sign().mul_(x.abs().sqrt_()) # TODO: rename or change functionality? Usually sample is not an inplace operation... def sample(self) -> None: self.eps_p.copy_(self.f(self.eps_p)) self.eps_q.copy_(self.f(self.eps_q)) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training: weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p)) bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() else: weight = self.mu_W bias = self.mu_bias return F.linear(x, weight, bias) class IntrinsicCuriosityModule(nn.Module): """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. :param feature_net: a self-defined feature_net which output a flattened hidden state. :param feature_dim: input dimension of the feature net. :param action_dim: dimension of the action space. :param hidden_sizes: hidden layer sizes for forward and inverse models. """ def __init__( self, *, feature_net: nn.Module, feature_dim: int, action_dim: int, hidden_sizes: Sequence[int] = (), ) -> None: super().__init__() self.feature_net = feature_net self.forward_model = MLP( input_dim=feature_dim + action_dim, output_dim=feature_dim, hidden_sizes=hidden_sizes, ) self.inverse_model = MLP( input_dim=feature_dim * 2, output_dim=action_dim, hidden_sizes=hidden_sizes, ) self.feature_dim = feature_dim self.action_dim = action_dim def forward( self, s1: np.ndarray | torch.Tensor, act: np.ndarray | torch.Tensor, s2: np.ndarray | torch.Tensor, **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Mapping: s1, act, s2 -> mse_loss, act_hat.""" device = torch_device(self) s1 = to_torch(s1, dtype=torch.float32, device=device) s2 = to_torch(s2, dtype=torch.float32, device=device) phi1, phi2 = self.feature_net(s1), self.feature_net(s2) act = to_torch(act, dtype=torch.long, device=device) phi2_hat = self.forward_model( torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1), ) mse_loss = 0.5 * F.mse_loss(phi2_hat, phi2, reduction="none").sum(1) act_hat = self.inverse_model(torch.cat([phi1, phi2], dim=1)) return mse_loss, act_hat ================================================ FILE: tianshou/utils/print.py ================================================ import pprint from collections.abc import Sequence from dataclasses import asdict, dataclass @dataclass class DataclassPPrintMixin: def pprint_asdict(self, exclude_fields: Sequence[str] | None = None, indent: int = 4) -> None: """Pretty-print the object as a dict, excluding specified fields. :param exclude_fields: A sequence of field names to exclude from the output. If None, no fields are excluded. :param indent: The indentation to use when pretty-printing. """ print(self.pprints_asdict(exclude_fields=exclude_fields, indent=indent)) def pprints_asdict(self, exclude_fields: Sequence[str] | None = None, indent: int = 4) -> str: """String corresponding to pretty-print of the object as a dict, excluding specified fields. :param exclude_fields: A sequence of field names to exclude from the output. If None, no fields are excluded. :param indent: The indentation to use when pretty-printing. """ prefix = f"{self.__class__.__name__}\n----------------------------------------\n" print_dict = asdict(self) exclude_fields = exclude_fields or [] for field in exclude_fields: print_dict.pop(field, None) return prefix + pprint.pformat(print_dict, indent=indent) ================================================ FILE: tianshou/utils/progress_bar.py ================================================ from typing import Any tqdm_config = { "dynamic_ncols": True, "ascii": True, } class DummyTqdm: """A dummy tqdm class that keeps stats but without progress bar. It supports ``__enter__`` and ``__exit__``, update and a dummy ``set_postfix``, which is the interface that trainers use. .. note:: Using ``disable=True`` in tqdm config results in infinite loop, thus this class is created. See the discussion at #641 for details. """ def __init__(self, total: int, **kwargs: Any): self.total = total self.n = 0 def set_postfix(self, **kwargs: Any) -> None: pass def update(self, n: int = 1) -> None: self.n += n def __enter__(self) -> "DummyTqdm": return self def __exit__(self, *args: Any, **kwargs: Any) -> None: pass ================================================ FILE: tianshou/utils/space_info.py ================================================ from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Self import gymnasium as gym import numpy as np from gymnasium import spaces from sensai.util.string import ToStringMixin @dataclass(kw_only=True) class ActionSpaceInfo(ToStringMixin): """A data structure for storing the different attributes of the action space.""" action_shape: int | Sequence[int] """The shape of the action space.""" min_action: float """The smallest allowable action or in the continuous case the lower bound for allowable action value.""" max_action: float """The largest allowable action or in the continuous case the upper bound for allowable action value.""" @property def action_dim(self) -> int: """Return the number of distinct actions (must be greater than zero) an agent can take it its action space.""" if isinstance(self.action_shape, int): return self.action_shape else: return int(np.prod(self.action_shape)) @classmethod def from_space(cls, space: spaces.Space) -> Self: """Instantiate the `ActionSpaceInfo` object from a `Space`, supported spaces are Box and Discrete.""" if isinstance(space, spaces.Box): return cls( action_shape=space.shape, min_action=float(np.min(space.low)), max_action=float(np.max(space.high)), ) elif isinstance(space, spaces.Discrete): return cls( action_shape=int(space.n), min_action=float(space.start), max_action=float(space.start + space.n - 1), ) else: raise ValueError( f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.", ) def _tostring_additional_entries(self) -> dict[str, Any]: return {"action_dim": self.action_dim} @dataclass(kw_only=True) class ObservationSpaceInfo(ToStringMixin): """A data structure for storing the different attributes of the observation space.""" obs_shape: int | Sequence[int] """The shape of the observation space.""" @property def obs_dim(self) -> int: """Return the number of distinct features (must be greater than zero) or dimensions in the observation space.""" if isinstance(self.obs_shape, int): return self.obs_shape else: return int(np.prod(self.obs_shape)) @classmethod def from_space(cls, space: spaces.Space) -> Self: """Instantiate the `ObservationSpaceInfo` object from a `Space`, supported spaces are Box and Discrete.""" if isinstance(space, spaces.Box): return cls( obs_shape=space.shape, ) elif isinstance(space, spaces.Discrete): return cls( obs_shape=int(space.n), ) else: raise ValueError( f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.", ) def _tostring_additional_entries(self) -> dict[str, Any]: return {"obs_dim": self.obs_dim} @dataclass(kw_only=True) class SpaceInfo(ToStringMixin): """A data structure for storing the attributes of both the action and observation space.""" action_info: ActionSpaceInfo """Stores the attributes of the action space.""" observation_info: ObservationSpaceInfo """Stores the attributes of the observation space.""" @classmethod def from_env(cls, env: gym.Env) -> Self: """Instantiate the `SpaceInfo` object from `gym.Env.action_space` and `gym.Env.observation_space`.""" return cls.from_spaces(env.action_space, env.observation_space) @classmethod def from_spaces(cls, action_space: spaces.Space, observation_space: spaces.Space) -> Self: """Instantiate the `SpaceInfo` object from `ActionSpaceInfo` and `ObservationSpaceInfo`.""" action_info = ActionSpaceInfo.from_space(action_space) observation_info = ObservationSpaceInfo.from_space(observation_space) return cls( action_info=action_info, observation_info=observation_info, ) ================================================ FILE: tianshou/utils/statistics.py ================================================ from numbers import Number import numpy as np import torch class MovAvg: """Class for moving average. It will automatically exclude the infinity and NaN. Usage: :: >>> stat = MovAvg(size=66) >>> stat.add(torch.tensor(5)) 5.0 >>> stat.add(float('inf')) # which will not add to stat 5.0 >>> stat.add([6, 7, 8]) 6.5 >>> stat.get() 6.5 >>> print(f'{stat.mean():.2f}±{stat.std():.2f}') 6.50±1.12 """ def __init__(self, size: int = 100) -> None: super().__init__() self.size = size self.cache: list[np.number] = [] self.banned = [np.inf, np.nan, -np.inf] def add( self, data_array: Number | float | np.number | list | np.ndarray | torch.Tensor, ) -> float: """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with only one element, a python scalar, or a list of python scalar. """ if isinstance(data_array, torch.Tensor): data_array = data_array.flatten().cpu().numpy() if np.isscalar(data_array): data_array = [data_array] for number in data_array: # type: ignore if number not in self.banned: self.cache.append(number) if self.size > 0 and len(self.cache) > self.size: self.cache = self.cache[-self.size :] return self.get() def get(self) -> float: """Get the average.""" if len(self.cache) == 0: return 0.0 return float(np.mean(self.cache)) # type: ignore def mean(self) -> float: """Get the average. Same as :meth:`get`.""" return self.get() def std(self) -> float: """Get the standard deviation.""" if len(self.cache) == 0: return 0.0 return float(np.std(self.cache)) # type: ignore class RunningMeanStd: """Calculates the running mean and std of a data stream. https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm :param mean: the initial mean estimation for data array. Default to 0. :param std: the initial standard error estimation for data array. :param clip_max: the maximum absolute value for data array. Default to 10.0. :param epsilon: To avoid division by zero. """ def __init__( self, mean: float | np.ndarray = 0.0, std: float | np.ndarray = 1.0, clip_max: float | None = 10.0, epsilon: float = np.finfo(np.float32).eps.item(), ) -> None: self.mean, self.var = mean, std self.clip_max = clip_max self.count = 0 self.eps = epsilon def norm(self, data_array: float | np.ndarray) -> float | np.ndarray: data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps) if self.clip_max: data_array = np.clip(data_array, -self.clip_max, self.clip_max) return data_array def update(self, data_array: np.ndarray) -> None: """Add a batch of item into RMS with the same shape, modify mean/var/count.""" batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0) batch_count = len(data_array) delta = batch_mean - self.mean total_count = self.count + batch_count new_mean = self.mean + delta * batch_count / total_count m_a = self.var * self.count m_b = batch_var * batch_count m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count new_var = m_2 / total_count self.mean, self.var = new_mean, new_var self.count = total_count ================================================ FILE: tianshou/utils/torch_utils.py ================================================ from collections.abc import Iterator from contextlib import contextmanager from typing import TYPE_CHECKING, overload import torch import torch.distributions as dist from gymnasium import spaces from torch import nn if TYPE_CHECKING: from tianshou.algorithm import algorithm_base @contextmanager def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: """Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`.""" original_mode = module.training try: module.train(enabled) yield finally: module.train(original_mode) @contextmanager def policy_within_training_step( policy: "algorithm_base.Policy", enabled: bool = True ) -> Iterator[None]: """Temporarily switch to `policy.is_within_training_step=enabled`. Enabling this ensures that the policy is able to adapt its behavior, allowing it to differentiate between training and inference/evaluation, e.g., to sample actions instead of using the most probable action (where applicable) Note that for rollout, which also happens within a training step, one would usually want the wrapped torch module to be in evaluation mode, which can be achieved using `with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both within training step and in torch train mode. """ original_mode = policy.is_within_training_step try: policy.is_within_training_step = enabled yield finally: policy.is_within_training_step = original_mode @overload def create_uniform_action_dist(action_space: spaces.Box, batch_size: int = 1) -> dist.Uniform: ... @overload def create_uniform_action_dist( action_space: spaces.Discrete, batch_size: int = 1, ) -> dist.Categorical: ... def create_uniform_action_dist( action_space: spaces.Box | spaces.Discrete, batch_size: int = 1, ) -> dist.Uniform | dist.Categorical: """Create a Distribution such that sampling from it is equivalent to sampling a batch with `action_space.sample()`. :param action_space: the environment's action_space. :param batch_size: The number of environments or batch size for sampling. :return: A PyTorch distribution for sampling actions. """ if isinstance(action_space, spaces.Box): low = torch.FloatTensor(action_space.low).unsqueeze(0).repeat(batch_size, 1) high = torch.FloatTensor(action_space.high).unsqueeze(0).repeat(batch_size, 1) return dist.Uniform(low, high) elif isinstance(action_space, spaces.Discrete): return dist.Categorical(torch.ones(batch_size, int(action_space.n))) else: raise ValueError(f"Unsupported action space type: {type(action_space)}") def torch_device(module: torch.nn.Module) -> torch.device: """Gets the device of a torch module by retrieving the device of the parameters. If parameters are empty, it returns the CPU device as a fallback. """ try: return next(module.parameters()).device except StopIteration: return torch.device("cpu") ================================================ FILE: tianshou/utils/warning.py ================================================ import warnings warnings.simplefilter("once", DeprecationWarning) def deprecation(msg: str) -> None: """Deprecation warning wrapper.""" warnings.warn(msg, category=DeprecationWarning, stacklevel=2)