Full Code of lebrice/Sequoia for AI

master 7e12ff8ed67f cached
460 files
2.6 MB
715.6k tokens
3006 symbols
1 requests
Download .txt
Showing preview only (2,859K chars total). Download the full file or copy to clipboard to get everything.
Repository: lebrice/Sequoia
Branch: master
Commit: 7e12ff8ed67f
Files: 460
Total size: 2.6 MB

Directory structure:
gitextract_c6gc35b2/

├── .dockerignore
├── .gitattributes
├── .gitignore
├── .gitmodules
├── .travis.yml
├── LICENSE
├── MANIFEST.in
├── README.md
├── dockers/
│   ├── .gitignore
│   ├── base/
│   │   ├── Dockerfile
│   │   └── build.sh
│   └── branch/
│       ├── Dockerfile
│       └── build.sh
├── docs/
│   └── diagrams/
│       └── src/
│           ├── gym.puml
│           ├── pytorch_lightning.puml
│           └── seq_diagram.puml
├── examples/
│   ├── README.md
│   ├── __init__.py
│   ├── advanced/
│   │   ├── RL_and_SL_demo.py
│   │   ├── continual_rl_demo.py
│   │   ├── ewc_in_rl.py
│   │   ├── hat_demo.py
│   │   ├── hparam_tuning.py
│   │   ├── pnn/
│   │   │   ├── __init__.py
│   │   │   ├── layers.py
│   │   │   ├── model_rl.py
│   │   │   ├── model_sl.py
│   │   │   └── pnn_method.py
│   │   └── procgen_example.py
│   ├── basic/
│   │   ├── __init__.py
│   │   ├── base_method_demo.py
│   │   ├── pl_example.py
│   │   ├── pl_example_packnet.py
│   │   ├── pl_example_test.py
│   │   ├── quick_demo.ipynb
│   │   ├── quick_demo.py
│   │   ├── quick_demo_ewc.py
│   │   ├── quick_demo_packnet.py
│   │   └── quick_demo_test.py
│   ├── clcomp21/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── a2c_example.py
│   │   ├── a2c_example_test.py
│   │   ├── classifier.py
│   │   ├── classifier_test.py
│   │   ├── conftest.py
│   │   ├── dummy_method.py
│   │   ├── dummy_method_test.py
│   │   ├── multihead_classifier.py
│   │   ├── multihead_classifier_test.py
│   │   ├── regularization_example.py
│   │   ├── regularization_example_test.py
│   │   ├── sb3_example.py
│   │   └── sb3_example_test.py
│   ├── demo_utils.py
│   └── prerequisites/
│       └── dataclasses_example.py
├── mypy.ini
├── pytest.ini
├── requirements.txt
├── scripts/
│   ├── eai/
│   │   ├── cancel_all_queuing.sh
│   │   ├── cancel_all_running.sh
│   │   ├── job.sh
│   │   ├── rl_sweep.sh
│   │   ├── shell_job.sh
│   │   └── sl_sweep.sh
│   └── slurm/
│       ├── launch_many_sweeps.sh
│       ├── run.sh
│       └── sweep.sh
├── sequoia/
│   ├── README.md
│   ├── __init__.py
│   ├── _version.py
│   ├── client/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── __main__.py
│   │   ├── env.proto
│   │   ├── env_proxy.py
│   │   ├── env_proxy_test.py
│   │   ├── server.py
│   │   ├── setting_proxy.py
│   │   └── setting_proxy_test.py
│   ├── common/
│   │   ├── __init__.py
│   │   ├── batch.py
│   │   ├── batch_test.py
│   │   ├── callbacks/
│   │   │   ├── __init__.py
│   │   │   ├── knn_callback.py
│   │   │   └── vae_callback.py
│   │   ├── config/
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── wandb_config.py
│   │   ├── gym_wrappers/
│   │   │   ├── __init__.py
│   │   │   ├── action_limit.py
│   │   │   ├── action_limit_test.py
│   │   │   ├── add_done.py
│   │   │   ├── add_info.py
│   │   │   ├── convert_tensors.py
│   │   │   ├── convert_tensors_test.py
│   │   │   ├── env_dataset.py
│   │   │   ├── env_dataset_test.py
│   │   │   ├── episode_limit.py
│   │   │   ├── episode_limit_test.py
│   │   │   ├── measure_performance.py
│   │   │   ├── multi_task_environment.py
│   │   │   ├── multi_task_environment_test.py
│   │   │   ├── observation_limit.py
│   │   │   ├── observation_limit_test.py
│   │   │   ├── pixel_observation.py
│   │   │   ├── pixel_observation_test.py
│   │   │   ├── policy_env.py
│   │   │   ├── policy_env_test.py
│   │   │   ├── smooth_environment.py
│   │   │   ├── smooth_environment_test.py
│   │   │   ├── step_callback_wrapper.py
│   │   │   ├── step_callback_wrapper_test.py
│   │   │   ├── transform_wrappers.py
│   │   │   ├── transform_wrappers_test.py
│   │   │   ├── utils.py
│   │   │   └── utils_test.py
│   │   ├── hparams/
│   │   │   └── __init__.py
│   │   ├── layers.py
│   │   ├── loss.py
│   │   ├── loss_test.py
│   │   ├── metrics/
│   │   │   ├── __init__.py
│   │   │   ├── classification.py
│   │   │   ├── classification_test.py
│   │   │   ├── get_metrics.py
│   │   │   ├── metrics.py
│   │   │   ├── metrics_utils.py
│   │   │   ├── metrics_utils_test.py
│   │   │   ├── regression.py
│   │   │   └── rl_metrics.py
│   │   ├── replay.py
│   │   ├── spaces/
│   │   │   ├── __init__.py
│   │   │   ├── image.py
│   │   │   ├── named_tuple.py
│   │   │   ├── named_tuple_test.py
│   │   │   ├── space.py
│   │   │   ├── sparse.py
│   │   │   ├── sparse_test.py
│   │   │   ├── tensor_spaces.py
│   │   │   ├── tensor_spaces_test.py
│   │   │   ├── typed_dict.py
│   │   │   └── typed_dict_test.py
│   │   ├── task.py
│   │   └── transforms/
│   │       ├── __init__.py
│   │       ├── channels.py
│   │       ├── compose.py
│   │       ├── resize.py
│   │       ├── split_batch.py
│   │       ├── to_tensor.py
│   │       ├── transform.py
│   │       ├── transform_enum.py
│   │       ├── transforms_test.py
│   │       └── utils.py
│   ├── common.puml
│   ├── conftest.py
│   ├── experiments/
│   │   ├── __init__.py
│   │   ├── experiment.py
│   │   ├── experiment_test.py
│   │   ├── hpo_sweep.py
│   │   └── hpo_sweep_test.py
│   ├── main.py
│   ├── methods/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── aux_tasks/
│   │   │   ├── __init__.py
│   │   │   ├── auxiliary_task.py
│   │   │   ├── ewc.py
│   │   │   ├── reconstruction/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── ae.py
│   │   │   │   ├── decoder_for_dataset.py
│   │   │   │   ├── decoders.py
│   │   │   │   └── vae.py
│   │   │   └── transformation_based/
│   │   │       ├── __init__.py
│   │   │       ├── bases.py
│   │   │       └── rotation.py
│   │   ├── avalanche_methods/
│   │   │   ├── __init__.py
│   │   │   ├── agem.py
│   │   │   ├── agem_test.py
│   │   │   ├── ar1.py
│   │   │   ├── ar1_test.py
│   │   │   ├── base.py
│   │   │   ├── base_test.py
│   │   │   ├── conftest.py
│   │   │   ├── cwr_star.py
│   │   │   ├── cwr_star_test.py
│   │   │   ├── ewc.py
│   │   │   ├── ewc_test.py
│   │   │   ├── experience.py
│   │   │   ├── gdumb.py
│   │   │   ├── gdumb_test.py
│   │   │   ├── gem.py
│   │   │   ├── gem_test.py
│   │   │   ├── lwf.py
│   │   │   ├── lwf_test.py
│   │   │   ├── naive.py
│   │   │   ├── naive_test.py
│   │   │   ├── patched_models.py
│   │   │   ├── plugins.py
│   │   │   ├── replay.py
│   │   │   ├── replay_test.py
│   │   │   ├── synaptic_intelligence.py
│   │   │   └── synaptic_intelligence_test.py
│   │   ├── base_method.py
│   │   ├── base_method_test.py
│   │   ├── conftest.py
│   │   ├── d3rlpy_methods/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   └── base_test.py
│   │   ├── ewc_method.py
│   │   ├── ewc_method_test.py
│   │   ├── experience_replay.py
│   │   ├── experience_replay_test.py
│   │   ├── hat.py
│   │   ├── method_test.py
│   │   ├── models/
│   │   │   ├── __init__.py
│   │   │   ├── base_model/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── base_model.py
│   │   │   │   ├── model.py
│   │   │   │   ├── multihead_model.py
│   │   │   │   ├── multihead_model_test.py
│   │   │   │   ├── self_supervised_model.py
│   │   │   │   ├── self_supervised_model_test.py
│   │   │   │   └── semi_supervised_model.py
│   │   │   ├── baseline_model.puml
│   │   │   ├── fcnet.py
│   │   │   ├── forward_pass.py
│   │   │   ├── output_heads/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── classification_head.py
│   │   │   │   ├── output_head.py
│   │   │   │   ├── regression_head.py
│   │   │   │   └── rl/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── actor_critic_head.py
│   │   │   │       ├── episodic_a2c.py
│   │   │   │       ├── episodic_a2c_test.py
│   │   │   │       ├── policy_head.py
│   │   │   │       ├── policy_head_test.py
│   │   │   │       └── wasted_steps_calc.py
│   │   │   ├── output_heads.puml
│   │   │   └── simple_convnet.py
│   │   ├── models.puml
│   │   ├── packnet_method.py
│   │   ├── packnet_method_test.py
│   │   ├── pl_bolts_methods/
│   │   │   └── __init__.py
│   │   ├── pl_dqn.py
│   │   ├── pnn/
│   │   │   ├── __init__.py
│   │   │   ├── layers.py
│   │   │   ├── model_rl.py
│   │   │   ├── model_sl.py
│   │   │   └── pnn_method.py
│   │   ├── random_baseline.py
│   │   ├── random_baseline_test.py
│   │   ├── stable_baselines3_methods/
│   │   │   ├── __init__.py
│   │   │   ├── a2c.py
│   │   │   ├── a2c_test.py
│   │   │   ├── base.py
│   │   │   ├── base_test.py
│   │   │   ├── ddpg.py
│   │   │   ├── ddpg_test.py
│   │   │   ├── dqn.py
│   │   │   ├── dqn_test.py
│   │   │   ├── off_policy_method.py
│   │   │   ├── off_policy_method_test.py
│   │   │   ├── on_policy_method.py
│   │   │   ├── policy_wrapper.py
│   │   │   ├── ppo.py
│   │   │   ├── ppo_test.py
│   │   │   ├── sac.py
│   │   │   ├── sac_test.py
│   │   │   ├── td3.py
│   │   │   └── td3_test.py
│   │   └── trainer.py
│   ├── methods.puml
│   ├── sequoia.puml
│   ├── settings/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── assumptions/
│   │   │   ├── __init__.py
│   │   │   ├── assumptions.puml
│   │   │   ├── base.py
│   │   │   ├── classification.py
│   │   │   ├── context_discreteness.py
│   │   │   ├── context_visibility.py
│   │   │   ├── continual.py
│   │   │   ├── discrete_results.py
│   │   │   ├── iid.py
│   │   │   ├── iid_results.py
│   │   │   ├── incremental.py
│   │   │   ├── incremental_results.py
│   │   │   ├── incremental_test.py
│   │   │   ├── task_incremental.py
│   │   │   └── task_type.py
│   │   ├── base/
│   │   │   ├── __init__.py
│   │   │   ├── base.puml
│   │   │   ├── bases.py
│   │   │   ├── environment.py
│   │   │   ├── objects.py
│   │   │   ├── results.py
│   │   │   ├── setting.py
│   │   │   ├── setting_meta.py
│   │   │   └── setting_test.py
│   │   ├── offline_rl/
│   │   │   └── setting.py
│   │   ├── presets/
│   │   │   ├── __init__.py
│   │   │   ├── cartpole_pixels.yaml
│   │   │   ├── cartpole_state.yaml
│   │   │   ├── cifar10.yaml
│   │   │   ├── cifar100.yaml
│   │   │   ├── classic_control/
│   │   │   │   ├── cartpole.yaml
│   │   │   │   └── mountaincar_continuous.yaml
│   │   │   ├── fashion_mnist.yaml
│   │   │   ├── mnist.yaml
│   │   │   ├── monsterkong/
│   │   │   │   ├── monsterkong_3each.yaml
│   │   │   │   ├── monsterkong_4each.yaml
│   │   │   │   ├── monsterkong_5each.yaml
│   │   │   │   ├── monsterkong_all.yaml
│   │   │   │   ├── monsterkong_jumps.yaml
│   │   │   │   ├── monsterkong_jumps_and_ladders.yaml
│   │   │   │   ├── monsterkong_ladders.yaml
│   │   │   │   └── monsterkong_mix.yaml
│   │   │   ├── mujoco/
│   │   │   │   └── half_cheetah.yaml
│   │   │   ├── rl_track.yaml
│   │   │   └── sl_track.yaml
│   │   ├── rl/
│   │   │   ├── __init__.py
│   │   │   ├── continual/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── environment.py
│   │   │   │   ├── environment_test.py
│   │   │   │   ├── make_env.py
│   │   │   │   ├── make_env_test.py
│   │   │   │   ├── objects.py
│   │   │   │   ├── results.py
│   │   │   │   ├── setting.py
│   │   │   │   ├── setting_test.py
│   │   │   │   ├── tasks.py
│   │   │   │   ├── tasks_test.py
│   │   │   │   └── test_environment.py
│   │   │   ├── discrete/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── multienv_wrappers.py
│   │   │   │   ├── multienv_wrappers_test.py
│   │   │   │   ├── results.py
│   │   │   │   ├── setting.py
│   │   │   │   ├── setting_test.py
│   │   │   │   ├── tasks.py
│   │   │   │   ├── tasks_test.py
│   │   │   │   └── test_environment.py
│   │   │   ├── environment.py
│   │   │   ├── environment_test.py
│   │   │   ├── envs/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── classic_control.py
│   │   │   │   ├── monsterkong.py
│   │   │   │   ├── mujoco/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── half_cheetah.py
│   │   │   │   │   ├── half_cheetah_test.py
│   │   │   │   │   ├── hopper.py
│   │   │   │   │   ├── hopper_test.py
│   │   │   │   │   ├── modified_friction.py
│   │   │   │   │   ├── modified_friction_test.py
│   │   │   │   │   ├── modified_gravity.py
│   │   │   │   │   ├── modified_gravity_test.py
│   │   │   │   │   ├── modified_mass.py
│   │   │   │   │   ├── modified_mass_test.py
│   │   │   │   │   ├── modified_size.py
│   │   │   │   │   ├── modified_size_test.py
│   │   │   │   │   ├── modified_wall.py
│   │   │   │   │   ├── mujoco_model_utils.py
│   │   │   │   │   ├── walker2d.py
│   │   │   │   │   └── walker2d_test.py
│   │   │   │   └── variant_spec.py
│   │   │   ├── incremental/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── objects.py
│   │   │   │   ├── results.py
│   │   │   │   ├── setting.py
│   │   │   │   ├── setting_test.py
│   │   │   │   └── tasks.py
│   │   │   ├── multi_task/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── setting.py
│   │   │   │   └── setting_test.py
│   │   │   ├── objects.py
│   │   │   ├── setting.py
│   │   │   ├── setting_test.py
│   │   │   ├── task_incremental/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── setting.py
│   │   │   │   ├── setting_test.py
│   │   │   │   └── tasks.py
│   │   │   ├── traditional/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── setting.py
│   │   │   │   └── setting_test.py
│   │   │   └── wrappers/
│   │   │       ├── __init__.py
│   │   │       ├── measure_performance.py
│   │   │       ├── measure_performance_test.py
│   │   │       ├── no_typed_objects.py
│   │   │       ├── task_labels.py
│   │   │       └── typed_objects.py
│   │   ├── settings.puml
│   │   └── sl/
│   │       ├── README.md
│   │       ├── __init__.py
│   │       ├── continual/
│   │       │   ├── __init__.py
│   │       │   ├── environment.py
│   │       │   ├── environment_test.py
│   │       │   ├── envs.py
│   │       │   ├── objects.py
│   │       │   ├── results.py
│   │       │   ├── setting.py
│   │       │   ├── setting_test.py
│   │       │   └── wrappers.py
│   │       ├── discrete/
│   │       │   ├── __init__.py
│   │       │   ├── setting.py
│   │       │   └── setting_test.py
│   │       ├── domain_incremental/
│   │       │   ├── __init__.py
│   │       │   ├── setting.py
│   │       │   └── setting_test.py
│   │       ├── environment.py
│   │       ├── environment_test.py
│   │       ├── incremental/
│   │       │   ├── __init__.py
│   │       │   ├── environment.py
│   │       │   ├── environment_test.py
│   │       │   ├── objects.py
│   │       │   ├── results.py
│   │       │   ├── setting.py
│   │       │   ├── setting_test.py
│   │       │   └── unused_batch_transforms.py
│   │       ├── multi_task/
│   │       │   ├── __init__.py
│   │       │   ├── setting.py
│   │       │   └── setting_test.py
│   │       ├── setting.py
│   │       ├── task_incremental/
│   │       │   ├── __init__.py
│   │       │   ├── setting.py
│   │       │   └── setting_test.py
│   │       ├── traditional/
│   │       │   ├── __init__.py
│   │       │   ├── results.py
│   │       │   ├── setting.py
│   │       │   └── setting_test.py
│   │       └── wrappers/
│   │           ├── __init__.py
│   │           ├── measure_performance.py
│   │           └── measure_performance_test.py
│   ├── settings.puml
│   └── utils/
│       ├── __init__.py
│       ├── categorical.py
│       ├── data_utils.py
│       ├── encode.py
│       ├── generic_functions/
│       │   ├── __init__.py
│       │   ├── _namedtuple.py
│       │   ├── _namedtuple_test.py
│       │   ├── concatenate.py
│       │   ├── detach.py
│       │   ├── move.py
│       │   ├── replace.py
│       │   ├── replace_test.py
│       │   ├── singledispatchmethod.py
│       │   ├── slicing.py
│       │   ├── slicing_test.py
│       │   ├── stack.py
│       │   └── to_from_tensor.py
│       ├── logging_utils.py
│       ├── module_dict.py
│       ├── parseable.py
│       ├── plotting.py
│       ├── pretrained_utils.py
│       ├── readme.py
│       ├── serialization.py
│       └── utils.py
├── setup.cfg
├── setup.py
└── versioneer.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .dockerignore
================================================
data
lightning_logs
checkpoints
results


================================================
FILE: .gitattributes
================================================
sequoia/_version.py export-subst


================================================
FILE: .gitignore
================================================
**/__pycache__/
.vscode

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

examples/results/*
results/*
!results/**/*.csv
data/*
*/data/*
!data/**/*.py
scripts/*.png
wandb
.idea
.ipynb_checkpoints
checkpoints
lightning_logs
.pylintrc

**.png

*.gz
*.pt
build
dist
*.egg-info
sequoia/results

mjkey.txt

================================================
FILE: .gitmodules
================================================
[submodule "sequoia/methods/cn_dpm"]
	path = sequoia/methods/cn_dpm
	url = https://github.com/ryanlindeborg/CN-DPM.git
[submodule "examples/clcomp21/Real_DEEL"]
	path = examples/clcomp21/Real_DEEL
	url = https://github.com/mostafaelaraby/Real-DEEL-Dark-Experience.git
[submodule "sequoia/methods/continual_world"]
	path = sequoia/methods/continual_world
	url = https://www.github.com/lebrice/continual_world.git


================================================
FILE: .travis.yml
================================================
language: python
python:
  - "3.7"
install:
  - pip install gym[atari]
  - pip install -r requirements.txt
script:
  - pytest
after_sucess:
  coveralls


================================================
FILE: LICENSE
================================================
                    GNU GENERAL PUBLIC LICENSE
                       Version 3, 29 June 2007

 Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
 Everyone is permitted to copy and distribute verbatim copies
 of this license document, but changing it is not allowed.

                            Preamble

  The GNU General Public License is a free, copyleft license for
software and other kinds of works.

  The licenses for most software and other practical works are designed
to take away your freedom to share and change the works.  By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.  We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors.  You can apply it to
your programs, too.

  When we speak of free software, we are referring to freedom, not
price.  Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.

  To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights.  Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.

  For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received.  You must make sure that they, too, receive
or can get the source code.  And you must show them these terms so they
know their rights.

  Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.

  For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software.  For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.

  Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so.  This is fundamentally incompatible with the aim of
protecting users' freedom to change the software.  The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable.  Therefore, we
have designed this version of the GPL to prohibit the practice for those
products.  If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.

  Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary.  To prevent this, the GPL assures that
patents cannot be used to render the program non-free.

  The precise terms and conditions for copying, distribution and
modification follow.

                       TERMS AND CONDITIONS

  0. Definitions.

  "This License" refers to version 3 of the GNU General Public License.

  "Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.

  "The Program" refers to any copyrightable work licensed under this
License.  Each licensee is addressed as "you".  "Licensees" and
"recipients" may be individuals or organizations.

  To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy.  The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.

  A "covered work" means either the unmodified Program or a work based
on the Program.

  To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy.  Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.

  To "convey" a work means any kind of propagation that enables other
parties to make or receive copies.  Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.

  An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License.  If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.

  1. Source Code.

  The "source code" for a work means the preferred form of the work
for making modifications to it.  "Object code" means any non-source
form of a work.

  A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.

  The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form.  A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.

  The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities.  However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work.  For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.

  The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.

  The Corresponding Source for a work in source code form is that
same work.

  2. Basic Permissions.

  All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met.  This License explicitly affirms your unlimited
permission to run the unmodified Program.  The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work.  This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.

  You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force.  You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright.  Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.

  Conveying under any other circumstances is permitted solely under
the conditions stated below.  Sublicensing is not allowed; section 10
makes it unnecessary.

  3. Protecting Users' Legal Rights From Anti-Circumvention Law.

  No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.

  When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.

  4. Conveying Verbatim Copies.

  You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.

  You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.

  5. Conveying Modified Source Versions.

  You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:

    a) The work must carry prominent notices stating that you modified
    it, and giving a relevant date.

    b) The work must carry prominent notices stating that it is
    released under this License and any conditions added under section
    7.  This requirement modifies the requirement in section 4 to
    "keep intact all notices".

    c) You must license the entire work, as a whole, under this
    License to anyone who comes into possession of a copy.  This
    License will therefore apply, along with any applicable section 7
    additional terms, to the whole of the work, and all its parts,
    regardless of how they are packaged.  This License gives no
    permission to license the work in any other way, but it does not
    invalidate such permission if you have separately received it.

    d) If the work has interactive user interfaces, each must display
    Appropriate Legal Notices; however, if the Program has interactive
    interfaces that do not display Appropriate Legal Notices, your
    work need not make them do so.

  A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit.  Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.

  6. Conveying Non-Source Forms.

  You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:

    a) Convey the object code in, or embodied in, a physical product
    (including a physical distribution medium), accompanied by the
    Corresponding Source fixed on a durable physical medium
    customarily used for software interchange.

    b) Convey the object code in, or embodied in, a physical product
    (including a physical distribution medium), accompanied by a
    written offer, valid for at least three years and valid for as
    long as you offer spare parts or customer support for that product
    model, to give anyone who possesses the object code either (1) a
    copy of the Corresponding Source for all the software in the
    product that is covered by this License, on a durable physical
    medium customarily used for software interchange, for a price no
    more than your reasonable cost of physically performing this
    conveying of source, or (2) access to copy the
    Corresponding Source from a network server at no charge.

    c) Convey individual copies of the object code with a copy of the
    written offer to provide the Corresponding Source.  This
    alternative is allowed only occasionally and noncommercially, and
    only if you received the object code with such an offer, in accord
    with subsection 6b.

    d) Convey the object code by offering access from a designated
    place (gratis or for a charge), and offer equivalent access to the
    Corresponding Source in the same way through the same place at no
    further charge.  You need not require recipients to copy the
    Corresponding Source along with the object code.  If the place to
    copy the object code is a network server, the Corresponding Source
    may be on a different server (operated by you or a third party)
    that supports equivalent copying facilities, provided you maintain
    clear directions next to the object code saying where to find the
    Corresponding Source.  Regardless of what server hosts the
    Corresponding Source, you remain obligated to ensure that it is
    available for as long as needed to satisfy these requirements.

    e) Convey the object code using peer-to-peer transmission, provided
    you inform other peers where the object code and Corresponding
    Source of the work are being offered to the general public at no
    charge under subsection 6d.

  A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.

  A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling.  In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage.  For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product.  A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.

  "Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source.  The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.

  If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information.  But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).

  The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed.  Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.

  Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.

  7. Additional Terms.

  "Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law.  If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.

  When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it.  (Additional permissions may be written to require their own
removal in certain cases when you modify the work.)  You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.

  Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:

    a) Disclaiming warranty or limiting liability differently from the
    terms of sections 15 and 16 of this License; or

    b) Requiring preservation of specified reasonable legal notices or
    author attributions in that material or in the Appropriate Legal
    Notices displayed by works containing it; or

    c) Prohibiting misrepresentation of the origin of that material, or
    requiring that modified versions of such material be marked in
    reasonable ways as different from the original version; or

    d) Limiting the use for publicity purposes of names of licensors or
    authors of the material; or

    e) Declining to grant rights under trademark law for use of some
    trade names, trademarks, or service marks; or

    f) Requiring indemnification of licensors and authors of that
    material by anyone who conveys the material (or modified versions of
    it) with contractual assumptions of liability to the recipient, for
    any liability that these contractual assumptions directly impose on
    those licensors and authors.

  All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10.  If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term.  If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.

  If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.

  Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.

  8. Termination.

  You may not propagate or modify a covered work except as expressly
provided under this License.  Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).

  However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.

  Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.

  Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License.  If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.

  9. Acceptance Not Required for Having Copies.

  You are not required to accept this License in order to receive or
run a copy of the Program.  Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance.  However,
nothing other than this License grants you permission to propagate or
modify any covered work.  These actions infringe copyright if you do
not accept this License.  Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.

  10. Automatic Licensing of Downstream Recipients.

  Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License.  You are not responsible
for enforcing compliance by third parties with this License.

  An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations.  If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.

  You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License.  For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.

  11. Patents.

  A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based.  The
work thus licensed is called the contributor's "contributor version".

  A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version.  For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.

  Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.

  In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement).  To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.

  If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients.  "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.

  If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.

  A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License.  You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.

  Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.

  12. No Surrender of Others' Freedom.

  If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License.  If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all.  For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.

  13. Use with the GNU Affero General Public License.

  Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work.  The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.

  14. Revised Versions of this License.

  The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time.  Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.

  Each version is given a distinguishing version number.  If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation.  If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.

  If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.

  Later license versions may give you additional or different
permissions.  However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.

  15. Disclaimer of Warranty.

  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.

  16. Limitation of Liability.

  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.

  17. Interpretation of Sections 15 and 16.

  If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.

                     END OF TERMS AND CONDITIONS

            How to Apply These Terms to Your New Programs

  If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.

  To do so, attach the following notices to the program.  It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.

    <one line to give the program's name and a brief idea of what it does.>
    Copyright (C) <year>  <name of author>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

Also add information on how to contact you by electronic and paper mail.

  If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:

    <program>  Copyright (C) <year>  <name of author>
    This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
    This is free software, and you are welcome to redistribute it
    under certain conditions; type `show c' for details.

The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License.  Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".

  You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.

  The GNU General Public License does not permit incorporating your program
into proprietary programs.  If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library.  If this is what you want to do, use the GNU Lesser General
Public License instead of this License.  But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.


================================================
FILE: MANIFEST.in
================================================
include versioneer.py
include sequoia/_version.py


================================================
FILE: README.md
================================================
# Sequoia - The Research Tree 

A Playground for research at the intersection of Continual, Reinforcement, and Self-Supervised Learning.

- 5 minute intro: https://www.youtube.com/watch?v=0u48vr96zRQ
- Paper link: https://arxiv.org/abs/2108.01005
- [Continual Supervised Learning Study](https://wandb.ai/sequoia/csl_study) (~6K runs)
- [Continual Reinforcement Learning Study](https://wandb.ai/sequoia/crl_study) (~2300 runs)


## Note: This project is not being actively developed at the moment. If you encounter any difficulties, please create an issue and I'll help you out. 

If you have any questions or comments, please make an issue!

## Motivation:
Most applied ML research generally either proposes new Settings (research problems), new Methods (solutions to such problems), or both.

- When proposing new Settings, researchers almost always have to reimplement or heavily modify existing solutions before they can be applied onto their new problem.

- Likewise, when creating new Methods, it's often necessary to first re-create the experimental setting of other baseline papers, or even the baseline methods themselves, as experimental conditions may be *slightly* different between papers!

The goal of this repo is to:

- Organize various research Settings into an inheritance hierarchy (a tree!), with more *general*, challenging settings with few assumptions at the top, and more constrained problems at the bottom.

- Provide a mechanism for easily reusing existing solutions (Methods) onto new Settings through **Polymorphism**!

- Allow researchers to easily create new, general Methods and quickly gather results on a multitude of Settings, ranging from Supervised to Reinforcement Learning!


## Installation
Requires python >= 3.7


### Basic installation:

```console
$ git clone https://www.github.com/lebrice/Sequoia.git
$ pip install -e Sequoia
```

### Optional Addons
You can also install optional "addons" for Sequoia, each of which either adds new Methods, new environments/datasets, or both.
using either the usual `extras_require` feature of setuptools, or by pip-installing other repositories which register Methods for Sequoia using an `entry_point` in their `setup.py` file.


```console
pip install -e Sequoia[all|<plugin name>]
```

Here are some of the optional addons:

- `avalanche`:
  
  Continual Supervised Learning methods, provided by the [Avalanche](https://github.com/ContinualAI/avalanche) library:
  
    ```console
    $ pip install -e Sequoia[avalanche]
    ```

- `CN-DPM`: Continual Neural Dirichlet Process Mixture model:
    ```console
    $ cd Sequoia
    $ git submodule init  # to setup the submodules
    $ pip install -e sequoia/methods/cn_dpm    
    ```


- `orion`:
  
    Hyper-parameter optimization using [Orion](https://github.com/epistimio/orion)
    ```console
    $ pip install -e Sequoia[orion]
    ```

- `metaworld`:
  
    Continual / Multi-Task Reinforcement Learning environments, thanks to the [metaworld](https://github.com/rlworkgroup/metaworld) package. The usual setup for mujoco needs to be done, Sequoia unfortunately can't do it for you ;(
    ```console
    $ pip install -e Sequoia[metaworld]
    ```

- `monsterkong`:
  
    Continual Reinforcement Learning environment from [the Meta-MonsterKong repo](https://github.com/lebrice/MetaMonsterkong).
    ```console
    $ pip install -e Sequoia[monsterkong]
    ```


- `continual_world`: The Continual World benchmark for Continual Reinforcement learning. Adds 6 different Continual RL Methods to Sequoia.
    ```console
    $ cd Sequoia
    $ git submodule init  # to setup the submodules
    $ pip install -e sequoia/methods/continual_world   
    ```

See the `setup.py` file for all the optional extras.

### Additional Installation Steps for Mac

Install the latest XQuartz app from here: https://www.xquartz.org/releases/index.html

Then run the following commands on the terminal:

```console
mkdir /tmp/.X11-unix 
sudo chmod 1777 /tmp/.X11-unix 
sudo chown root /tmp/.X11-unix/
```

## Documentation overview:


- ### **[Getting Started / Examples (take a look at this first)](examples/)**
- ### Runing Experiments (below)
- ### [Settings overview](sequoia/settings/)
- ### [Methods overview](sequoia/methods/)


### Current Settings & Assumptions:

| Setting                                                                    | RL vs SL                                                                 | clear task boundaries? | Task boundaries given? | Task labels at training time? | task labels at test time | Stationary context? | Fixed action space |
| -------------------------------------------------------------------------- | ------------------------------------------------------------------------ | ---------------------- | ---------------------- | ----------------------------- | ------------------------ | ------------------- | ------------------ |
| [Continual RL](sequoia/settings/rl/continual/setting.py)                   | RL                                                                       | no                     | no                     | no                            | no                       | no                  | no(?)              |
| [Discrete Task-Agnostic RL](sequoia/settings/rl/discrete/setting.py)       | RL                                                                       | **yes**                | **yes**                | no                            | no                       | no                  | no(?)              |
| [Incremental RL](sequoia/settings/rl/incremental/setting.py)               | RL                                                                       | **yes**                | **yes**                | **yes**                       | no                       | no                  | no(?)              |
| [Task-Incremental RL](sequoia/settings/rl/task_incremental/setting.py)     | RL                                                                       | **yes**                | **yes**                | **yes**                       | **yes**                  | no                  | no(?)              |
| [Traditional RL](sequoia/settings/rl/task_incremental/setting.py)          | RL                                                                       | **yes**                | **yes**                | **yes**                       | no                       | **yes**             | no(?)              |
| [Multi-Task RL](sequoia/settings/rl/task_incremental/setting.py)           | RL                                                                       | **yes**                | **yes**                | **yes**                       | **yes**                  | **yes**             | no(?)              |
| [Continual SL](sequoia/settings/sl/continual/setting.py)                   | SL                                                                       | no                     | no                     | no                            | no                       | no                  | no                 |
| [Discrete Task-Agnostic SL](sequoia/settings/sl/discrete/setting.py)       | SL                                                                       | **yes**                | no                     | no                            | no                       | no                  | no                 |
| [(Class) Incremental SL](sequoia/settings/sl/incremental/setting.py)       | SL                                                                       | **yes**                | **yes**                | no                            | no                       | no                  | no                 |
| [Domain-Incremental SL](sequoia/settings/sl/domain_incremental/setting.py) | SL                                                                       | **yes**                | **yes**                | **yes**                       | no                       | no                  | **yes**            |
| [Task-Incremental SL](sequoia/settings/sl/task_incremental/setting.py)     | SL                                                                       | **yes**                | **yes**                | **yes**                       | **yes**                  | no                  | no                 |
| [Traditional SL](sequoia/settings/sl/traditional/setting.py)               | SL                                                                       | **yes**                | **yes**                | **yes**                       | no                       | **yes**             | no                 |
| [Multi-Task SL](sequoia/settings/sl/multi_task/setting.py)                 | SL                                                                       | **yes**                | **yes**                | **yes**                       | **yes**                  | **yes**             | no                 |
<!--|                                                                        | [Class-Incremental SL](sequoia/settings/sl/class_incremental/setting.py) | SL                     | **yes**                | **yes**                       | no                       | no                  | no                 |  |-->

#### Notes

- **Active / Passive**:
    Active settings are Settings where the next observation depends on the current action, i.e. where actions influence future observations, e.g. Reinforcement Learning.
    Passive settings are Settings where the current actions don't influence the next observations (e.g. Supervised Learning.)

- **Bold entries** in the table mark constant attributes which cannot be
   changed from their default value.

- \*: The environment is changing constantly over time in `ContinualRLSetting`, so
    there aren't really "tasks" to speak of.



## Running experiments

--> **(Reminder) First, take a look at the [Examples](/examples)** <--

#### Directly in code:

```python
from sequoia.settings import TaskIncrementalSLSetting
from sequoia.methods import BaseMethod
# Create the setting
setting = TaskIncrementalSLSetting(dataset="mnist")
# Create the method
method = BaseMethod(max_epochs=1)
# Apply the setting to the method to generate results.
results = setting.apply(method)
print(results.summary())
```

### Command-line:

```console
$ sequoia --help
usage: sequoia [-h] [--version] {run,sweep,info} ...

Sequoia - The Research Tree 

Used to run experiments, which consist in applying a Method to a Setting.

optional arguments:
  -h, --help        show this help message and exit
  --version         Displays the installed version of Sequoia and exits.

command:
  Command to execute

  {run,sweep,info}
    run             Run an experiment on a given setting.
    sweep           Run a hyper-parameter optimization sweep.
    info            Displays some information about a Setting or Method.
```
For example:
```console
$ sequoia run [--debug] <setting> (setting arguments) <method> (method arguments)
$ sequoia sweep [--debug] <setting> (setting arguments) <method> (method arguments)
$ sequoia info [setting or method]
```

For a detailed description of all the arguments, use the `--help` command for any of the actions:
```console 
$ sequoia --help
$ sequoia run --help
$ sequoia run <some_setting> --help
$ sequoia run <some_setting> <some_method> --help
$ sequoia sweep --help
$ sequoia sweep <some_setting> --help
$ sequoia sweep <some_setting> <some_method> --help
```

For example:

```console
$ sequoia run --debug task_incremental_sl --dataset mnist random_baseline
```

For example:
- Run the BaseMethod on task-incremental MNIST, with one epoch per task, and without wandb:
    ```console
    $ sequoia run task_incremental_sl --dataset mnist base --max_epochs 1
    ```
- Run the PPO Method from stable-baselines3 on an incremental RL setting, with the default dataset (CartPole) and 5 tasks: 
    ```console
    $ sequoia --setting incremental_rl --nb_tasks 5 --method sb3.ppo --steps_per_task 10_000
    ```

More questions? Please let us know by creating an issue or posting in the discussions!


================================================
FILE: dockers/.gitignore
================================================
# Hiding the 'eai' dockerfile
eai


================================================
FILE: dockers/base/Dockerfile
================================================
# syntax=docker/dockerfile:1
FROM pytorch/pytorch:1.8.1-cuda11.1-cudnn8-runtime
USER root
EXPOSE 2222
EXPOSE 6000
EXPOSE 8088
ENV LANG=en_US.UTF-8
RUN apt update && \
    apt install -y \
    git wget zsh unzip rsync build-essential \
        ca-certificates supervisor openssh-server ssh \
        curl wget vim procps htop locales nano man net-tools iputils-ping \
        libosmesa6-dev libgl1-mesa-glx libgl1-mesa-dev libglu1-mesa-dev libglfw3 \
        libglfw3-dev freeglut3 xvfb ffmpeg curl patchelf cmake zlib1g zlib1g-dev \
        swig libopenmpi-dev aptitude screen xz-utils locate && \
    sed -i "s/# en_US.UTF-8/en_US.UTF-8/" /etc/locale.gen && locale-gen && \
    useradd -m -u 13011 -s /bin/zsh toolkit && passwd -d toolkit && \
    useradd -m -u 13011 -s /bin/zsh --non-unique console && passwd -d console && \
    useradd -m -u 13011 -s /bin/zsh --non-unique _toolchain && passwd -d _toolchain && \
    useradd -m -u 13011 -s /bin/bash --non-unique coder && passwd -d coder && \
    chown -R toolkit:toolkit /run /etc/shadow /etc/profile && \
    apt autoremove --purge && apt-get clean && \
    rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \
    echo ssh >> /etc/securetty && \
    rm -f /etc/legal /etc/motd

# RUN conda install -c conda-forge opencv
RUN conda install matplotlib numpy scipy hdf5 h5py cython
# RUN pip install \ 
#     # Needed to build atari_py: (WHY don't they put it in a build_requires?)
#     lockfile 
    # fasteners \ 
    # pybullet \
    # wandb \
    # tqdm \
    # # tensorflow \
    # bs4 \
    # pandas notebook plotly tqdm pyamg lxml numba pyyaml torchmeta

# Removing this `torchtext` package, seems to be causing an import issue in pytorch!
RUN pip uninstall -y torchtext
RUN chown -R toolkit:root /workspace
RUN chmod -R 777 /workspace
# this doesn't do anything
RUN adduser toolkit sudo
RUN chown -R toolkit:root /mnt/
# RUN mkdir -p /mnt/home
RUN chmod 777 /opt/conda
RUN chmod 777 /mnt
RUN chmod -R 777 /workspace
SHELL [ "conda", "run", "-n", "base", "/bin/bash", "-c"]

## Unused zshell and oh-my-zsh stuff:
# RUN sh -c "$(wget -O- https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)"
# RUN sed -i 's/robbyrussell/clean/' ~/.zshrc
# RUN sed -i 's/plugins=(git)/plugins=(git debian history-substring-search)/' ~/.zshrc


# MuJoCo-related stuff:
# RUN curl -o ~/mujoco200_linux.zip -L -C - https://www.roboti.us/download/mujoco200_linux.zip
# RUN curl -o ~/mjpro150_linux.zip -L -C -  https://www.roboti.us/download/mjpro150_linux.zip
# RUN cd ~ && unzip mujoco200_linux.zip && rm mujoco200_linux.zip
# RUN cd ~ && unzip mjpro150_linux.zip && rm mjpro150_linux.zip
# RUN mkdir ~/.mujoco
# RUN mv ~/mujoco200_linux ~/.mujoco/mujoco200
# RUN mv ~/mjpro150 ~/.mujoco
# RUN echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin" >> ~/.bashrc
# RUN echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:~/.mujoco/mjpro150/bin" >> ~/.bashrc
# COPY mjkey.txt /home/toolkit/.mujoco/
# ENV LD_LIBRARY_PATH /home/toolkit/.mujoco/mujoco200/bin:${LD_LIBRARY_PATH}
# ENV LD_LIBRARY_PATH /home/toolkit/.mujoco/mjpro150/bin:${LD_LIBRARY_PATH}
# RUN mkdir /workspace/tools
# RUN cd /workspace/tools && git clone https://github.com/openai/mujoco-py.git && pip install -e mujoco-py

# For Wandb (TODO: Doesn't appear to work, using env variable with WANDB_API_KEY
# instead.)
# COPY .netrc /home/toolkit/.netrc
# COPY .netrc /root/.netrc
# COPY .netrc /tmp/.netrc

VOLUME /mnt/data
VOLUME /mnt/results
# USER toolkit

ENV DATA_DIR=/mnt/data
ENV RESULTS_DIR=/mnt/results
ENV WANDB_DIR=/mnt/results

# VOLUME /mnt/home
# WORKDIR /mnt/home
ENV PATH /home/toolkit/.local/bin:${PATH}
# RUN cd /workspace/tools && git clone https://github.com/openai/gym.git && cd gym && pip install -e '.[all]'
# RUN cd /workspace/tools && git clone https://github.com/openai/baselines.git && cd baselines && pip install -e .
RUN cd /workspace/ && git clone https://github.com/lebrice/Sequoia.git
RUN pip install -e /workspace/Sequoia[no_mujoco]
ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "base", "/bin/bash", "-c"]


================================================
FILE: dockers/base/build.sh
================================================
#!/bin/bash
set -o errexit    # Used to exit upon error, avoiding cascading errors
set -o errtrace    # Show error trace
set -o pipefail   # Unveils hidden failures
set -o nounset    # Exposes unset variables

if git diff-index --quiet HEAD --; then
    # No changes
    echo "All good, no uncommitted changes."
else
    # Changes
    echo "Can't build dockers when there are uncommited changes!"
    exit 1
fi


echo "Building the 'base' dockerfile"
docker build . --file dockers/base/Dockerfile --tag sequoia:base

REGISTRY=${REGISTRY:-`docker info | sed '/Username:/!d;s/.* //'`}
echo "Using registry $REGISTRY"

docker tag sequoia:base $REGISTRY/sequoia:base
docker push $REGISTRY/sequoia:base


================================================
FILE: dockers/branch/Dockerfile
================================================
# syntax=docker/dockerfile:1
FROM lebrice/sequoia:base
USER root
SHELL [ "conda", "run", "-n", "base", "/bin/bash", "-c"]
ARG BRANCH=master
RUN conda install -y cudatoolkit
RUN cd /workspace/Sequoia && git fetch -p && git checkout ${BRANCH} && pip install -e .[no_mujoco]
ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "base", "/bin/bash", "-c"]


================================================
FILE: dockers/branch/build.sh
================================================
#!/bin/bash
set -o errexit    # Used to exit upon error, avoiding cascading errors
set -o errtrace    # Show error trace
set -o pipefail   # Unveils hidden failures
set -o nounset    # Exposes unset variables

export CURRENT_BRANCH="`git branch --show-current`"
export BRANCH=${BRANCH:-$CURRENT_BRANCH}
echo "Using branch $BRANCH"

export REGISTRY=${REGISTRY:-`docker info | sed '/Username:/!d;s/.* //'`}
echo "Using registry $REGISTRY"


if git diff-index --quiet HEAD --; then
    # No changes
    echo "all good."
else
    # Changes
    echo "Can't build dockers when you have uncommited changes!"
    exit 1
fi
git push

echo "Building the container for branch $BRANCH (no cache)"
docker build . --file dockers/branch/Dockerfile \
    --no-cache \
    --build-arg BRANCH=$BRANCH \
    --tag sequoia:$BRANCH

docker tag sequoia:$BRANCH $REGISTRY/sequoia:$BRANCH
docker push $REGISTRY/sequoia:$BRANCH


================================================
FILE: docs/diagrams/src/gym.puml
================================================
@startuml gym

package gym {
    package spaces as gym.spaces {
        abstract class Space<T> {
            + contains(T sample) -> bool
            + sample() -> T
        }
        class Box extends Space {
            + low: np.ndarray
            + high: np.ndarray
            + shape: Tuple[int, ...]
            + dtype: np.dtype
            + contains(np.ndarray sample) -> bool
            + sample() -> np.ndarray
        }

        class Discrete extends Space {
            + n: int
            + contains(int sample) -> bool
            + sample() -> int
        }

        class Tuple extends Space {
            + spaces: Tuple[Space]
            + contains(Tuple sample) -> bool
            + sample() -> Tuple
        }
        ' Tuple spaces contain other spaces.
        Tuple *--  Space

        class Dict extends Space {
            + spaces: dict[str, Space]
            + contains(dict sample) -> bool
            + sample() -> dict
        }
        ' Same for Dicts.
        Dict *--  Space
    }

    abstract class gym.Env<Obs, Act, Rew> {
        + observation_space: Space<Obs>
        + action_space: Space<Act> 
        + step(Actions) -> Tuple[Obs, Rew, bool, dict]
        + reset() -> Obs
    }
    gym.Env .. Space

    abstract class Wrapper extends gym.Env{
        + env: gym.Env
    }
}

@enduml

================================================
FILE: docs/diagrams/src/pytorch_lightning.puml
================================================
@startuml pytorch_lightning
package pytorch_lightning {
    abstract class LightningDataModule {
        {abstract} + prepare_data()
        {abstract} + setup()
        {abstract} + train_dataloader(): torch.DataLoader
        {abstract} + val_dataloader(): torch.DataLoader
        {abstract} + test_dataloader(): torch.DataLoader
    }
    abstract class LightningModule {
        {abstract} + train_step(batch)
        + val_step()
        + test_step()
    }
}
@enduml

================================================
FILE: docs/diagrams/src/seq_diagram.puml
================================================
@startuml ContinualRLSetting
header Page Header
footer Page %page% of %lastpage%
title Overall Evaluation loop - Sequoia
note over User, Setting
Even though this diagram is somewhat large,
keep in mind that there are but a few key methods:
1. Method.configure()
2. Method.fit()
3. Method.get_actions()
4. Method.on_task_switch()  
end note

actor User
participant Setting << (A,#2121FF) Setting >>
collections TrainEnv
collections ValidEnv
collections TestEnv
' autoactivate on
participant Method << (C,#ADD1B2) Method >>
participant Model << (C,#ADD1B2) nn.Module >>
' activate Setting
' autoactivate on



User -> Setting: Create the Setting
Setting -> TrainEnv: Create temp env
return observation / action / reward spaces
User <-- Setting


User -> Method: Create the Method
User <-- Method


User -> Setting: setting.apply(method)

Setting -> Method: **method.configure(setting)**

    Method -> Method: create model, optimizer, etc.
    ' deactivate Method

    Method -> Model: Create
    ' activate Model
Setting <-- Method

autoactivate off

== training ==


group train_loop [for each task `i`]
    alt task_labels_at_train_time?
    else True
        Setting -> Method: **on_task_switch(i)**
        Method -> Method: consolidate knowledge, \n switch output heads, etc.
        Setting <-- Method
    else False 
        Setting -> Method: **on_task_switch(None)**
        Method -> Method: consolidate knowledge etc.
        Setting <-- Method

    end

    Setting -> TrainEnv: Create train env for task i
    Setting -> ValidEnv: Create valid env for task i
    ' activate ValidEnv
    Setting -> Method: **Method.fit(train_env, valid_env)**
    ' loop
    
    ' alt loop
    group loop
        note right
        The Method is free to do whatever
        it wants with the Train and Valid envs
        of the current task.
        end note
        Method -> Model: train()
        return

        ' group training
        Model <--> TrainEnv: train with the env
        ...

        Method -> Model: eval()
        return
        Model <--> ValidEnv: Evaluate performance
        ...
        ' autoactivate on
        ' Model -> TrainEnv: reset
        ' return Observations
        ' Model -> TrainEnv: step(actions)
        ' return Observations, Rewards, done, info
    end

end


== testing ==

note over Setting, Method
We currently only perform the test loop after training is complete on all tasks,
however, in the future we will run this test loop after the end of training on
each task. See issue#46 on GitHub for more info.
end note

group test_loop
    Setting --> Setting: Concatenate datasets for all tasks, \n create test wrappers, etc.
    Setting --> TestEnv: Create test environment (all tasks)
    autoactivate on
    Setting -> TestEnv: reset
    return observations
    ' loop
        alt
        else normal step

            Setting -> Method: **get_actions(observations)**
            Method -> Model: predict(x)
            return y_pred
            return actions
            Setting -> TestEnv: step(actions)
            return observations, rewards, done, info

        else end of episode reached
            Setting -> TestEnv: reset
            return observations

        else task boundary is reached
            ' TestEnv --> Method: **on_task_switch(i)**
            
            alt known_task_boundaries?
            else False: do nothing
                note over Method
                When known_task_boundaries=False, the Method doesn't get informed
                of task boundaries (it might have to perform some kind of change-point
                detection, for instance).
                end note
            else True
                note over TestEnv
                Minor note: here it's the TestEnv
                that calls the Method when a
                task boundary is reached.
                end note

                alt task_labels_at_test_time?
                else true
                    ' note right of Setting: If task labels are given
                    TestEnv -> Method: **on_task_switch(i)**
                    autoactivate off
                    Method -> Method
                    autoactivate on
                    return

                else false 
                    TestEnv -> Method: **on_task_switch(None)**
                    autoactivate off
                    Method -> Method
                    autoactivate on
                    return
                end
            end
        end
    autoactivate off
    note over TestEnv
    The test environment uses a `Monitor` wrapper, and gather
    statistics of interest like the mean reward, accuracy, etc.    
    end note
    TestEnv -> Setting: report performance of the Method
end
Setting -> Setting: Weigh performance of each task \n depending on the Setting
User <-- Setting: Results
' return Results
@enduml

================================================
FILE: examples/README.md
================================================
# Examples

Here's a brief description of the examples in this folder:

## Prerequisites:
- [Intro to dataclasses & simple-parsing](prerequisites/dataclasses_example.py)
- [Basics of openai gym](https://github.com/openai/gym#basics)


## Basic examples:

- [pl_example.py](basic/pl_example.py):
    **Recommended entry-point for ML Practicioners**. Shows an example method and model
    using [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning).
    This is the best way to get started if you don't mind some level of abstraction in your code
    (a good thing in general!)


- [quick_demo.ipynb](basic/quick_demo.ipynb):
    **Recommended entry-point for new users**. Simple demo showing how to create a `Method`
    from scratch that targets a Supervised CL `Setting`, as well as how to
    improve this simple Method using a simple regularization loss.

    - [quick_demo.py](basic/quick_demo.py): First part of the above
        notebook: shows how to create a Method from scratch that
        targets a Supervised CL Setting.
    - [quick_demo_ewc.py](basic/quick_demo_ewc.py): Second part of the
        above notebook: shows how to improve upon an existing Method by adding a
        CL regularization loss.

- [baseline_demo.py](basic/baseline_demo.py): Shows how the
    BaseMethod can be applied to get results in both RL and SL Settings.


## CLVision Workshop Submission Examples:

Examples in this folder are aimed at solving the supervised learning track of the competition.

Each example builds on top of the previous, in a manner that improves the overall performance you can expect on any given CL setting.

As such, it is recommended that you take a look at the examples in the following order:

0. [DummyMethod](clcomp21/dummy_method.py)
    Non-parametric method that simply returns a random prediction for each observation.

1. [Simple Classifier](clcomp21/classifier.py):
    Standard neural net classifier without any CL-related mechanism. Works in the SL track, but has very poor performance.

2. [Multi-Head / Task Inference Classifier](clcomp21/multihead_classifier.py):
    Performs multi-head prediction, and a simple form of task inference. Gets better results that the example.

3. [CL Regularized Classifier](clcomp21/regularization_example.py):
    Adds a simple CL regularization loss to the multihead classifier above.


## Advanced examples:

- [RL_and_SL_demo.py](advanced/RL_and_SL_demo.py):
    
    Example that shows how the BaseMethod can easily be extended by adding
    AuxiliaryTasks to it, allows you to get results in both RL and SL.

- [continual_rl_demo.py](advanced/ewc_in_rl.py):
    
    Demonstrates how to create Reinforcement Learning (RL) Settings, as well as
    how methods from [stable-baselines3](https://github.com/DLR-RM/stable-baselines3)
    can be applied to these settings.


- [Extending Stable-Baselines3 (RL Settings only)](advanced/ewc_in_rl.py):

    (Not recommended for new users!)
    Very specific example which shows how, if you really wanted to, you could
    extend one or more of the Methods from SB3 with some kind of regularization
    loss hooking into the internal optimization loop of SB3.


================================================
FILE: examples/__init__.py
================================================


================================================
FILE: examples/advanced/RL_and_SL_demo.py
================================================
""" Demo where we add the same regularization loss from the other examples, but
this time as an `AuxiliaryTask` on top of the BaseMethod.

This makes it easy to create CL methods that apply to both RL and SL Settings!
"""

import copy
import random
import sys
from argparse import Namespace
from dataclasses import dataclass
from typing import ClassVar, List

import torch
from simple_parsing import ArgumentParser, field
from torch import Tensor

# This "hack" is required so we can run `python examples/custom_baseline_demo.py`
sys.path.extend([".", ".."])

from sequoia.common.config import Config
from sequoia.common.loss import Loss
from sequoia.methods import BaseMethod
from sequoia.methods.aux_tasks import AuxiliaryTask
from sequoia.methods.models import BaseModel, ForwardPass
from sequoia.methods.trainer import TrainerConfig
from sequoia.settings import Environment, RLSetting, Setting
from sequoia.utils.utils import camel_case, dict_intersection
from sequoia.utils.logging_utils import get_logger

logger = get_logger(__name__)


class SimpleRegularizationAuxTask(AuxiliaryTask):
    """Same regularization loss as in the previous examples, this time
    implemented as an `AuxiliaryTask`, which gets added to the BaseModel,
    making it applicable to both RL and SL.

    This adds a CL regularizaiton loss to the BaseModel.

    The most important methods of `AuxiliaryTask` is `get_loss`, which should
    return a `Loss` for the given forward pass and resulting rewards/labels.
    Take a look at the `AuxiliaryTask` class for more info.
    """

    name: ClassVar[str] = "simple_regularization"

    @dataclass
    class Options(AuxiliaryTask.Options):
        """Hyper-parameters / configuration options of this auxiliary task."""

        # Coefficient used to scale this regularization loss before it gets
        # added to the 'base' loss of the model.
        coefficient: float = 0.01
        # Wether to use the absolute difference of the weights or the difference
        # in the `regularize` method below.
        use_abs_diff: bool = False
        # The norm term for the 'distance' between the current and old weights.
        distance_norm: int = 2

    def __init__(
        self,
        *args,
        name: str = None,
        options: "SimpleRegularizationAuxTask.Options" = None,
        **kwargs,
    ):
        super().__init__(*args, options=options, name=name, **kwargs)
        self.options: SimpleRegularizationAuxTask.Options
        self.previous_task: int = None
        # TODO: Figure out a clean way to persist this dict into the state_dict.
        self.previous_model_weights: Dict[str, Tensor] = {}
        self.n_switches: int = 0

    def get_loss(self, forward_pass: ForwardPass, y: Tensor = None) -> Loss:
        """Get a `Loss` for the given forward pass and resulting rewards/labels.

        Take a look at the `AuxiliaryTask` class for more info,

        NOTE: This is the same simplified version of EWC used throughout the
        other examples: the loss is the P-norm between the current weights and
        the weights as they were on the begining of the task.
        Also note, this particular example doesn't actually use the provided
        arguments.
        """
        if self.previous_task is None:
            # We're in the first task: do nothing.
            return Loss(name=self.name)

        old_weights: Dict[str, Tensor] = self.previous_model_weights
        new_weights: Dict[str, Tensor] = dict(self.model.named_parameters())

        loss = 0.0
        for weight_name, (new_w, old_w) in dict_intersection(new_weights, old_weights):
            loss += torch.dist(new_w, old_w.type_as(new_w), p=self.options.distance_norm)

        ewc_loss = Loss(name=self.name, loss=loss)
        return ewc_loss

    def on_task_switch(self, task_id: int) -> None:
        """Executed when the task switches (to either a new or known task)."""
        if not self.enabled:
            return
        if self.previous_task is None and self.n_switches == 0:
            logger.debug(f"Starting the first task, no update.")
            pass
        elif task_id is None or task_id != self.previous_task:
            logger.debug(
                f"Switching tasks: {self.previous_task} -> {task_id}: "
                f"Updating the 'anchor' weights."
            )
            self.previous_task = task_id
            self.previous_model_weights.clear()
            self.previous_model_weights.update(
                copy.deepcopy({k: v.detach() for k, v in self.model.named_parameters()})
            )
        self.n_switches += 1


class CustomizedBaselineModel(BaseModel):
    @dataclass
    class HParams(BaseModel.HParams):
        """Hyper-parameters of our customized baseline model."""

        # Hyper-parameters of our simple new auxiliary task.
        simple_reg: SimpleRegularizationAuxTask.Options = field(
            default_factory=SimpleRegularizationAuxTask.Options
        )

    def __init__(
        self,
        setting: Setting,
        hparams: "CustomizedBaselineModel.HParams",
        config: Config,
    ):
        super().__init__(setting=setting, hparams=hparams, config=config)
        self.hp: CustomizedBaselineModel.HParams

        # Here we add our new auxiliary task:
        self.add_auxiliary_task(SimpleRegularizationAuxTask(options=self.hp.simple_reg))

        # Or, add replay buffers of some sort:
        self.replay_buffer: List = []

        # (...)


@dataclass
class CustomMethod(BaseMethod, target_setting=Setting):
    """Example methods which adds regularization to the baseline in RL and SL.

    This extends the `BaseMethod` by adding the simple regularization
    auxiliary task defined above to the `BaseModel`.

    NOTE: Since this class inherits from `BaseMethod`, which targets the
    `Setting` setting, i.e. the "root" node, it is applicable to all settings,
    both in RL and SL. However, you could customize the `target_setting`
    argument above to limit this to any particular subtree (only SL, only RL,
    only when task labels are present, etc).
    """

    # Hyper-parameters of the customized Baseline Model used by this method.
    hparams: CustomizedBaselineModel.HParams = field(
        default_factory=CustomizedBaselineModel.HParams
    )

    def __init__(
        self,
        hparams: CustomizedBaselineModel.HParams = None,
        config: Config = None,
        trainer_options: TrainerConfig = None,
        **kwargs,
    ):
        super().__init__(
            hparams=hparams,
            config=config,
            trainer_options=trainer_options,
            **kwargs,
        )

    def create_model(self, setting: Setting) -> CustomizedBaselineModel:
        """Creates the Model to be used for the given `Setting`."""
        return CustomizedBaselineModel(setting=setting, hparams=self.hparams, config=self.config)

    def configure(self, setting: Setting):
        """Configure this Method before being trained / tested on this Setting."""
        super().configure(setting)

        # For example, change the value of the coefficient of our
        # regularization loss when in RL vs SL:
        if isinstance(setting, RLSetting):
            self.hparams.simple_reg.coefficient = 0.01
        else:
            self.hparams.simple_reg.coefficient = 1.0

    def fit(self, train_env: Environment, valid_env: Environment):
        """Called by the Setting to let the Method train on a given task.

        You can do whatever you want with the train and valid
        environments. As it is currently, in most `Settings`, the valid
        environment will contain data from only the current task. (See issue at
        https://github.com/lebrice/Sequoia/issues/46 for more context).
        """
        return super().fit(train_env=train_env, valid_env=valid_env)

    @classmethod
    def add_argparse_args(cls, parser: ArgumentParser):
        """Adds command-line arguments for this Method to an argument parser.

        NOTE: This doesn't do anything differently than the base implementation,
        but it's included here just for illustration purposes.
        """
        # 'dest' is where the arguments will be stored on the namespace.
        dest = camel_case(cls.__qualname__)
        # Add all command-line arguments. This adds arguments for all fields of
        # this dataclass.
        parser.add_arguments(cls, dest=dest)
        # You could add arguments here if you wanted to:
        # parser.add_argument("--foo", default=1.23, help="example argument")

    @classmethod
    def from_argparse_args(cls, args: Namespace):
        """Create an instance of this class from the parsed arguments."""
        # Retrieve the parsed arguments:
        dest = camel_case(cls.__qualname__)
        method: CustomMethod = getattr(args, dest)
        # You could retrieve other arguments like so:
        # foo: int = args.foo
        return method


def demo_manual():
    """Apply the custom method to a Setting, creating both manually in code."""
    # Create any Setting from the tree:
    from sequoia.settings import TaskIncrementalRLSetting, TaskIncrementalSLSetting

    # setting = TaskIncrementalSLSetting(dataset="mnist", nb_tasks=5)  # SL
    setting = TaskIncrementalRLSetting(  # RL
        dataset="cartpole",
        train_task_schedule={
            0: {"gravity": 10, "length": 0.5},
            5000: {"gravity": 10, "length": 1.0},
        },
        train_max_steps=10_000,
    )

    ## Create the BaseMethod:
    config = Config(debug=True)
    trainer_options = TrainerConfig(max_epochs=1)
    hparams = BaseModel.HParams()
    base_method = BaseMethod(hparams=hparams, config=config, trainer_options=trainer_options)

    ## Get the results of the baseline method:
    base_results = setting.apply(base_method, config=config)

    ## Create the CustomMethod:
    config = Config(debug=True)
    trainer_options = TrainerConfig(max_epochs=1)
    hparams = CustomizedBaselineModel.HParams()
    new_method = CustomMethod(hparams=hparams, config=config, trainer_options=trainer_options)

    ## Get the results for the 'improved' method:
    new_results = setting.apply(new_method, config=config)

    print(f"\n\nComparison: BaseMethod vs CustomMethod")
    print("\n BaseMethod results: ")
    print(base_results.summary())

    print("\n CustomMethod results: ")
    print(new_results.summary())


def demo_command_line():
    """Run the same demo as above, but customizing the Setting and Method from
    the command-line.

    NOTE: Remember to uncomment the function call below to use this instead of
    demo_simple!
    """
    ## Create the `Setting` and the `Config` from the command-line, like in
    ## the other examples.
    parser = ArgumentParser(description=__doc__)

    ## Add command-line arguments for any Setting in the tree:
    from sequoia.settings import TaskIncrementalRLSetting, TaskIncrementalSLSetting

    # parser.add_arguments(TaskIncrementalSLSetting, dest="setting")
    parser.add_arguments(TaskIncrementalRLSetting, dest="setting")
    parser.add_arguments(Config, dest="config")

    # Add the command-line arguments for our CustomMethod (including the
    # arguments for our simple regularization aux task).
    CustomMethod.add_argparse_args(parser, dest="method")

    args = parser.parse_args()

    setting: ClassIncrementalSetting = args.setting
    config: Config = args.config

    # Create the BaseMethod:
    base_method = BaseMethod.from_argparse_args(args, dest="method")
    # Get the results of the BaseMethod:
    base_results = setting.apply(base_method, config=config)

    ## Create the CustomMethod:
    new_method = CustomMethod.from_argparse_args(args, dest="method")
    # Get the results for the CustomMethod:
    new_results = setting.apply(new_method, config=config)

    print(f"\n\nComparison: BaseMethod vs CustomMethod:")
    print(base_results.summary())
    print(new_results.summary())


if __name__ == "__main__":
    demo_manual()
    # demo_command_line()


================================================
FILE: examples/advanced/continual_rl_demo.py
================================================
import sys

# This "hack" is required so we can run `python examples/continual_rl_demo.py`
sys.path.extend([".", ".."])
from sequoia.methods.stable_baselines3_methods import A2CMethod, DQNMethod
from sequoia.settings import (
    ContinualRLSetting,
    IncrementalRLSetting,
    RLSetting,
    TaskIncrementalRLSetting,
)

if __name__ == "__main__":
    task_schedule = {
        0: {"gravity": 10, "length": 0.2},
        1000: {"gravity": 100, "length": 1.2},
        2000: {"gravity": 10, "length": 0.2},
    }
    setting = ContinualRLSetting(
        # setting = IncrementalRLSetting(
        # setting = TaskIncrementalRLSetting(
        # setting = RLSetting(
        dataset="CartPole-v1",
        train_max_steps=2000,
        train_task_schedule=task_schedule,
    )
    # Create the method to use here:
    # NOTE: The DQN method doesn't seem to work nearly as well as A2C.
    # method = DQNMethod(train_steps_per_task=1_000)
    method = A2CMethod(train_steps_per_task=1_000)
    # You could change the hyper-parameters of the method too:
    # method.hparams.buffer_size = 100

    results = setting.apply(method)
    print(results.summary())


================================================
FILE: examples/advanced/ewc_in_rl.py
================================================
""" Example of how to add a simplified regularization method to algos from
stable-baseline-3.
"""
from collections import deque
from copy import deepcopy
from dataclasses import dataclass
from typing import ClassVar, Dict, List, Optional, Type, TypeVar, Union

import gym
import torch
from nngeometry.generator.jacobian import Jacobian
from nngeometry.layercollection import LayerCollection
from nngeometry.object.pspace import PMatAbstract, PMatDiag, PMatKFAC, PVector
from simple_parsing import choice
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.policies import BasePolicy
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset

from sequoia.methods import register_method
from sequoia.methods.stable_baselines3_methods import StableBaselines3Method
from sequoia.methods.stable_baselines3_methods.policy_wrapper import PolicyWrapper
from sequoia.settings import TaskIncrementalRLSetting
from sequoia.settings.base import Actions, Environment, Method, Observations
from sequoia.utils.utils import dict_intersection
from sequoia.utils.logging_utils import get_logger

logger = get_logger(__name__)

Policy = TypeVar("Policy", bound=BasePolicy)


class NormRegularizer(PolicyWrapper[Policy]):
    """A Wrapper class that adds a `on_task_switch` and a `ewc_loss` method to
    an nn.Module (in this particular case, a Policy from SB3.)

    By subclassing PolicyWrapper, this is able to leverage some 'hooks' into the
    optimizer of the policy.
    """

    def __init__(self: Policy, *args, reg_coefficient: float = 1.0, ewc_p_norm: int = 2, **kwargs):
        super().__init__(*args, **kwargs)
        self.reg_coefficient = reg_coefficient
        self.ewc_p_norm = ewc_p_norm

        self.previous_model_weights: Dict[str, Tensor] = {}

        self._previous_task: Optional[int] = None
        self._n_switches: int = 0

    def on_task_switch(self: Policy, task_id: Optional[int], *args, **kwargs) -> None:
        """Executed when the task switches (to either a known or unknown task)."""
        logger.info(f"On task switch called: task_id={task_id}")
        if self._previous_task is None and self._n_switches == 0 and not task_id:
            logger.info("Starting the first task, no EWC update.")
        elif task_id is None or task_id != self._previous_task:
            # NOTE: We also switch between unknown tasks.
            logger.info(
                f"Switching tasks: {self._previous_task} -> {task_id}: "
                f"Updating the EWC 'anchor' weights."
            )
            self._previous_task = task_id
            self.previous_model_weights.clear()
            self.previous_model_weights.update(
                deepcopy({k: v.detach() for k, v in self.named_parameters()})
            )
        self._n_switches += 1

    def get_loss(self: Policy) -> Union[float, Tensor]:
        """This will get called before the call to `policy.optimizer.step()`
        from within the `train` method of the algos from stable-baselines3.

        You can use this to return some kind of loss tensor to use.
        """
        return self.reg_coefficient * self.ewc_loss()

    def after_zero_grad(self: Policy):
        """Called after `self.policy.optimizer.zero_grad()` in the training
        loop of the SB3 algos.
        """
        # Backpropagate the loss here, by default, so that any grad clipping
        # also affects the grads of the loss, for instance.
        wrapper_loss = self.get_loss()
        if isinstance(wrapper_loss, Tensor) and wrapper_loss != 0.0 and wrapper_loss.requires_grad:
            logger.info(f"{type(self).__name__} loss: {wrapper_loss.item()}")
            wrapper_loss.backward(retain_graph=True)

    def before_optimizer_step(self: Policy):
        """Called before `self.policy.optimizer.step()` in the training
        loop of the SB3 algos.
        """

    def ewc_loss(self: Policy) -> Union[float, Tensor]:
        """Gets an 'ewc-like' regularization loss.

        NOTE: This is a simplified version of EWC where the loss is the P-norm
        between the current weights and the weights as they were on the begining
        of the task.
        """
        if self._previous_task is None:
            # We're in the first task: do nothing.
            return 0.0

        old_weights: Dict[str, Tensor] = self.previous_model_weights
        new_weights: Dict[str, Tensor] = dict(self.named_parameters())

        loss = 0.0
        for weight_name, (new_w, old_w) in dict_intersection(new_weights, old_weights):
            loss += torch.dist(new_w, old_w.type_as(new_w), p=self.ewc_p_norm)

        return loss


class EWCPolicy(NormRegularizer):
    """A Wrapper class that adds a `on_task_switch` and a `ewc_loss` method to
    an nn.Module (in this particular case, a Policy from SB3) and implements the EWC method.
    """

    def __init__(
        self: Policy,
        *args,
        reg_coefficient: float = 1.0,
        ewc_p_norm: int = 2,
        fim_representation: PMatAbstract = PMatDiag,
        **kwargs,
    ):
        super().__init__(*args, reg_coefficient, ewc_p_norm, **kwargs)
        self.FIMs: List[PMatAbstract] = None
        self.previous_model_weights: PVector = None
        self.FIM_representation = fim_representation

    def consolidate(self, new_fims: List[PMatAbstract], task: int) -> None:
        """
        Consolidates the previous FIMs and the new onces.
        See online EWC in https://arxiv.org/pdf/1805.06370.pdf.
        """
        if self.FIMs is None:
            self.FIMs = new_fims
            return
        assert len(new_fims) == len(self.FIMs)
        for i, (fim_previous, fim_new) in enumerate(zip(self.FIMs, new_fims)):
            if fim_previous is None:
                self.FIMs[i] = fim_new
            else:
                # consolidate the FIMs
                self.FIMs[i] = EWCPolicy._consolidate_fims(fim_previous, fim_new, task)

    @staticmethod
    def _consolidate_fims(
        fim_previous: PMatAbstract, fim_new: PMatAbstract, task: int
    ) -> PMatAbstract:
        # consolidate the fim_new into fim_previous in place
        if isinstance(fim_new, PMatDiag):
            fim_previous.data = ((deepcopy(fim_new.data)) + fim_previous.data * (task)) / (task + 1)

        elif isinstance(fim_new.data, dict):
            for (n, p), (n_, p_) in zip(fim_previous.data.items(), fim_new.data.items()):
                for item, item_ in zip(p, p_):
                    item.data = ((item.data * (task)) + deepcopy(item_.data)) / (task + 1)
        return fim_previous

    def on_task_switch(
        self: Policy, task_id: Optional[int], dataloader: DataLoader, method: str = "a2c"
    ) -> None:
        """Executed when the task switches (to either a known or unknown task)."""
        logger.info(f"On task switch called: task_id={task_id}")
        if self._previous_task is None and self._n_switches == 0 and not task_id:
            self._previous_task = task_id
            logger.info("Starting the first task, no EWC update.")
            self._n_switches += 1
        elif task_id is None or self._previous_task is None or task_id > self._previous_task:
            # we dont want to go here at test tiem
            # NOTE: We also switch between unknown tasks.
            logger.info(
                f"Switching tasks: {self._previous_task} -> {task_id}: "
                f"Updating the EWC 'anchor' weights."
            )
            self._previous_task = task_id
            self.previous_model_weights = PVector.from_model(self).clone().detach()

            # TODO: keepng to FIMs might be not the optimal way of doing this
            new_fims = []
            if method == "dqn":
                function = self.q_net
                n_output = self.action_space.n
            else:
                function = self
                n_output = 1
            # TODO: Import this FIM function, from wherever it was defined.
            new_fim = FIM(
                model=self,
                loader=dataloader,
                representation=self.FIM_representation,
                n_output=n_output,
                variant=method,
                function=function,
                device=self.device.type,
            )
            new_fims.append(new_fim)
            if method == "a2c":
                # apply EWC also to the value net
                new_fim_critic = FIM(
                    model=self,
                    loader=dataloader,
                    representation=self.FIM_representation,
                    n_output=1,
                    variant="regression",
                    function=lambda *x: self(x[0])[1],
                    device=self.device.type,
                )
                new_fims.append(new_fim_critic)
            self.consolidate(new_fims, task=self._previous_task)
            self._n_switches += 1

    def ewc_loss(self: Policy) -> Union[float, Tensor]:
        """Gets an 'ewc-like' regularization loss."""
        regularizer = 0.0
        if self._previous_task is None or self.reg_coefficient == 0 or self.FIMs is None:
            # We're in the first task: do nothing.
            return regularizer
        v_current = PVector.from_model(self)
        for fim in self.FIMs:
            regularizer += fim.vTMv(v_current - self.previous_model_weights)
        return regularizer


from sequoia.methods.stable_baselines3_methods import (
    A2CModel,
    DDPGModel,
    DQNModel,
    PPOModel,
    SACModel,
    TD3Model,
)


@register_method
@dataclass
class ExampleRegularizationMethod(StableBaselines3Method):
    Model: ClassVar[Type[BaseAlgorithm]]

    # You could use any of these 'backbones' from SB3:
    Model = A2CModel  # Works great! (fastest)
    # Model = PPOModel  # Works great! (somewhat fast)
    # Model = SACModel  # Works (seems to be quite a bit slower).

    # These don't yet work, they have the same error, which seems to be
    # related to the action space being Discrete:
    #     stable_baselines3/td3/td3.py", line 143, in train
    #     noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise)
    # RuntimeError: "normal_kernel_cuda" not implemented for 'Long'
    # Model = TD3Model  # TODO
    # Model = DDPGModel  # TODO
    # Model = DQNModel  # Doesn't work: predictions have more than one value?!

    # Coefficient for the EWC-like loss.
    reg_coefficient: float = 1.0
    # norm of the 'distance' used in the ewc-like loss above.
    ewc_p_norm: int = 2

    def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> BaseAlgorithm:
        # Create the model, as usual:
        model = super().create_model(train_env, valid_env)
        # 'Wrap' the algorithm's policy with the EWC wrapper.
        model = NormRegularizer.wrap_algorithm(
            model,
            reg_coefficient=self.reg_coefficient,
            ewc_p_norm=self.ewc_p_norm,
        )
        return model

    def on_task_switch(self, task_id: Optional[int]) -> None:
        """Called when switching tasks in a CL setting.

        If task labels are available, `task_id` will correspond to the index of
        the new task. Otherwise, if task labels aren't available, `task_id` will
        be `None`.

        todo: use this to customize how your method handles task transitions.
        """
        if self.model:
            self.model.policy.on_task_switch(task_id)


@register_method
@dataclass
class EWCExampleMethod(StableBaselines3Method):
    Model: ClassVar[Type[BaseAlgorithm]]
    # Model = A2CModel  # Works great! (fastest)
    Model = DQNModel  # Works great! (fastest)
    # Coefficient for the EWC-like loss.
    reg_coefficient: float = 1.0
    # Number of observations to use for FIM calculation
    total_steps_fim: int = 1000
    # Fisher information type  (diagonal or block diagobnal)
    fim_representation: PMatAbstract = choice(
        {"diagonal": PMatDiag, "block_diagonal": PMatKFAC}, default=PMatKFAC
    )

    def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> BaseAlgorithm:
        # Create the model, as usual:
        model = super().create_model(train_env, valid_env)
        # 'Wrap' the algorithm's policy with the EWC wrapper.
        model = EWCPolicy.wrap_algorithm(
            model,
            reg_coefficient=self.reg_coefficient,
            fim_representation=self.fim_representation,
        )
        return model

    def on_task_switch(self, task_id: Optional[int]) -> None:
        """Called when switching tasks in a CL setting.

        If task labels are available, `task_id` will correspond to the index of
        the new task. Otherwise, if task labels aren't available, `task_id` will
        be `None`.
        """
        if self.model:
            # create onbservation collection to use for FIM calculation
            observation_collection = []
            while len(observation_collection) < self.total_steps_fim:
                state = self.model.env.reset()
                for _ in range(1000):
                    action = self.get_actions(Observations(state), self.model.env.action_space)
                    state, _, done, _ = self.model.env.step(action)
                    observation_collection.append(torch.tensor(state).to(self.model.device))
                    if done:
                        break
            dataloader = DataLoader(
                TensorDataset(torch.cat(observation_collection)), batch_size=100, shuffle=False
            )
            if "a2c" in str(self.model.__class__):
                rl_method = "a2c"
            elif "dqn" in str(self.model.__class__):
                rl_method = "dqn"
            else:
                raise NotImplementedError
            self.model.policy.on_task_switch(task_id, dataloader, method=rl_method)


if __name__ == "__main__":
    setting = TaskIncrementalRLSetting(
        dataset="cartpole",
        nb_tasks=2,
        train_task_schedule={
            0: {"gravity": 10, "length": 0.3},
            1000: {"gravity": 10, "length": 0.5},  # second task is 'easier' than the first one.
        },
        train_max_steps=2000,
    )
    method = EWCExampleMethod(reg_coefficient=0.0)
    results_without_reg = setting.apply(method)
    method = EWCExampleMethod(reg_coefficient=100)
    results_with_reg = setting.apply(method)
    print("-" * 40)
    print("WITHOUT EWC ")
    print(results_without_reg.summary())
    print(f"With EWC (coefficient={method.reg_coefficient}):")
    print(results_with_reg.summary())


================================================
FILE: examples/advanced/hat_demo.py
================================================
import sys
from argparse import Namespace
from dataclasses import dataclass
from typing import Dict, NamedTuple, Optional, Tuple

import gym
import numpy as np
import torch
import tqdm
from gym import Space, spaces
from numpy import inf
from simple_parsing import ArgumentParser
from torch import Tensor

from sequoia.common import Config
from sequoia.common.spaces import Image
from sequoia.methods import register_method
from sequoia.settings import Environment, Method
from sequoia.settings.sl import TaskIncrementalSLSetting
from sequoia.settings.sl.environment import PassiveEnvironment
from sequoia.settings.sl.incremental import Actions, Observations, Rewards


class Masks(NamedTuple):
    """Named tuple for the masked tensors created in the HATNet."""

    gc1: Tensor
    gc2: Tensor
    gc3: Tensor
    gfc1: Tensor
    gfc2: Tensor


class HatNet(torch.nn.Module):
    """
    @inproceedings{serra2018overcoming,
      title={Overcoming Catastrophic Forgetting with Hard Attention to the Task},
      author={Serra, Joan and Suris, Didac and Miron, Marius and Karatzoglou, Alexandros},
      booktitle={International Conference on Machine Learning},
      pages={4548--4557},
      year={2018}
    }

    The model is where the model weights are initialized.
    Just like a classic PyTorch, here the different layers and components of the model are defined
    """

    def __init__(self, image_space: Image, n_classes_per_task: Dict[int, int], s_hat: int = 50):
        super().__init__()

        ncha = image_space.channels
        size = image_space.width
        self.n_classes_per_task = n_classes_per_task
        self.s_hat = s_hat

        self.c1 = torch.nn.Conv2d(ncha, 64, kernel_size=size // 8)
        s = compute_conv_output_size(size, size // 8)
        s //= 2
        self.c2 = torch.nn.Conv2d(64, 128, kernel_size=size // 10)
        s = compute_conv_output_size(s, size // 10)
        s //= 2
        self.c3 = torch.nn.Conv2d(128, 256, kernel_size=2)
        s = compute_conv_output_size(s, 2)
        s //= 2
        self.smid = s
        self.maxpool = torch.nn.MaxPool2d(2)
        self.relu = torch.nn.ReLU()

        self.drop1 = torch.nn.Dropout(0.2)
        self.drop2 = torch.nn.Dropout(0.5)
        self.fc1 = torch.nn.Linear(256 * self.smid * self.smid, 2048)
        self.fc2 = torch.nn.Linear(2048, 2048)
        self.output_layers = torch.nn.ModuleList()

        n_tasks = len(self.n_classes_per_task)
        # TODO: (@lebrice) Here I'm 'fixing' this, by making it so each output head has
        # as many outputs as there are classes in total. It's not super efficient, but
        # it should work.
        total_classes = sum(self.n_classes_per_task.values())
        for task_index, n_classes_in_task in self.n_classes_per_task.items():
            self.output_layers.append(torch.nn.Linear(2048, total_classes))

        self.gate = torch.nn.Sigmoid()
        # All embedding stuff should start with 'e'
        self.ec1 = torch.nn.Embedding(n_tasks, 64)
        self.ec2 = torch.nn.Embedding(n_tasks, 128)
        self.ec3 = torch.nn.Embedding(n_tasks, 256)
        self.efc1 = torch.nn.Embedding(n_tasks, 2048)
        self.efc2 = torch.nn.Embedding(n_tasks, 2048)

        self.flatten = torch.nn.Flatten()

        self.loss = torch.nn.CrossEntropyLoss()
        self.current_task: Optional[int] = 0

    def forward(self, observations: TaskIncrementalSLSetting.Observations) -> Tuple[Tensor, Masks]:
        observations.as_list_of_tuples()
        x = observations.x
        t = observations.task_labels
        # BUG: This won't work if task_labels is None (which is the case at
        # test-time in the ClassIncrementalSetting)
        masks = self.mask(t, s_hat=self.s_hat)
        gc1, gc2, gc3, gfc1, gfc2 = masks
        # Gated
        h = self.maxpool(self.drop1(self.relu(self.c1(x))))
        h = h * gc1.unsqueeze(2).unsqueeze(3)
        h = self.maxpool(self.drop1(self.relu(self.c2(h))))
        h = h * gc2.unsqueeze(2).unsqueeze(3)
        h = self.maxpool(self.drop2(self.relu(self.c3(h))))
        h = h * gc3.unsqueeze(2).unsqueeze(3)
        h = self.flatten(h)
        h = self.drop2(self.relu(self.fc1(h)))
        h = h * gfc1.expand_as(h)
        h = self.drop2(self.relu(self.fc2(h)))
        h = h * gfc2.expand_as(h)

        # Each batch can have elements of more than one Task (in test)
        # In Task Incremental Learning, each task have it own classification head.
        y: Optional[Tensor] = None
        task_masks = {}
        for task_id in set(t.tolist()):
            task_mask = t == task_id
            task_masks[task_id] = task_mask

            y_pred_t = self.output_layers[task_id](h.clone())
            if y is None:
                y = y_pred_t
            else:
                y[task_mask] = y_pred_t[task_mask]
        assert y is not None
        return y, masks

    def mask(self, t: Tensor, s_hat: float) -> Masks:
        gc1 = self.gate(s_hat * self.ec1(t))
        gc2 = self.gate(s_hat * self.ec2(t))
        gc3 = self.gate(s_hat * self.ec3(t))
        gfc1 = self.gate(s_hat * self.efc1(t))
        gfc2 = self.gate(s_hat * self.efc2(t))
        return Masks(gc1, gc2, gc3, gfc1, gfc2)

    def shared_step(
        self, batch: Tuple[Observations, Optional[Rewards]], environment: Environment
    ) -> Tuple[Tensor, Dict]:
        """Shared step used for both training and validation.

        Parameters
        ----------
        batch : Tuple[Observations, Optional[Rewards]]
            Batch containing Observations, and optional Rewards. When the Rewards are
            None, it means that we'll need to provide the Environment with actions
            before we can get the Rewards (e.g. image labels) back.

            This happens for example when being applied in a Setting which cares about
            sample efficiency or training performance, for example.

        environment : Environment
            The environment we're currently interacting with. Used to provide the
            rewards when they aren't already part of the batch (as mentioned above).

        Returns
        -------
        Tuple[Tensor, Dict]
            The Loss tensor, and a dict of metrics to be logged.
        """
        # Since we're training on a Passive environment, we will get both observations
        # and rewards, unless we're being evaluated based on our training performance,
        # in which case we will need to send actions to the environments before we can
        # get the corresponding rewards (image labels) back.
        observations: Observations = batch[0]
        rewards: Optional[Rewards] = batch[1]

        # Get the predictions:
        logits, _ = self(observations)
        y_pred = logits.argmax(-1)

        if rewards is None:
            # If the rewards in the batch were None, it means we're expected to give
            # actions before we can get rewards back from the environment.
            # This happens when the Setting is monitoring our training performance.
            rewards = environment.send(Actions(y_pred))

        assert rewards is not None
        image_labels = rewards.y

        loss = self.loss(logits, image_labels)

        accuracy = (y_pred == image_labels).sum().float() / len(image_labels)
        metrics_dict = {"accuracy": accuracy}
        return loss, metrics_dict


def compute_conv_output_size(
    Lin: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1
) -> int:
    return int(np.floor((Lin + 2 * padding - dilation * (kernel_size - 1) - 1) / float(stride) + 1))


@register_method
class HatDemoMethod(Method, target_setting=TaskIncrementalSLSetting):
    """
    Here we implement the method according to the characteristics and methodology of the current proposal.
    It should be as much as possible agnostic to the model and setting we are going to use.

    The method proposed can be specific to a setting to make comparisons easier.
    Here what we control is the model's training process, given a setting that delivers data in a certain way.
    """

    @dataclass
    class HParams:
        """Hyper-parameters of the Settings."""

        # Learning rate of the optimizer.
        learning_rate: float = 0.001
        # Batch size
        batch_size: int = 128
        # weight/importance of the task embedding to the gate function
        s_hat: float = 50.0
        # Maximum number of training epochs per task
        max_epochs_per_task: int = 2

    def __init__(self, hparams: HParams = None):
        self.hparams: HatDemoMethod.HParams = hparams or self.HParams()

        # We will create those when `configure` will be called, before training.
        self.model: HatNet
        self.optimizer: torch.optim.Optimizer

    def configure(self, setting: TaskIncrementalSLSetting):
        """Called before the method is applied on a setting (before training).

        You can use this to instantiate your model, for instance, since this is
        where you get access to the observation & action spaces.
        """
        setting.batch_size = self.hparams.batch_size
        assert (
            setting.increment == setting.test_increment
        ), "Assuming same number of classes per task for training and testing."
        n_classes_per_task = {
            i: setting.num_classes_in_task(i, train=True) for i in range(setting.nb_tasks)
        }
        image_space: Image = setting.observation_space["x"]
        self.model = HatNet(
            image_space=image_space,
            n_classes_per_task=n_classes_per_task,
            s_hat=self.hparams.s_hat,
        )
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.hparams.learning_rate,
        )

    def fit(self, train_env: PassiveEnvironment, valid_env: PassiveEnvironment):
        """
        Train loop

        Different Settings can return elements from tasks in an other  way,
        be it class incremental, task incremental, etc.

        Batch can have information about en environment, rewards, input, task labels, etc.
        And we call the forward training function of our method, independent of the settings
        """

        # configure() will have been called by the setting before we get here,

        best_val_loss = inf
        best_epoch = 0
        for epoch in range(self.hparams.max_epochs_per_task):
            self.model.train()
            print(f"Starting epoch {epoch}")
            # Training loop:
            with tqdm.tqdm(train_env) as train_pbar:
                postfix = {}
                train_pbar.set_description(f"Training Epoch {epoch}")
                for i, batch in enumerate(train_pbar):
                    loss, metrics_dict = self.model.shared_step(
                        batch,
                        environment=train_env,
                    )
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    postfix.update(metrics_dict)
                    train_pbar.set_postfix(postfix)

            # Validation loop:
            self.model.eval()
            torch.set_grad_enabled(False)
            with tqdm.tqdm(valid_env) as val_pbar:
                postfix = {}
                val_pbar.set_description(f"Validation Epoch {epoch}")
                epoch_val_loss = 0.0

                for i, batch in enumerate(val_pbar):
                    batch_val_loss, metrics_dict = self.model.shared_step(
                        batch,
                        environment=valid_env,
                    )
                    epoch_val_loss += batch_val_loss
                    postfix.update(metrics_dict, val_loss=epoch_val_loss)
                    val_pbar.set_postfix(postfix)
            torch.set_grad_enabled(True)

            if epoch_val_loss < best_val_loss:
                best_val_loss = epoch_val_loss
                best_epoch = i

    def get_actions(self, observations: Observations, action_space: gym.Space) -> Actions:
        """Get a batch of predictions (aka actions) for these observations."""
        with torch.no_grad():
            logits, _ = self.model(observations)
        # Get the predicted classes
        y_pred = logits.argmax(dim=-1)
        return self.target_setting.Actions(y_pred)

    def on_task_switch(self, task_id: Optional[int]):
        # This method gets called if task boundaries are known in the current
        # setting. Furthermore, if task labels are available, task_id will be
        # the index of the new task. If not, task_id will be None.
        # TODO: Does this method actually work when task_id is None?
        self.model.current_task = task_id

    @classmethod
    def add_argparse_args(cls, parser: ArgumentParser) -> None:
        parser.add_arguments(cls.HParams, dest="hparams")
        # You can also add arguments as usual:
        # parser.add_argument("--foo", default=123)

    @classmethod
    def from_argparse_args(cls, args: Namespace) -> "HatDemoMethod":
        hparams: HatDemoMethod.HParams = args.hparams
        # foo: int = args.foo
        method = cls(hparams=hparams)
        return method


if __name__ == "__main__":
    # Example: Evaluate a Method on a single CL setting:
    parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False)

    """
    We must define 3 main components:
     1.- Setting: It is the continual learning scenario that we are working, SL or RL, TI or CI
                  Each settings has it own parameters that can be customized.
     2.- Model: Is the parameters and layers of the model, just like in PyTorch.
                We can use a predefined model or create your own
     3.- Method: It is how we are going to use what the settings give us to train our model.
                 Same as before, we can define our own or use pre-defined Methods.
    """
    ## Add arguments for the Method, the Setting, and the Config.
    ## (Config contains options like the log_dir, the data_dir, etc.)
    HatDemoMethod.add_argparse_args(parser, dest="method")
    parser.add_arguments(TaskIncrementalSLSetting, dest="setting")
    parser.add_arguments(Config, "config")

    args = parser.parse_args()

    ## Create the Method from the args, and extract the Setting, and the Config:
    method: HatDemoMethod = HatDemoMethod.from_argparse_args(args, dest="method")
    setting: TaskIncrementalSLSetting = args.setting
    config: Config = args.config

    ## Apply the method to the setting, optionally passing in a Config,
    ## producing Results.
    results = setting.apply(method, config=config)
    print(results.summary())
    print(f"objective: {results.objective}")


================================================
FILE: examples/advanced/hparam_tuning.py
================================================
"""Runs a hyper-parameter tuning sweep, using Orion for HPO and wandb for visualization. 

# PREREQUISITES:


1.  (Optional): If you want to run the sweep on the monsterkong env:
    At the time of writing, the monsterkong repo is private. Once the challenge is out,
    it will most probably be made public. In the meantime, you'll need to ask
    @mattriemer for access to the MonsterKong_examples repo.

    ```
    pip install -e .[rl]
    ```

2.  Install the repo, along with the optional dependencies for Hyper-Parameter
    Optimization (HPO):

    ```console
    pip install -e .[hpo]
    ```

    NOTE: You can also fuse the two steps above with `pip install -e .[rl,hpo]`

3.  (Optional) Setup a database to hold the hyper-parameter configurations, following
    the [Orion database configuration documentation](https://orion.readthedocs.io/en/stable/install/database.html)

    The quickest way to get this setup is to run the `orion db setup` wizard, entering
    "pickleddb" as the database type:

    ```console
    $ orion db setup
    Enter the database type:  (default: mongodb) pickleddb
    Enter the database name:  (default: test) 
    Enter the database host:  (default: localhost)
    Default configuration file will be saved at: 
    /home/<your username>/.config/orion.core/orion_config.yaml
    ```

"""
import wandb
from sequoia.common import Config
from sequoia.methods.base_method import BaseMethod
from sequoia.settings import Results, Setting, TraditionalSLSetting
from sequoia.utils.logging_utils import get_logger

logger = get_logger(__name__)


if __name__ == "__main__":
    from simple_parsing import ArgumentParser

    ## Create the Setting:
    from sequoia.settings import RLSetting

    setting = RLSetting(dataset="monsterkong")

    # from sequoia.settings import TaskIncrementalSLSetting
    # setting = TaskIncrementalSLSetting(dataset="cifar10")

    ## Create the BaseMethod:
    # Option 1: Create the method manually:
    # method = BaseMethod()

    # Option 2: From the command-line:
    method, unused_args = BaseMethod.from_known_args()  # allow unused args.
    # parser = ArgumentParser(description=__doc__)
    # BaseMethod.add_argparse_args(parser, dest="method")
    # args, unused_args = parser.parse_known_args()
    # method: BaseMethod = BaseMethod.from_argparse_args(args, dest="method")

    # Search space for the Hyper-Parameter optimization algorithm.
    # NOTE: This is just a copy of the spaces that are auto-generated from the fields of
    # the `BaseModel.HParams` class. You can change those as you wish though.
    search_space = {
        "learning_rate": "loguniform(1e-06, 1e-02, default_value=0.001)",
        "weight_decay": "loguniform(1e-12, 1e-03, default_value=1e-06)",
        "optimizer": "choices(['sgd', 'adam', 'rmsprop'], default_value='adam')",
        "encoder": "choices({'resnet18': 0.5, 'simple_convnet': 0.5}, default_value='resnet18')",
        "output_head": {
            "activation": "choices(['relu', 'tanh', 'elu', 'gelu', 'relu6'], default_value='tanh')",
            "dropout_prob": "uniform(0, 0.8, default_value=0.2)",
            "gamma": "uniform(0.9, 0.999, default_value=0.99)",
            "normalize_advantages": "choices([True, False])",
            "actor_loss_coef": "uniform(0.1, 1, default_value=0.5)",
            "critic_loss_coef": "uniform(0.1, 1, default_value=0.5)",
            "entropy_loss_coef": "uniform(0, 1, discrete=True, default_value=0)",
        },
    }
    best_hparams, best_results = method.hparam_sweep(
        setting, search_space=search_space, experiment_id="123"
    )

    print(f"Best hparams: {best_hparams}, best perf: {best_results}")
    # results = setting.apply(method, config=Config(debug=True))


================================================
FILE: examples/advanced/pnn/__init__.py
================================================


================================================
FILE: examples/advanced/pnn/layers.py
================================================
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

"""
Based on https://github.com/TomVeniat/ProgressiveNeuralNetworks.pytorch
"""


class PNNConvLayer(nn.Module):
    def __init__(self, col, depth, n_in, n_out, kernel_size=3):
        super(PNNConvLayer, self).__init__()
        self.col = col
        self.layer = nn.Conv2d(n_in, n_out, kernel_size, stride=2, padding=1)

        self.u = nn.ModuleList()
        if depth > 0:
            self.u.extend(
                [nn.Conv2d(n_in, n_out, kernel_size, stride=2, padding=1) for _ in range(col)]
            )

    def forward(self, inputs):
        if not isinstance(inputs, list):
            inputs = [inputs]

        cur_column_out = self.layer(inputs[-1])
        prev_columns_out = [mod(x) for mod, x in zip(self.u, inputs)]

        return F.relu(cur_column_out + sum(prev_columns_out))


class PNNLinearBlock(nn.Module):
    def __init__(self, col: int, depth: int, n_in: int, n_out: int):
        super(PNNLinearBlock, self).__init__()
        self.layer = nn.Linear(n_in, n_out)

        self.u = nn.ModuleList()
        if depth > 0:
            self.u.extend([nn.Linear(n_in, n_out) for _ in range(col)])

    def forward(self, inputs):
        if not isinstance(inputs, list):
            inputs = [inputs]

        cur_column_out = self.layer(inputs[-1])
        prev_columns_out = [mod(x) for mod, x in zip(self.u, inputs)]

        return F.relu(cur_column_out + sum(prev_columns_out))


================================================
FILE: examples/advanced/pnn/model_rl.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

from .layers import PNNConvLayer, PNNLinearBlock


class PnnA2CAgent(nn.Module):
    """
    @article{rusu2016progressive,
      title={Progressive neural networks},
      author={Rusu, Andrei A and Rabinowitz, Neil C and Desjardins, Guillaume and Soyer, Hubert and Kirkpatrick, James and Kavukcuoglu, Koray and Pascanu, Razvan and Hadsell, Raia},
      journal={arXiv preprint arXiv:1606.04671},
      year={2016}
    }
    """

    def __init__(self, arch="mlp", hidden_size=256):
        super(PnnA2CAgent, self).__init__()
        self.columns_actor = nn.ModuleList([])
        self.columns_critic = nn.ModuleList([])
        self.columns_conv = nn.ModuleList([])
        self.arch = arch
        self.hidden_size = hidden_size

        # Original size 3 x 400 x 600
        self.transformation = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ]
        )

    def forward(self, observations):
        assert (
            self.columns_actor
        ), "PNN should at least have one column (missing call to `new_task` ?)"
        t = observations.task_labels

        if self.arch == "mlp":
            x = torch.from_numpy(observations.x).unsqueeze(0).float()
            inputs_critic = [c[1](c[0](x)) for c in self.columns_critic]
            inputs_actor = [c[1](c[0](x)) for c in self.columns_actor]

            outputs_critic = []
            outputs_actor = []
            for i, column in enumerate(self.columns_critic):
                outputs_critic.append(column[2](inputs_critic[: i + 1]))
                outputs_actor.append(self.columns_actor[i][2](inputs_actor[: i + 1]))

            ind_depth = 3

        else:
            x = self.transfor_img(observations.x).unsqueeze(0).float()
            inputs = [c[1](c[0](x)) for c in self.columns_conv]

            outputs = []
            for i, column in enumerate(self.columns_conv):
                outputs.append(column[3](column[2](inputs[: i + 1])))

            inputs = outputs
            outputs = []
            for i, column in enumerate(self.columns_conv):
                outputs.append(column[5](column[4](inputs[: i + 1])))

            inputs_critic = [c[6](outputs[i]).view(1, -1) for i, c in enumerate(self.columns_conv)]
            inputs_actor = inputs_critic[:]

            outputs_critic = []
            outputs_actor = []
            for i, column in enumerate(self.columns_critic):
                outputs_critic.append(column[0](inputs_critic[: i + 1]))
                outputs_actor.append(self.columns_actor[i][0](inputs_actor[: i + 1]))

            ind_depth = 1

        critic = []
        for i, column in enumerate(self.columns_critic):
            critic.append(column[ind_depth](outputs_critic[i]))

        actor = []
        for i, column in enumerate(self.columns_actor):
            actor.append(F.softmax(column[ind_depth](outputs_actor[i]), dim=1))

        return critic[t], actor[t]

    def new_task(self, device, num_inputs, num_actions=5):
        task_id = len(self.columns_actor)

        if self.arch == "conv":
            sizes = [num_inputs, 32, 64, self.hidden_size]
            modules_conv = nn.Sequential()

            modules_conv.add_module("Conv1", PNNConvLayer(task_id, 0, sizes[0], sizes[1]))
            modules_conv.add_module("MaxPool1", nn.MaxPool2d(3))
            modules_conv.add_module("Conv2", PNNConvLayer(task_id, 1, sizes[1], sizes[2]))
            modules_conv.add_module("MaxPool2", nn.MaxPool2d(3))
            modules_conv.add_module("Conv3", PNNConvLayer(task_id, 2, sizes[2], sizes[3]))
            modules_conv.add_module("MaxPool3", nn.MaxPool2d(3))
            modules_conv.add_module("globavgpool2d", nn.AdaptiveAvgPool2d((1, 1)))
            self.columns_conv.append(modules_conv)

        modules_actor = nn.Sequential()
        modules_critic = nn.Sequential()

        if self.arch == "mlp":
            modules_actor.add_module("linAc1", nn.Linear(num_inputs, self.hidden_size))
            modules_actor.add_module("relAc", nn.ReLU(inplace=True))
        modules_actor.add_module(
            "linAc2", PNNLinearBlock(task_id, 1, self.hidden_size, self.hidden_size)
        )
        modules_actor.add_module("linAc3", nn.Linear(self.hidden_size, num_actions))

        if self.arch == "mlp":
            modules_critic.add_module("linCr1", nn.Linear(num_inputs, self.hidden_size))
            modules_critic.add_module("relCr", nn.ReLU(inplace=True))
        modules_critic.add_module(
            "linCr2", PNNLinearBlock(task_id, 1, self.hidden_size, self.hidden_size)
        )
        modules_critic.add_module("linCr3", nn.Linear(self.hidden_size, 1))

        self.columns_actor.append(modules_actor)
        self.columns_critic.append(modules_critic)

        print("Add column of the new task")

    def unfreeze_columns(self):
        for i, c in enumerate(self.columns_actor):
            for params in c.parameters():
                params.requires_grad = True

            for params in self.columns_critic[i].parameters():
                params.requires_grad = True

        for i, c in enumerate(self.columns_conv):
            for params in c.parameters():
                params.requires_grad = True

    def freeze_columns(self, skip=None):
        if skip == None:
            skip = []

        self.unfreeze_columns()

        for i, c in enumerate(self.columns_actor):
            if i not in skip:
                for params in c.parameters():
                    params.requires_grad = False

                for params in self.columns_critic[i].parameters():
                    params.requires_grad = False

        for i, c in enumerate(self.columns_conv):
            if i not in skip:
                for params in c.parameters():
                    params.requires_grad = False

        print("Freeze columns from previous tasks")

    def parameters(self, task_id):
        param = []
        for p in self.columns_critic[task_id].parameters():
            param.append(p)
        for p in self.columns_actor[task_id].parameters():
            param.append(p)

        if len(self.columns_conv) > 0:
            for p in self.columns_conv[task_id].parameters():
                param.append(p)

        return param

    def transfor_img(self, img):
        return self.transformation(img)
        # return lambda img: imresize(img[35:195].mean(2), (80,80)).astype(np.float32).reshape(1,80,80)/255.


================================================
FILE: examples/advanced/pnn/model_sl.py
================================================
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from sequoia.settings import Actions, PassiveEnvironment
from sequoia.settings.sl.incremental import Observations, Rewards

from .layers import PNNConvLayer, PNNLinearBlock


class PnnClassifier(nn.Module):
    """
    @article{rusu2016progressive,
      title={Progressive neural networks},
      author={Rusu, Andrei A and Rabinowitz, Neil C and Desjardins, Guillaume and Soyer, Hubert and Kirkpatrick, James and Kavukcuoglu, Koray and Pascanu, Razvan and Hadsell, Raia},
      journal={arXiv preprint arXiv:1606.04671},
      year={2016}
    }
    """

    def __init__(self, n_layers):
        super().__init__()
        self.n_layers = n_layers
        self.columns = nn.ModuleList([])

        self.loss = torch.nn.CrossEntropyLoss()
        self.device = None
        self.n_tasks = 0
        self.n_classes_per_task: List[int] = []

    def forward(self, observations):
        assert self.columns, "PNN should at least have one column (missing call to `new_task` ?)"
        x = observations.x
        x = torch.flatten(x, start_dim=1)
        labels = observations.task_labels
        # TODO: Debug this:
        inputs = [
            c[0](x) + n_classes_in_task
            for n_classes_in_task, c in zip(self.n_classes_per_task, self.columns)
        ]
        for l in range(1, self.n_layers):
            outputs = []

            for i, column in enumerate(self.columns):
                outputs.append(column[l](inputs[: i + 1]))

            inputs = outputs

        y: Optional[Tensor] = None
        task_masks = {}
        for task_id in set(labels.tolist()):
            task_mask = labels == task_id
            task_masks[task_id] = task_mask

            if y is None:
                y = inputs[task_id]
            else:
                y[task_mask] = inputs[task_id][task_mask]

        assert y is not None, "Can't get prediction in model PNN"
        return y

    # def new_task(self, device, num_inputs, num_actions = 5):
    def new_task(self, device, sizes: List[int]):
        assert len(sizes) == self.n_layers + 1, (
            f"Should have the out size for each layer + input size (got {len(sizes)} "
            f"sizes but {self.n_layers} layers)."
        )
        self.n_tasks += 1
        # TODO: Fix this to use the actual number of classes per task.
        self.n_classes_per_task.append(2)
        task_id = len(self.columns)
        modules = []
        for i in range(0, self.n_layers):
            modules.append(PNNLinearBlock(col=task_id, depth=i, n_in=sizes[i], n_out=sizes[i + 1]))

        new_column = nn.ModuleList(modules).to(device)
        self.columns.append(new_column)
        self.device = device

        print("Add column of the new task")

    def freeze_columns(self, skip=None):
        if skip == None:
            skip = []

        for i, c in enumerate(self.columns):
            for params in c.parameters():
                params.requires_grad = True

        for i, c in enumerate(self.columns):
            if i not in skip:
                for params in c.parameters():
                    params.requires_grad = False

        print("Freeze columns from previous tasks")

    def shared_step(
        self,
        batch: Tuple[Observations, Optional[Rewards]],
        environment: PassiveEnvironment,
    ):
        """Shared step used for both training and validation.

        Parameters
        ----------
        batch : Tuple[Observations, Optional[Rewards]]
            Batch containing Observations, and optional Rewards. When the Rewards are
            None, it means that we'll need to provide the Environment with actions
            before we can get the Rewards (e.g. image labels) back.

            This happens for example when being applied in a Setting which cares about
            sample efficiency or training performance, for example.

        environment : Environment
            The environment we're currently interacting with. Used to provide the
            rewards when they aren't already part of the batch (as mentioned above).

        Returns
        -------
        Tuple[Tensor, Dict]
            The Loss tensor, and a dict of metrics to be logged.
        """
        # Since we're training on a Passive environment, we will get both observations
        # and rewards, unless we're being evaluated based on our training performance,
        # in which case we will need to send actions to the environments before we can
        # get the corresponding rewards (image labels).
        observations: Observations = batch[0].to(self.device)
        rewards: Optional[Rewards] = batch[1]

        # Get the predictions:
        logits = self(observations)
        y_pred = logits.argmax(-1)
        # TODO: PNN is coded for the DomainIncrementalSetting, where the action space
        # is the same for each task.

        # Get the rewards, if necessary:
        if rewards is None:
            rewards = environment.send(Actions(y_pred))

        image_labels = rewards.y.to(self.device)
        # print(logits.size())
        loss = self.loss(logits, image_labels)

        accuracy = (y_pred == image_labels).sum().float() / len(image_labels)
        metrics_dict = {"accuracy": accuracy}
        return loss, metrics_dict

    def parameters(self, task_id):
        return self.columns[task_id].parameters()


================================================
FILE: examples/advanced/pnn/pnn_method.py
================================================
import sys
from argparse import Namespace
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from gym import spaces
from gym.spaces import Box
from numpy import inf
from scipy.signal import lfilter
from simple_parsing import ArgumentParser
from torchvision import transforms

from examples.advanced.pnn.model_rl import PnnA2CAgent
from examples.advanced.pnn.model_sl import PnnClassifier
from sequoia import Environment
from sequoia.common import Config
from sequoia.common.spaces import Image
from sequoia.common.transforms.utils import is_image
from sequoia.settings import Actions, Method, Observations, Rewards, Setting
from sequoia.settings.assumptions import IncrementalAssumption
from sequoia.settings.rl import ActiveEnvironment, RLSetting, TaskIncrementalRLSetting
from sequoia.settings.sl import (
    DomainIncrementalSLSetting,
    PassiveEnvironment,
    SLSetting,
    TaskIncrementalSLSetting,
)


class PnnMethod(Method, target_setting=Setting):
    """
    Here we implement the PNN Method according to the characteristics and methodology of
    the current proposal.  It should be as much as possible agnostic to the model and
    setting we are going to use.

    The method proposed can be specific to a setting to make comparisons easier.
    Here what we control is the model's training process, given a setting that delivers
    data in a certain way.
    """

    @dataclass
    class HParams:
        """Hyper-parameters of the Pnn method."""

        # Learning rate of the optimizer. Defauts to 0.0001 when in SL.
        learning_rate: float = 2e-4
        num_steps: int = 200  # (only applicable in RL settings.)
        # Discount factor (Only used in RL settings).
        gamma: float = 0.99
        # Number of hidden units (only used in RL settings.)
        hidden_size: int = 256
        # Batch size in SL, and number of parallel environments in RL.
        # Defaults to None in RL, and 32 when in SL.
        batch_size: Optional[int] = None
        # Maximum number of training epochs per task. (only used in SL Settings)
        max_epochs_per_task: int = 2

    def __init__(self, hparams: HParams = None):
        # We will create those when `configure` will be called, before training.
        self.config: Optional[Config] = None
        self.task_id: Optional[int] = 0
        self.hparams: Optional[PnnMethod.HParams] = hparams
        self.model: Union[PnnA2CAgent, PnnClassifier]
        self.optimizer: torch.optim.Optimizer

    def configure(self, setting: Setting):
        """Called before the method is applied on a setting (before training).

        You can use this to instantiate your model, for instance, since this is
        where you get access to the observation & action spaces.
        """

        input_space: Box = setting.observation_space["x"]
        task_label_space = setting.observation_space["task_labels"]

        # For now all Settings have `Discrete` (i.e. classification) action spaces.
        action_space: spaces.Discrete = setting.action_space

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_actions = action_space.n
        self.num_inputs = np.prod(input_space.shape)

        self.added_tasks = []

        if isinstance(setting, RLSetting):
            # If we're applied to an RL setting:

            # Used these as the default hparams in RL:
            self.hparams = self.hparams or self.HParams(
                learning_rate=2e-4,
                num_steps=200,
                gamma=0.99,
                hidden_size=256,
                batch_size=None,
            )
            assert self.hparams
            self.train_steps_per_task = setting.steps_per_task

            # We want a batch_size of None, i.e. only one observation at a time.
            setting.batch_size = None

            self.num_steps = self.hparams.num_steps
            # Otherwise, we can train basically as long as we want on each task.
            self.loss_function = {
                "gamma": self.hparams.gamma,
            }

            x_space = setting.observation_space.x
            if is_image(setting.observation_space.x):
                # Observing pixel input.
                self.arch = "conv"
            else:
                # Observing state input (e.g. the 4 floats in cartpole rather than images)
                self.arch = "mlp"
            self.model = PnnA2CAgent(self.arch, self.hparams.hidden_size)

        else:
            # If we're applied to a Supervised Learning setting:
            # Used these as the default hparams in SL:
            self.hparams = self.hparams or self.HParams(
                learning_rate=0.0001,
                batch_size=32,
            )
            if self.hparams.batch_size is None:
                self.hparams.batch_size = 32

            # Set the batch size on the setting.
            setting.batch_size = self.hparams.batch_size
            # For now all Settings on the supervised side of the tree have images as
            # inputs, so the observation spaces are of type `Image` (same as Box, but with
            # additional `h`, `w`, `c` and `b` attributes).
            assert isinstance(input_space, Image)
            assert (
                setting.increment == setting.test_increment
            ), "Assuming same number of classes per task for training and testing."
            # TODO: (@lebrice): Temporarily 'fixing' this by making it so each output
            # head has as many outputs as there are classes in total, which might make
            # no sense, but currently works.
            # It would be better to refactor this so that each output head can have only
            # as many outputs as is required, and then reshape / offset the predictions.
            n_outputs = setting.increment
            n_outputs = setting.action_space.n
            self.layer_size = [self.num_inputs, 256, n_outputs]
            self.model = PnnClassifier(
                n_layers=len(self.layer_size) - 1,
            )

    def on_task_switch(self, task_id: Optional[int]) -> None:
        """Called when switching tasks in a CL setting."""
        # This method gets called if task boundaries are known in the current
        # setting. Furthermore, if task labels are available, task_id will be
        # the index of the new task. If not, task_id will be None.
        # For example, you could do something like this:
        # self.model.current_task = task_id
        # This freezes all columns except the one for the next task.. but there might
        # not yet be a column for the new task!
        self.model.freeze_columns(skip=[task_id])
        if task_id not in self.added_tasks:
            if isinstance(self.model, PnnA2CAgent):
                self.model.new_task(
                    device=self.device,
                    num_inputs=self.num_inputs,
                    num_actions=self.num_actions,
                )
            else:
                self.model.new_task(device=self.device, sizes=self.layer_size)

            self.added_tasks.append(task_id)

        self.task_id = task_id

    def set_optimizer(self):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(self.task_id),
            lr=self.hparams.learning_rate,
        )

    def get_actions(self, observations: Observations, action_space: spaces.Space) -> Actions:
        """Get a batch of predictions (aka actions) for the given observations."""

        observations = observations.to(self.device)
        with torch.no_grad():
            if isinstance(self.model, PnnA2CAgent):
                predictions = self.model(observations)
                _, logit = predictions
                # get the predicted action:
                action = torch.argmax(logit).item()
            else:
                logits = self.model(observations)
                # Get the predicted classes
                y_pred = logits.argmax(dim=-1)
                action = y_pred

        assert action in action_space, (action, action_space)
        return action

    def fit(self, train_env: Environment, valid_env: Environment):
        """Train and validate this method using the "environments" for the current task.

        NOTE: `train_env` and `valid_env` are both `gym.Env`s as well as `DataLoader`s.
        This means that if you want to write a "regular" SL training loop, you totally
        can, and if you want to write you RL-style training loop, you can also do that.
        """
        if isinstance(train_env.unwrapped, PassiveEnvironment):
            self.fit_sl(train_env, valid_env)
        else:
            self.fit_rl(train_env, valid_env)

    def fit_rl(self, train_env: gym.Env, valid_env: gym.Env):
        """Training loop for Reinforcement Learning (a.k.a. "active") environment."""
        """
        base on https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f
        """
        if self.model is None:
            self.model = PnnA2CAgent(self.arch, self.hparams.hidden_size)
        assert isinstance(self.model, PnnA2CAgent)

        self.set_optimizer()
        assert self.hparams
        # self.model.float()

        all_lengths = []
        average_lengths = []
        all_rewards = []
        entropy_term = 0

        for episode in range(self.train_steps_per_task):
            values = []
            rewards = []
            log_probs = []

            state = train_env.reset()
            for steps in range(self.num_steps):
                value, policy_dist = self.model(state)

                value = value.item()
                dist = policy_dist.detach().numpy()

                action = np.random.choice(self.num_actions, p=np.squeeze(dist))
                log_prob = torch.log(policy_dist.squeeze(0)[action])
                entropy = -np.sum(np.mean(dist) * np.log(dist))
                new_state, reward, done, _ = train_env.step(action)

                rewards.append(reward.y)
                values.append(value)
                log_probs.append(log_prob)
                entropy_term += entropy
                state = new_state

                if done or steps == self.num_steps - 1:
                    Qval, _ = self.model(state)
                    Qval = Qval.item()
                    all_rewards.append(np.sum(rewards))
                    all_lengths.append(steps)
                    average_lengths.append(np.mean(all_lengths[-10:]))

                    if episode % 10 == 0:
                        print(
                            f"episode: {episode}, "
                            f"reward: {np.sum(rewards)}, "
                            f"total length: {steps}, "
                            f"average length: {average_lengths[-1]}"
                        )
                    break

            Qvals = np.zeros_like(values)
            for t in reversed(range(len(rewards))):
                Qval = rewards[t] + self.hparams.gamma * Qval
                Qvals[t] = Qval

            # update actor critic
            values_tensor = torch.as_tensor(values, dtype=torch.float)
            Qvals = torch.as_tensor(Qvals, dtype=torch.float)
            log_probs_tensor = torch.stack(log_probs)

            advantage = Qvals - values_tensor
            actor_loss = (-log_probs_tensor * advantage).mean()
            critic_loss = 0.5 * advantage.pow(2).mean()
            ac_loss = actor_loss + critic_loss + 0.001 * entropy_term

            self.optimizer.zero_grad()
            ac_loss.backward()
            self.optimizer.step()

    def fit_sl(self, train_env: PassiveEnvironment, valid_env: PassiveEnvironment):
        """Train on a Supervised Learning (a.k.a. "passive") environment."""
        observations: TaskIncrementalSLSetting.Observations = train_env.reset()
        cuda_observations = observations.to(self.device)
        assert isinstance(self.model, PnnClassifier)
        assert self.hparams

        self.set_optimizer()

        best_val_loss = inf
        best_epoch = 0
        for epoch in range(self.hparams.max_epochs_per_task):
            self.model.train()
            print(f"Starting epoch {epoch}")
            # Training loop:
            with torch.set_grad_enabled(True), tqdm.tqdm(train_env) as train_pbar:
                postfix: Dict[str, Any] = {}
                train_pbar.set_description(f"Training Epoch {epoch}")
                for i, batch in enumerate(train_pbar):
                    loss, metrics_dict = self.model.shared_step(
                        batch,
                        environment=train_env,
                    )
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    postfix.update(metrics_dict)
                    train_pbar.set_postfix(postfix)

            # Validation loop:
            self.model.eval()
            with torch.set_grad_enabled(False), tqdm.tqdm(valid_env) as val_pbar:
                postfix = {}
                val_pbar.set_description(f"Validation Epoch {epoch}")
                epoch_val_loss = 0.0

                for i, batch in enumerate(val_pbar):
                    batch_val_loss, metrics_dict = self.model.shared_step(
                        batch,
                        environment=valid_env,
                    )
                    epoch_val_loss += batch_val_loss
                    postfix.update(metrics_dict, val_loss=epoch_val_loss)
                    val_pbar.set_postfix(postfix)

    @classmethod
    def add_argparse_args(cls, parser: ArgumentParser) -> None:
        parser.add_arguments(cls.HParams, dest="hparams", default=None)

    @classmethod
    def from_argparse_args(cls, args: Namespace) -> "PnnMethod":
        hparams: PnnMethod.HParams = args.hparams
        method = cls(hparams=hparams)
        return method


def main_rl():
    """Applies the PnnMethod in a RL Setting."""
    parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False)

    Config.add_argparse_args(parser, dest="config")
    PnnMethod.add_argparse_args(parser, dest="method")

    setting = TaskIncrementalRLSetting(
        dataset="cartpole",
        nb_tasks=2,
        train_task_schedule={
            0: {"gravity": 10, "length": 0.3},
            1000: {"gravity": 10, "length": 0.5},
        },
    )

    args = parser.parse_args()

    config: Config = Config.from_argparse_args(args, dest="config")
    method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method")
    method.config = config

    # 2. Creating the Method
    # method = ImproveMethod()

    # 3. Applying the method to the setting:
    results = setting.apply(method, config=config)

    print(results.summary())
    print(f"objective: {results.objective}")
    return results


def main_sl():
    """Applies the PnnMethod in a SL Setting."""
    parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False)

    # Add arguments for the Setting
    # TODO: PNN is coded for the DomainIncrementalSetting, where the action space
    # is the same for each task.
    # parser.add_arguments(DomainIncrementalSetting, dest="setting")
    parser.add_arguments(TaskIncrementalSLSetting, dest="setting")
    # TaskIncrementalSLSetting.add_argparse_args(parser, dest="setting")
    Config.add_argparse_args(parser, dest="config")

    # Add arguments for the Method:
    PnnMethod.add_argparse_args(parser, dest="method")

    args = parser.parse_args()

    # setting: TaskIncrementalSLSetting = args.setting
    setting: TaskIncrementalSLSetting = TaskIncrementalSLSetting.from_argparse_args(
        # setting: DomainIncrementalSetting = DomainIncrementalSetting.from_argparse_args(
        args,
        dest="setting",
    )
    config: Config = Config.from_argparse_args(args, dest="config")

    method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method")

    method.config = config

    results = setting.apply(method, config=config)
    print(results.summary())
    return results


if __name__ == "__main__":
    # Run RL Setting
    main_sl()
    # Run SL Setting
    # main_rl()


================================================
FILE: examples/advanced/procgen_example.py
================================================
""" Example of how to create an incremental RL Setting with custom environments for each task.

In this example, we create environments using [the `procgen` package](https://github.com/openai/procgen).
"""

import dataclasses
from dataclasses import dataclass, replace
from typing import Dict, List, NamedTuple, Optional, Type, TypeVar

import gym
import numpy as np

from sequoia.settings.rl import (
    IncrementalRLSetting,
    MultiTaskRLSetting,
    TaskIncrementalRLSetting,
    TraditionalRLSetting,
)


@dataclass
class ProcGenConfig:
    """Options for creating an environment from ProcGen.

    The fields on this dataclass match the arguments that can be passed to `gym.make`, based on the
    README of the procgen repo.
    """

    # Name of environment, or comma-separate list of environment names to instantiate as each env
    # in the VecEnv.
    env_name: str = "coinrun-v0"
    # The number of unique levels that can be generated. Set to 0 to use unlimited levels.
    num_levels: int = 0
    # The lowest seed that will be used to generated levels. 'start_level' and 'num_levels' fully
    # specify the set of possible levels.
    start_level: int = 0
    # Paint player velocity info in the top left corner. Only supported by certain games.
    paint_vel_info: bool = False
    # Use randomly generated assets in place of human designed assets.
    use_generated_assets: bool = False
    # Set to True to use the debug build if building from source.
    debug: bool = False
    # Useful flag that's passed through to procgen envs. Use however you want during debugging.
    debug_mode: int = 0
    # Determines whether observations are centered on the agent or display the full level.
    # Override at your own risk.
    center_agent: bool = True
    # When you reach the end of a level, the episode is ended and a new level is selected.
    # If use_sequential_levels is set to True, reaching the end of a level does not end the episode,
    # and the seed for the new level is derived from the current level seed.
    # If you combine this with start_level=<some seed> and num_levels=1, you can have a single
    # linear series of levels similar to a gym-retro or ALE game.
    use_sequential_levels: bool = False
    # What variant of the levels to use, the options are "easy", "hard", "extreme", "memory",
    # "exploration". All games support "easy" and "hard", while other options are game-specific.
    # The default is "hard". Switching to "easy" will reduce the number of timesteps required to
    # solve each game and is useful for testing or when working with limited compute resources.
    distribution_mode: str = "hard"
    # Normally games use human designed backgrounds, if this flag is set to False, games will use
    # pure black backgrounds.
    use_backgrounds: bool = True
    # Some games select assets from multiple themes, if this flag is set to True, those games will
    # only use a single theme.
    restrict_themes: bool = False
    # If set to True, games will use monochromatic rectangles instead of human designed assets.
    # Best used with restrict_themes=True.
    use_monochrome_assets: bool = False

    def make_env(self) -> gym.Env:
        """Creates the environment using these options."""
        env_id = f"procgen:procgen-{self.env_name}"
        # Create the env by passing the arguments to gym.make, same as what is done in the README of
        # the procgen repo.
        procgen_env = gym.make(
            id=env_id,
            num_levels=self.num_levels,
            start_level=self.start_level,
            paint_vel_info=self.paint_vel_info,
            use_generated_assets=self.use_generated_assets,
            debug=self.debug,
            center_agent=self.center_agent,
            use_sequential_levels=self.use_sequential_levels,
            distribution_mode=self.distribution_mode,
            use_backgrounds=self.use_backgrounds,
            restrict_themes=self.restrict_themes,
            use_monochrome_assets=self.use_monochrome_assets,
        )
        # NOTE: The environments that are created with `gym.make("procgen:procgen-...")` are
        # instances of the `gym3.interop:ToGymEnv` class, which has a slightly different API than
        # the `gym.Env` class:
        # (Taken From gym3/interop.py:)
        # > - The `render()` method does nothing in "human" mode, in "rgb_array" mode the info dict
        #     is checked for a key named "rgb" and info["rgb"][0] is returned if present
        # > - `seed()` and `close() are ignored since gym3 environments do not require these methods
        #
        # Therefore, for now, since in Sequoia we assume that the envs fit the gym.Env API, we have to
        # "patch" these different methods up a bit. Here I suggest we do this using a wrapper
        # (defined below)
        wrapped_env = SequoiaProcGenAdapterWrapper(env=procgen_env)
        return wrapped_env


class SequoiaProcGenAdapterWrapper(gym.Wrapper):
    """A wrapper around an environment from ProcGen to patch up the methods/properties that differ
    from the gym API:

    - The `seed` method doesn't ahve the right number of arguments.
    - The `done` value is of type `np.bool_` instead of a plain bool.
    - `render` returns None.
    """

    def __init__(self, env):
        super().__init__(env=env)

    def step(self, action):
        obs, rewards, done, info = self.env.step(action)
        if isinstance(done, np.bool_):
            done = bool(done)
        return obs, rewards, done, info

    def seed(self, seed: Optional[int] = None) -> List[int]:
        # The procgen env apparently doesn't have (or need?) a `seed` method, but they don't
        # implement it corrently, by not accepting a `seed` argument!
        return []

    def render(self, mode: str = "rgb_array"):
        # note: rendering doesn't seem to be working: `self.env.render("rgb_array")` returns None.
        array: Optional[np.ndarray] = self.env.render("rgb_array")
        return array


# Type variable for a type of setting that supports passing envs for each task (all settings below
# `InrementalRLSetting`).
SettingType = TypeVar("SettingType", bound=IncrementalRLSetting)

available_envs = [
    "bigfish",
    "bossfight",
    "caveflyer",
    "chaser",
    "climber",
    "coinrun",
    "dodgeball",
    "fruitbot",
    "heist",
    "jumper",
    "leaper",
    "maze",
    "miner",
    "ninja",
    "plunder",
    "starpilot",
]


def make_procgen_setting(
    env_name: str,
    nb_tasks: int,
    num_levels_per_task: int = 1,
    overlapping_levels_between_tasks: int = 0,
    common_options: ProcGenConfig = None,
    setting_type: Type[SettingType] = TaskIncrementalRLSetting,
) -> SettingType:
    """Creates an RL Setting that uses environments from procgen.

    Parameters
    ----------
    env_name : str
        Name of the environment from procgen to use. Should include the version tag.
        For example: "coinrun-v0".
    nb_tasks : int
        Number of tasks in the setting.
    num_levels_per_task : int, optional
        Number of generated levels per task, by default 1
    overlapping_levels_between_tasks : int, optional
        Number of levels in common between neighbouring tasks. Needs to be less than
        `num_levels_per_task`. Defaults to 0, in which case all tasks distinct levels.
    common_options : ProcGenConfig, optional
        Set of options common to the envs of all the tasks. This can be used to set the starting
        level, for example. Defaults to None, in which case the default options from `ProcGenConfig`
        are used.
    setting_type : Type[SettingType], optional
        The type of setting to create, by default TaskIncrementalRLSetting.

    For example, say `nb_tasks`=5, `num_levels_per_task`=2, `overlapping_levels_between_tasks`=1:

    task #1: levels: [0, 1]
    task #2: levels: [1, 2]
    task #3: levels: [2, 3]
    task #4: levels: [3, 4]
    task #5: levels: [4, 5]

    For example, say `nb_tasks`=5, `num_levels_per_task`=5, `overlapping_levels_between_tasks`=2:
    task #1: levels: [0, 1, 2, 3, 4]
    task #2: levels: [3, 4, 5, 6, 7]
    task #3: levels: [6, 7, 8, 9, 10]
    task #4: levels: [9, 10, 11, 12, 13]
    task #5: levels: [12, 13, 14, 15, 16]

    NOTE: (lebrice): Maybe this (and other benchmark-creating functions) could be classmethods on
    the settings, instead of passing the setting_type as a parameter!

    Returns
    -------
    SettingType
        A Setting of type `setting_type` (`TaskIncrementalRLSetting`) by default, where each task
        uses environments from ProcGen.
    """
    assert overlapping_levels_between_tasks < num_levels_per_task

    # Create the options common to every task.
    if common_options is None:
        common_options = ProcGenConfig(env_name=env_name)
    else:
        common_options = dataclasses.replace(common_options, env_name=env_name)

    # Get the starting levels for each task, as shown in the docstring above.
    offset = num_levels_per_task - overlapping_levels_between_tasks
    first_task_start_level = common_options.start_level
    last_task_start_level = common_options.start_level + offset * nb_tasks
    start_levels: List[int] = list(range(first_task_start_level, last_task_start_level, offset))

    # Create the configurations that will be used to create the train/valid/test environments for
    # each task by starting from the common options, and overwriting the values of `start_level`.
    train_env_configs: List[ProcGenConfig] = [
        replace(common_options, start_level=start_levels[task_id], num_levels=num_levels_per_task)
        for task_id in range(nb_tasks)
    ]
    # NOTE: For now the validation and testing environment are the same as those for training.
    # This could easily be different though!
    # For example:
    # - the test environments could have a background while the train/valid envs don't!
    #   --> This could be super interesting to researchers in Out-of-Distribution RL!
    valid_env_configs: List[ProcGenConfig] = train_env_configs.copy()
    test_env_configs: List[ProcGenConfig] = train_env_configs.copy()

    # Here we pass a list of functions to be called to create each env. This can be a bit better
    # than passing the envs themselves, as it saves some memory, and also because we'll be able to
    # close the envs after each task (since we can always re-create them).
    setting = setting_type(
        dataset=None,
        train_envs=[config.make_env for config in train_env_configs],
        val_envs=[config.make_env for config in valid_env_configs],
        test_envs=[config.make_env for config in test_env_configs],
    )
    return setting


from sequoia.common.config import Config
from sequoia.methods.random_baseline import RandomBaselineMethod


def main_simple():
    # Simple example: Create a Task-Incremental RL setting using procgen envs.
    setting = make_procgen_setting(env_name="coinrun-v0", nb_tasks=5)
    method = RandomBaselineMethod()
    # NOTE: The `render` option isn't yet working (see above)
    results = setting.apply(method, config=Config(debug=True, render=False))
    print(results.summary())


def main_using_other_setting():
    # Example where we change what kind of setting we want to create.
    class Key(NamedTuple):
        stationary_context: bool
        task_labels_at_test_time: bool

    # This is here just to give an idea of the differences between these settings.
    available_settings: Dict[Key, Type[IncrementalRLSetting]] = {
        Key(task_labels_at_test_time=False, stationary_context=False): IncrementalRLSetting,
        Key(task_labels_at_test_time=True, stationary_context=False): TaskIncrementalRLSetting,
        Key(task_labels_at_test_time=False, stationary_context=True): TraditionalRLSetting,
        Key(task_labels_at_test_time=True, stationary_context=True): MultiTaskRLSetting,
    }

    # You can choose whichever setting you want, but for example:
    setting_type = available_settings[Key(task_labels_at_test_time=True, stationary_context=False)]
    # Create the Method.
    method = RandomBaselineMethod()

    setting = make_procgen_setting(env_name="coinrun-v0", nb_tasks=5, setting_type=setting_type)
    results = setting.apply(method, config=Config(debug=True, render=False))
    print(results.summary())


if __name__ == "__main__":
    main_simple()


================================================
FILE: examples/basic/__init__.py
================================================


================================================
FILE: examples/basic/base_method_demo.py
================================================
""" Example showing how the BaseMethod can be applied to get results in both
RL and SL settings.
"""

from simple_parsing import ArgumentParser

from sequoia.common import Config
from sequoia.methods import BaseMethod
from sequoia.settings import Setting, TaskIncrementalRLSetting, TaskIncrementalSLSetting


def baseline_demo_simple():
    config = Config()
    method = BaseMethod(config=config, max_epochs=1)

    ## Create *any* Setting from the tree, for example:
    # Supervised Learning Setting:
    setting = TaskIncrementalSLSetting(
        dataset="cifar10",
        nb_tasks=2,
    )
    ## Reinforcement Learning Setting:
    # setting = TaskIncrementalRLSetting(
    #     dataset="cartpole",
    #     train_max_steps=4000,
    #     nb_tasks=2,
    # )
    results = setting.apply(method, config=config)
    print(results.summary())
    return results


def baseline_demo_command_line():
    parser = ArgumentParser(__doc__, add_dest_to_option_strings=False)

    # Supervised Learning Setting:
    parser.add_arguments(TaskIncrementalSLSetting, dest="setting")
    # Reinforcement Learning Setting:
    # parser.add_arguments(TaskIncrementalRLSetting, dest="setting")

    parser.add_arguments(Config, dest="config")
    BaseMethod.add_argparse_args(parser, dest="method")

    args = parser.parse_args()

    setting: Setting = args.setting
    config: Config = args.config
    method: BaseMethod = BaseMethod.from_argparse_args(args, dest="method")

    results = setting.apply(method, config=config)
    print(results.summary())
    return results


if __name__ == "__main__":
    ### Option 1: Create the BaseMethod and Settings manually.
    baseline_demo_simple()

    ### Option 2: Create the BaseMethod and Settings from the command-line.
    # baseline_demo_command_line()


================================================
FILE: examples/basic/pl_example.py
================================================
"""A simple example for creating a Method using PyTorch-Lightning.

Run this as:

```console
$> python examples/basic/pl_examples.py
```
"""
from dataclasses import asdict, dataclass
from typing import Optional, Tuple

import torch
from gym import spaces
from pytorch_lightning import LightningModule, Trainer
from torch import Tensor, nn
from torch.optim import Adam

from sequoia.common.config import Config
from sequoia.common.spaces import Image
from sequoia.methods import Method
from sequoia.settings.assumptions.task_type import ClassificationActions
from sequoia.settings.sl.continual import (
    Actions,
    ContinualSLSetting,
    Observations,
    ObservationSpace,
    Rewards,
)


class Model(LightningModule):
    """Example Pytorch Lightning model used for continual image classification.

    Used by the `ExampleMethod` below.
    """

    @dataclass
    class HParams:
        """Hyper-parameters of our model.

        NOTE: dataclasses are totally optional. This is just much nicer than dicts or
        ugly namespaces.
        """

        # Learning rate.
        learning_rate: float = 1e-3
        # Maximum number of training epochs per task.
        max_epochs_per_task: int = 1

    def __init__(
        self,
        input_space: ObservationSpace,
        output_space: spaces.Discrete,
        hparams: HParams = None,
    ):
        super().__init__()
        hparams = hparams or self.HParams()
        # NOTE: `input_space` is a subclass of `gym.spaces.Dict`. It contains (at least)
        # the `x` key, but can also contain other things, for example the task labels.
        # Doing things this way makes sure that this Model can also be applied to any
        # more specific Setting in the future (any setting with more information given)!
        image_space: Image = input_space.x
        # NOTE: `Image` is just a subclass of `gym.spaces.Box` with a few extra properties

        self.input_dims = image_space.shape
        # NOTE: Can't set the `hparams` attribute in PL, so use hp instead:
        self.hp = hparams
        self.save_hyperparameters({"hparams": asdict(hparams)})
        in_channels: int = image_space.channels
        num_classes: int = output_space.n

        # Imitates the SimpleConvNet from  sequoia.common.models.simple_convnet
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 6, kernel_size=5, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(6),
            nn.ReLU(inplace=True),
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.AdaptiveAvgPool2d(output_size=(8, 8)),  # [16, 8, 8]
            # [32, 6, 6]
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            # [32, 4, 4]
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(32),
            nn.Flatten(),
        )
        # Quick tip: In this case we have a fixed hidden size (thanks to the Adaptive
        # pooling layer above), but you could also use the cool new `nn.LazyLinear` when
        # you don't know the hidden size in advance!
        self.fc = nn.Sequential(
            nn.Flatten(),
            # nn.LazyLinear(out_features=120),
            nn.Linear(512, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, num_classes),
        )
        self.loss = nn.CrossEntropyLoss()
        self.trainer: Trainer

    def forward(self, observations: ContinualSLSetting.Observations) -> Tensor:
        """Returns the logits for the given observation.

        Parameters
        ----------
        observations : ContinualSLSetting.Observations
            dataclass with (at least) the following attributes:
            - "x" (Tensor): the samples (images)
            - "task_labels" (Optional[Tensor]): Task labels, when applicable.

        Returns
        -------
        Tensor
            Classification logits for each class.
        """
        x: Tensor = observations.x
        # Task labels for each sample. We don't use them in this example.
        t: Optional[Tensor] = observations.task_labels
        h_x = self.features(x)
        logits = self.fc(h_x)
        return logits

    def training_step(
        self, batch: Tuple[Observations, Optional[Rewards]], batch_idx: int
    ) -> Tensor:
        return self.shared_step(batch=batch, batch_idx=batch_idx, stage="train")

    def validation_step(
        self, batch: Tuple[Observations, Optional[Rewards]], batch_idx: int
    ) -> Tensor:
        return self.shared_step(batch=batch, batch_idx=batch_idx, stage="val")

    def test_step(self, batch: Tuple[Observations, Optional[Rewards]], batch_idx: int) -> Tensor:
        return self.shared_step(batch=batch, batch_idx=batch_idx, stage="test")

    def shared_step(
        self,
        batch: Tuple[Observations, Optional[Rewards]],
        batch_idx: int,
        stage: str,
    ) -> Tensor:
        observations, rewards = batch

        logits = self(observations)
        y_pred = logits.argmax(-1)
        actions = ClassificationActions(y_pred=y_pred, logits=logits)

        if rewards is None:
            environment: ContinualSLSetting.Environment
            # The rewards (image labels) might not be given at the same time as the
            # observations (images), for example during testing, or if we're being
            # evaluated based on our online performance during training!
            #
            # When that is the case, we need to send the "action" (predictions) to the
            # environment using `send()` to get the rewards.
            actions = y_pred
            # Get the current environment / dataloader from the Trainer.
            environment = self.trainer.request_dataloader(self, stage)
            rewards = environment.send(actions)
        y: Tensor = rewards.y

        accuracy = (y_pred == y).int().sum() / len(y)
        self.log(f"{stage}/accuracy", accuracy, prog_bar=True)

        loss = self.loss(logits, y)
        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.hp.learning_rate)


class ExampleMethod(Method, target_setting=ContinualSLSetting):
    """Example method for solving Continual SL Settings with PyTorch-Lightning

    This ExampleMethod declares that it can be applied to any `Setting` that inherits
    from this `ContinualSLSetting`.

    NOTE: Settings in Sequoia are a subclass of `LightningDataModule`, which create
    the training/validation/testing `Environment`s that methods will interact with.
    Each setting defines an `apply` method, which serves as a "main loop", and describes
    when and on what data to train the Method, and how it will be evaluated, according
    to the usual methodology for that setting in the litterature.

    Importantly, settings do NOT describe **how** the method is to be trained, that is
    entirely up to the Method!
    """

    def __init__(self, hparams: Model.HParams = None):
        super().__init__()
        self.hparams = hparams or Model.HParams()
        self.current_task: Optional[int] = None
        # NOTE: These get assigned in `configure` below:
        self.model: Model
        self.trainer: Trainer

    def configure(self, setting: ContinualSLSetting):
        """Called by the Setting so the method can configure itself before training.

        This could be used to, for example, create a model, since the observation space
        (which describes the types and shapes of the data) and the `nb_tasks` can be
        read from the Setting.

        Parameters
        ----------
        setting : ContinualSLSetting
            The research setting that this `Method` will be applied to.
        """
        if not setting.known_task_boundaries_at_train_time:
            # If we're being applied on a Setting where we don't have access to task
            # boundaries, then there is only one training environment that transitions
            # between all tasks and then closes itself.
            # We therefore limit the number of epochs per task to 1 in that case.
            self.hparams.max_epochs_per_task = 1
        self.model = Model(
            input_space=setting.observation_space,
            output_space=setting.action_space,
            hparams=self.hparams,
        )

    def fit(
        self,
        train_env: ContinualSLSetting.Environment,
        valid_env: ContinualSLSetting.Environment,
    ):
        """Called by the Setting to allow the method to train.

        The passed environments inherit from `DataLoader` as well as from `gym.Env`.
        They produce `Observations` (which have an `x` Tensor field, for instance), and
        return `Rewards` when they receive `Actions`.
        This interface is the same between RL and SL, making it easy to create methods
        that can adapt to both domains.

        Parameters
        ----------
        train_env : ContinualSLSetting.Environment
            The Training environment. In the case of a `ContinualSLSetting`, this
            environment will smoothly transition between the different tasks.
            NOTE: Regardless of what exact type of `Setting` this method is being
            applied to, this environment will always be a subclass of
            `ContinualSLSetting.Environment`, and the `Observations`, `Actions`,
            `Rewards` produced by this environment will also always follow this
            hierarchy.
            This is important to note, since it makes it possible to create a Method
            that also works in other settings which add extra information in the
            observations (e.g. task labels)!

        valid_env : ContinualSLSetting.Environment
            The Validation environment.
        """
        # NOTE: Currently have to 'reset' the Trainer for each call to `fit`.
        self.trainer = Trainer(
            gpus=torch.cuda.device_count(),
            max_epochs=self.hparams.max_epochs_per_task,
        )
        self.trainer.fit(self.model, train_dataloader=train_env, val_dataloaders=valid_env)

    def test(self, test_env: ContinualSLSetting.Environment):
        """Called to let the Method handle the test loop by itself.

        The `test_env` will only give back rewards (y) once an action (y_pred) is sent
        to it via its `send` method.

        This test environment keeps track of some metrics of interest for its `Setting`
        (accuracy in this case) and reports them back to the `Setting` once the test
        environment has been exhausted.

        NOTE: The test environment will close itself when done, signifying the end
        of the test period. At that point, `test_env.is_closed()` will return `True`.
        """
        # BUG: There is currently a bug with the test loop with Trainer: on_task_switch
        # doesn't get called properly.
        raise NotImplementedError
        # Use ckpt_path=None to use the current weights, rather than the "best" ones.
        self.trainer.test(self.model, ckpt_path=None, test_dataloaders=test_env)

    def get_actions(self, observations: Observations, action_space: spaces.MultiDiscrete):
        """Called by the Setting to query for individual predictions.

        You currently have to implement this, but if `test` is implemented, it will be
        used instead. Sorry if this isn't super clear.
        """
        self.model.eval()
        with torch.no_grad():
            logits = self.model(observations.to(self.model.device))
            y_pred = logits.argmax(-1)
        return Actions(y_pred=y_pred)

    def on_task_switch(self, task_id: Optional[int]) -> None:
        """Can be called by the Setting when a task boundary is reached.

        This will be called if `setting.known_task_boundaries_at_[train/test]_time` is
        True, depending on if this is called during training or during testing.

        If `setting.task_labels_at_[train/test]_time` is True, then `task_id` will be
        the identifyer (index) of the next task. If the value is False, then `task_id`
        will be None.
        """
        if task_id != self.current_task:
            phase = "training" if self.training else "testing"
            print(f"Switching tasks during {phase}: {self.current_task} -> {task_id}")
            self.current_task = task_id


def main():
    """Runs the example: applies the method on a Continual Supervised Learning Setting."""
    # You could use any of the settings in SL, since this example methods targets the
    # most general Continual SL Setting in Sequoia: `ContinualSLSetting`:
    # from sequoia.settings.sl import ClassIncrementalSetting

    # Create the Setting:
    # NOTE: Since our model above uses an adaptive pooling layer, it should work on any
    # dataset!
    setting = ContinualSLSetting(dataset="mnist", monitor_training_performance=True)

    # Create the Method:
    method = ExampleMethod()

    # Create a config for the experiment (just so we can set a few options for this
    # example)
    config = Config(debug=True, log_dir="results/pl_example")

    # Launch the experiment: trains and tests the method according to the chosen
    # setting and returns a Results object.
    results = setting.apply(method, config=config)

    # Print the results, and show some plots!
    print(results.summary())
    for figure_name, figure in results.make_plots().items():
        print("Figure:", figure_name)
        figure.show()
        # figure.waitforbuttonpress(10)


if __name__ == "__main__":
    main()


================================================
FILE: examples/basic/pl_example_packnet.py
================================================
from dataclasses import dataclass
from typing import Optional

import torch
from simple_parsing import mutable_field

from examples.basic.pl_example import ExampleMethod, Model
from sequoia.common import Config
from sequoia.methods import BaseModel
from sequoia.methods.packnet_method import PackNet
from sequoia.methods.trainer import Trainer, TrainerConfig
from sequoia.settings.sl import ContinualSLSetting, TaskIncrementalSLSetting


class ExamplePackNetMethod(ExampleMethod, target_setting=TaskIncrementalSLSetting):
    def __init__(self, hparams: Model.HParams = None, packnet_hparams: PackNet.HParams = None):
        super().__init__(hparams=hparams)
        self.packnet_hparams = packnet_hparams or PackNet.HParams()
        # TODO: Modify `hparams.max_epochs_per_task` to at least be enough so that
        # PackNet will work.
        min_epochs = self.packnet_hparams.train_epochs + self.packnet_hparams.fine_tune_epochs
        if self.hparams.max_epochs_per_task < min_epochs:
            self.hparams.max_epochs_per_task = min_epochs
        self.p_net: PackNet

    def configure(self, setting: TaskIncrementalSLSetting):
        super().configure(setting)
        # TODO: Why does PackNet need access to the number of tasks again?
        self.p_net = PackNet(
            n_tasks=setting.nb_tasks,
            hparams=self.packnet_hparams,
        )
        # TODO: This could be set as default values in the PackNet constructor.
        self.p_net.current_task = -1
        self.p_net.config_instructions()

    def fit(
        self,
        train_env: TaskIncrementalSLSetting.Environment,
        valid_env: TaskIncrementalSLSetting.Environment,
    ):
        # NOTE: PackNet is not compatible with EarlyStopping, thus we set max_epochs==min_epochs
        self.trainer = Trainer(
            gpus=torch.cuda.device_count(),
            min_epochs=self.p_net.total_epochs(),
            max_epochs=self.p_net.total_epochs(),
            callbacks=[self.p_net],
        )

        self.trainer.fit(self.model, train_dataloader=train_env, val_dataloaders=valid_env)

    def on_task_switch(self, task_id: Optional[int]):
        """Called when switching between tasks.

        Args:
            task_id (int, optional): the id of the new task. When None, we are
            basically being informed that there is a task boundary, but without
            knowing what task we're switching to.
        """
        super().on_task_switch(task_id=task_id)
        if task_id is not None and len(self.p_net.masks) > task_id:
            self.p_net.load_final_state(model=self.model)
            self.p_net.apply_eval_mask(task_idx=task_id, model=self.model)
        self.p_net.current_task = task_id


def main():
    """Runs the example: applies the method on a Continual Supervised Learning Setting."""
    # You could use any of the settings in SL, since this example methods targets the
    # most general Continual SL Setting in Sequoia: `ContinualSLSetting`:
    # from sequoia.settings.sl import ClassIncrementalSetting

    # Create the Setting:
    # NOTE: Since our model above uses an adaptive pooling layer, it should work on any
    # dataset!
    setting = TaskIncrementalSLSetting(
        dataset="mnist", nb_tasks=5, monitor_training_performance=True
    )

    # Create the Method:
    method = ExamplePackNetMethod()

    # Create a config for the experiment (just so we can set a few options for this
    # example)
    config = Config(debug=False, log_dir="results/pl_example_packnet")

    # Launch the experiment: trains and tests the method according to the chosen
    # setting and returns a Results object.
    results = setting.apply(method, config=config)

    # Print the results, and show some plots!
    print(results.summary())
    for figure_name, figure in results.make_plots().items():
        print("Figure:", figure_name)
        figure.show()
        # figure.waitforbuttonpress(10)


if __name__ == "__main__":
    main()


================================================
FILE: examples/basic/pl_example_test.py
================================================
""" Unit-tests for the PyTorch-Lightning Example.

Can be run like so:
```console
$ pytest examples/basic/pl_example_test.py
```
"""
from typing import Type

import pytest

from examples.basic.pl_example import ExampleMethod, Model
from sequoia.common.config import Config
from sequoia.common.metrics import ClassificationMetrics
from sequoia.methods import Method
from sequoia.methods.method_test import MethodTests, config, session_config  # type: ignore
from sequoia.settings import Results
from sequoia.settings.sl import ContinualSLSetting, IncrementalSLSetting


class TestPLExample(MethodTests):
    """Tests for this PL Example.

    This `MethodTests` base class generates a `test_debug` test for us.
    """

    Method: Type[Method] = ExampleMethod

    @pytest.fixture()
    def method(self, config: Config):
        """Required fixture, which creates a Method that can be used for quick tests."""
        return ExampleMethod(hparams=Model.HParams(max_epochs_per_task=1))

    def validate_results(
        self, setting: ContinualSLSetting, method: ExampleMethod, results: Results
    ):
        """This gets called by `test_debug` to check that the results make sense for
        the given setting and method.

        """
        # NOTE: This particular example isn't that great: We just check that the average
        # final test accuracy and the average online accuracy are both non-zero.
        # It would be best to do some kind of branching depending on what type of
        # Setting was used, since each setting can produce different types of results.
        print(results.summary())

        average_metrics: ClassificationMetrics
        online_metrics: ClassificationMetrics

        assert setting.monitor_training_performance

        todo = 0.0
        if isinstance(setting, IncrementalSLSetting):
            # The results in this case include the entire nb_tasks x nb_tasks transfer
            # matrix.
            assert isinstance(results, IncrementalSLSetting.Results)
            average_metrics = results.average_final_performance
            online_metrics = results.average_online_performance

            if setting.stationary_context:
                # Example: Should expect better performance if the data is i.i.d!
                assert average_metrics.accuracy > todo
            else:
                assert average_metrics.accuracy > todo

            if setting.monitor_training_performance:
                assert online_metrics.accuracy > todo
        else:
            # In this case, there aren't clear 'tasks' to speak of, so the results are
            # just aggregated metrics for each test batch:
            assert isinstance(results, ContinualSLSetting.Results)
            average_metrics = results.average_metrics
            online_metrics = results.online_performance_metrics

            assert average_metrics.accuracy > todo
            assert online_metrics.accuracy > todo


================================================
FILE: examples/basic/quick_demo.ipynb
================================================
{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python38364bitpy38conda80a8f432976e4e99926307fddceb6e0b",
   "display_name": "Python 3.8.3 64-bit ('py38': conda)",
   "language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# Quick Demo (Notebook version)\n",
    "\n",
    "(I hate notebooks.)\n",
    "\n",
    "In this demo, we will create a simple method and apply it to various Continual Learning settings.\n",
    "\n",
    "For the purposes of this demo, we will restrict ourselves to classification problems on the mnist and fashion-mnist datasets."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports:\n",
    "import sys\n",
    "from dataclasses import dataclass\n",
    "from typing import Dict, Optional, Tuple, Type\n",
    "\n",
    "import gym\n",
    "import torch\n",
    "from gym import spaces\n",
    "from torch import Tensor, nn\n",
    "from simple_parsing import ArgumentParser\n",
    "\n",
    "sys.path.extend([\".\", \"..\"])\n",
    "from sequoia.settings import Method, Setting\n",
    "from sequoia.settings.sl.class_incremental import ClassIncrementalSetting, DomainIncrementalSetting\n",
    "from sequoia.settings.sl.class_incremental.objects import (\n",
    "    Actions,\n",
    "    Environment,\n",
    "    Observations,\n",
    "    PassiveEnvironment,\n",
    "    Results,\n",
    "    Rewards,\n",
    ")"
   ]
  },
  {
   "source": [
    "# Basic Model:"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class MyModel(nn.Module):\n",
    "    \"\"\" Simple classification model without any CL-related mechanism.\n",
    "\n",
    "    To keep things simple, this demo model is designed for supervised\n",
    "    (classification) settings where observations have shape [3, 28, 28] (ie the\n",
    "    MNIST variants: Mnist, FashionMnist, RotatedMnist, EMnist, etc.)\n",
    "    \"\"\"\n",
    "    def __init__(self,\n",
    "                 observation_space: gym.Space,\n",
    "                 action_space: gym.Space,\n",
    "                 reward_space: gym.Space):\n",
    "        super().__init__()\n",
    "        image_shape = observation_space["x"].shape\n",
    "        assert image_shape == (3, 28, 28)\n",
    "        assert isinstance(action_space, spaces.Discrete)\n",
    "        assert action_space == reward_space\n",
    "        n_classes = action_space.n\n",
    "        image_channels = image_shape[0]\n",
    "\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channels, 6, 5),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2),\n",
    "            nn.Conv2d(6, 16, 5),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2),\n",
    "        )\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(256, 120),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(120, 84),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(84, n_classes),\n",
    "        )\n",
    "        self.loss = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, observations: Observations) -> Tensor:\n",
    "        # NOTE: here we don't make use of the task labels.\n",
    "        x = observations.x\n",
    "        task_labels = observations.task_labels\n",
    "        features = self.encoder(x)\n",
    "        logits = self.classifier(features)\n",
    "        return logits\n",
    "\n",
    "    def shared_step(\n",
    "        self, batch: Tuple[Observations, Optional[Rewards]], environment: Environment\n",
    "    ) -> Tuple[Tensor, Dict]:\n",
    "        \"\"\"Shared step used for both training and validation.\n",
    "                \n",
    "        Parameters\n",
    "        ----------\n",
    "        batch : Tuple[Observations, Optional[Rewards]]\n",
    "            Batch containing Observations, and optional Rewards. When the Rewards are\n",
    "            None, it means that we'll need to provide the Environment with actions\n",
    "            before we can get the Rewards (e.g. image labels) back.\n",
    "            \n",
    "            This happens for example when being applied in a Setting which cares about\n",
    "            sample efficiency or training performance, for example.\n",
    "            \n",
    "        environment : Environment\n",
    "            The environment we're currently interacting with. Used to provide the\n",
    "            rewards when they aren't already part of the batch (as mentioned above).\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        Tuple[Tensor, Dict]\n",
    "            The Loss tensor, and a dict of metrics to be logged.\n",
    "        \"\"\"\n",
    "        # Since we're training on a Passive environment, we will get both observations\n",
    "        # and rewards, unless we're being evaluated based on our training performance,\n",
    "        # in which case we will need to send actions to the environments before we can\n",
    "        # get the corresponding rewards (image labels).\n",
    "        observations: Observations = batch[0]\n",
    "        rewards: Optional[Rewards] = batch[1]\n",
    "        # Get the predictions:\n",
    "        logits = self(observations)\n",
    "        y_pred = logits.argmax(-1)\n",
    "\n",
    "        if rewards is None:\n",
    "            # If the rewards in the batch is None, it means we're expected to give\n",
    "            # actions before we can get rewards back from the environment.\n",
    "            rewards = environment.send(Actions(y_pred))\n",
    "\n",
    "        assert rewards is not None\n",
    "        image_labels = rewards.y\n",
    "\n",
    "        loss = self.loss(logits, image_labels)\n",
    "\n",
    "        accuracy = (y_pred == image_labels).sum().float() / len(image_labels)\n",
    "        metrics_dict = {\"accuracy\": accuracy.item()}\n",
    "        return loss, metrics_dict\n"
   ]
  },
  {
   "source": [
    "## Creating our Method\n",
    "\n",
    "Here by subclassing 'MethodABC' and passing in a target_setting, we indicate that we are creating a new method, and that it will work on any Setting that is an instance of ClassIncrementalSetting or one of its subclasses. "
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class DemoMethod(Method, target_setting=ClassIncrementalSetting):\n",
    "    \"\"\" Minimal example of a Method targetting the Class-Incremental CL setting.\n",
    "    \n",
    "    For a quick intro to dataclasses, see examples/dataclasses_example.py    \n",
    "    \"\"\"\n",
    "\n",
    "    @dataclass\n",
    "    class HParams:\n",
    "        \"\"\" Hyper-parameters of the demo model. \"\"\"\n",
    "        # Learning rate of the optimizer.\n",
    "        learning_rate: float = 0.001\n",
    "    \n",
    "    def __init__(self, hparams: HParams):\n",
    "        self.hparams: DemoMethod.HParams = hparams\n",
    "        self.max_epochs: int = 1\n",
    "        self.early_stop_patience: int = 2\n",
    "\n",
    "        # We will create those when `configure` will be called, before training.\n",
    "        self.model: MyModel\n",
    "        self.optimizer: torch.optim.Optimizer\n",
    "\n",
    "    def configure(self, setting: ClassIncrementalSetting):\n",
    "        \"\"\" Called before the method is applied on a setting (before training). \n",
    "\n",
    "        You can use this to instantiate your model, for instance, since this is\n",
    "        where you get access to the observation & action spaces.\n",
    "        \"\"\"\n",
    "        self.model = MyModel(\n",
    "            observation_space=setting.observation_space,\n",
    "            action_space=setting.action_space,\n",
    "            reward_space=setting.reward_space,\n",
    "        )\n",
    "        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.learning_rate)\n",
    "\n",
    "    def fit(self, train_env: PassiveEnvironment, valid_env: PassiveEnvironment):\n",
    "        # configure() will have been called by the setting before we get here.\n",
    "        import tqdm\n",
    "        from numpy import inf\n",
    "        best_val_loss = inf\n",
    "        best_epoch = 0\n",
    "        for epoch in range(self.max_epochs):\n",
    "            self.model.train()\n",
    "            # Training loop:\n",
    "            with tqdm.tqdm(train_env) as train_pbar:\n",
    "                train_pbar.set_description(f\"Training Epoch {epoch}\")\n",
    "                for i, batch in enumerate(train_pbar):\n",
    "                    loss, metrics_dict = self.model.shared_step(batch, environment=train_env)\n",
    "                    self.optimizer.zero_grad()\n",
    "                    loss.backward()\n",
    "                    self.optimizer.step()\n",
    "                    train_pbar.set_postfix(**metrics_dict)\n",
    "\n",
    "            # Validation loop:\n",
    "            self.model.eval()\n",
    "            torch.set_grad_enabled(False)\n",
    "            with tqdm.tqdm(valid_env) as val_pbar:\n",
    "                val_pbar.set_description(f\"Validation Epoch {epoch}\")\n",
    "                epoch_val_loss = 0.\n",
    "\n",
    "                for i, batch in enumerate(val_pbar):\n",
    "                    batch_val_loss, metrics_dict = self.model.shared_step(batch, environment=valid_env)\n",
    "                    epoch_val_loss += batch_val_loss\n",
    "                    val_pbar.set_postfix(**metrics_dict, val_loss=epoch_val_loss)\n",
    "            torch.set_grad_enabled(True)\n",
    "\n",
    "            if epoch_val_loss < best_val_loss:\n",
    "                best_val_loss = valid_env\n",
    "                best_epoch = epoch\n",
    "            if epoch - best_epoch > self.early_stop_patience:\n",
    "                print(f\"Early stopping at epoch {i}.\")\n",
    "                break\n",
    "\n",
    "    def get_actions(self, observations: Observations, action_space: gym.Space) -> Actions:\n",
    "        \"\"\" Get a batch of predictions (aka actions) for these observations. \"\"\" \n",
    "        with torch.no_grad():\n",
    "            logits = self.model(observations)\n",
    "        # Get the predicted classes\n",
    "        y_pred = logits.argmax(dim=-1)\n",
    "        return self.target_setting.Actions(y_pred)\n",
    "    \n",
    "    @classmethod\n",
    "    def add_argparse_args(cls, parser: ArgumentParser, dest: str = \"\"):\n",
    "        \"\"\"Adds command-line arguments for this Method to an argument parser.\"\"\"\n",
    "        parser.add_arguments(cls.HParams, \"hparams\")\n",
    "\n",
    "    @classmethod\n",
    "    def from_argparse_args(cls, args, dest: str = \"\"):\n",
    "        \"\"\"Creates an instance of this Method from the parsed arguments.\"\"\"\n",
    "        hparams: cls.HParams = args.hparams\n",
    "        return cls(hparams=hparams)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-02-25:17:29:01,958 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:184] Starting training on task 0.\n",
      "2021-02-25:17:29:01,959 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:148] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:02,13 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/passive/cl/class_incremental_setting.py:433] Number of train tasks: 5.\n",
      "2021-02-25:17:29:02,14 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/passive/cl/class_incremental_setting.py:434] Number of test tasks: 5.\n",
      "Training Epoch 0: 100%|██████████| 300/300 [00:04<00:00, 64.17it/s, accuracy=1]\n",
      "Validation Epoch 0: 100%|██████████| 75/75 [00:00<00:00, 155.53it/s, accuracy=1, val_loss=tensor(3.1905)]\n",
      "2021-02-25:17:29:07,205 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:212] Finished Training on task 0.\n",
      "2021-02-25:17:29:07,246 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/passive/cl/class_incremental_setting.py:433] Number of train tasks: 5.\n",
      "2021-02-25:17:29:07,246 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/passive/cl/class_incremental_setting.py:434] Number of test tasks: 5.\n",
      "2021-02-25:17:29:07,274 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:347] Will query the method for actions at each step, since it doesn't implement a `test` method.\n",
      "Test:   0%|          | 0/312 [00:00<?, ?it/s]2021-02-25:17:29:07,361 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:07,365 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:07,373 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:07,382 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:07,394 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "Test: 100%|██████████| 312/312 [00:01<00:00, 232.18it/s]\n",
      "2021-02-25:17:29:08,713 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:217] Resulting objective of Test Loop: 0.626102\n",
      "2021-02-25:17:29:08,713 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:184] Starting training on task 1.\n",
      "2021-02-25:17:29:08,714 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:148] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "Training Epoch 0: 100%|██████████| 300/300 [00:03<00:00, 79.71it/s, accuracy=0.969]\n",
      "Validation Epoch 0: 100%|██████████| 75/75 [00:00<00:00, 170.55it/s, accuracy=0.969, val_loss=tensor(5.7692)]\n",
      "2021-02-25:17:29:12,923 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:212] Finished Training on task 1.\n",
      "2021-02-25:17:29:12,926 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:347] Will query the method for actions at each step, since it doesn't implement a `test` method.\n",
      "Test:   0%|          | 0/312 [00:00<?, ?it/s]2021-02-25:17:29:13,14 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:13,19 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:13,27 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:13,36 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:13,46 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "Test: 100%|██████████| 312/312 [00:01<00:00, 248.27it/s]\n",
      "2021-02-25:17:29:14,276 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:217] Resulting objective of Test Loop: 0.568409\n",
      "2021-02-25:17:29:14,277 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:184] Starting training on task 2.\n",
      "2021-02-25:17:29:14,278 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:148] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "Training Epoch 0: 100%|██████████| 300/300 [00:03<00:00, 86.51it/s, accuracy=1]\n",
      "Validation Epoch 0: 100%|██████████| 75/75 [00:00<00:00, 152.03it/s, accuracy=1, val_loss=tensor(0.0980)]\n",
      "2021-02-25:17:29:18,245 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:212] Finished Training on task 2.\n",
      "2021-02-25:17:29:18,249 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:347] Will query the method for actions at each step, since it doesn't implement a `test` method.\n",
      "Test:   0%|          | 0/312 [00:00<?, ?it/s]2021-02-25:17:29:18,339 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:18,343 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:18,356 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:18,362 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "2021-02-25:17:29:18,371 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:305] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't know about it! \n",
      "Test: 100%|██████████| 312/312 [00:01<00:00, 243.46it/s]\n",
      "2021-02-25:17:29:19,632 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:217] Resulting objective of Test Loop: 0.757212\n",
      "2021-02-25:17:29:19,632 INFO     [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:184] Starting training on task 3.\n",
      "2021-02-25:17:29:19,633 WARNING  [/home/fabrice/repos/Sequoia/sequoia/settings/assumptions/incremental.py:148] On a task boundary, but since your method doesn't have an `on_task_switch` method, it won't
Download .txt
gitextract_c6gc35b2/

├── .dockerignore
├── .gitattributes
├── .gitignore
├── .gitmodules
├── .travis.yml
├── LICENSE
├── MANIFEST.in
├── README.md
├── dockers/
│   ├── .gitignore
│   ├── base/
│   │   ├── Dockerfile
│   │   └── build.sh
│   └── branch/
│       ├── Dockerfile
│       └── build.sh
├── docs/
│   └── diagrams/
│       └── src/
│           ├── gym.puml
│           ├── pytorch_lightning.puml
│           └── seq_diagram.puml
├── examples/
│   ├── README.md
│   ├── __init__.py
│   ├── advanced/
│   │   ├── RL_and_SL_demo.py
│   │   ├── continual_rl_demo.py
│   │   ├── ewc_in_rl.py
│   │   ├── hat_demo.py
│   │   ├── hparam_tuning.py
│   │   ├── pnn/
│   │   │   ├── __init__.py
│   │   │   ├── layers.py
│   │   │   ├── model_rl.py
│   │   │   ├── model_sl.py
│   │   │   └── pnn_method.py
│   │   └── procgen_example.py
│   ├── basic/
│   │   ├── __init__.py
│   │   ├── base_method_demo.py
│   │   ├── pl_example.py
│   │   ├── pl_example_packnet.py
│   │   ├── pl_example_test.py
│   │   ├── quick_demo.ipynb
│   │   ├── quick_demo.py
│   │   ├── quick_demo_ewc.py
│   │   ├── quick_demo_packnet.py
│   │   └── quick_demo_test.py
│   ├── clcomp21/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── a2c_example.py
│   │   ├── a2c_example_test.py
│   │   ├── classifier.py
│   │   ├── classifier_test.py
│   │   ├── conftest.py
│   │   ├── dummy_method.py
│   │   ├── dummy_method_test.py
│   │   ├── multihead_classifier.py
│   │   ├── multihead_classifier_test.py
│   │   ├── regularization_example.py
│   │   ├── regularization_example_test.py
│   │   ├── sb3_example.py
│   │   └── sb3_example_test.py
│   ├── demo_utils.py
│   └── prerequisites/
│       └── dataclasses_example.py
├── mypy.ini
├── pytest.ini
├── requirements.txt
├── scripts/
│   ├── eai/
│   │   ├── cancel_all_queuing.sh
│   │   ├── cancel_all_running.sh
│   │   ├── job.sh
│   │   ├── rl_sweep.sh
│   │   ├── shell_job.sh
│   │   └── sl_sweep.sh
│   └── slurm/
│       ├── launch_many_sweeps.sh
│       ├── run.sh
│       └── sweep.sh
├── sequoia/
│   ├── README.md
│   ├── __init__.py
│   ├── _version.py
│   ├── client/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── __main__.py
│   │   ├── env.proto
│   │   ├── env_proxy.py
│   │   ├── env_proxy_test.py
│   │   ├── server.py
│   │   ├── setting_proxy.py
│   │   └── setting_proxy_test.py
│   ├── common/
│   │   ├── __init__.py
│   │   ├── batch.py
│   │   ├── batch_test.py
│   │   ├── callbacks/
│   │   │   ├── __init__.py
│   │   │   ├── knn_callback.py
│   │   │   └── vae_callback.py
│   │   ├── config/
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── wandb_config.py
│   │   ├── gym_wrappers/
│   │   │   ├── __init__.py
│   │   │   ├── action_limit.py
│   │   │   ├── action_limit_test.py
│   │   │   ├── add_done.py
│   │   │   ├── add_info.py
│   │   │   ├── convert_tensors.py
│   │   │   ├── convert_tensors_test.py
│   │   │   ├── env_dataset.py
│   │   │   ├── env_dataset_test.py
│   │   │   ├── episode_limit.py
│   │   │   ├── episode_limit_test.py
│   │   │   ├── measure_performance.py
│   │   │   ├── multi_task_environment.py
│   │   │   ├── multi_task_environment_test.py
│   │   │   ├── observation_limit.py
│   │   │   ├── observation_limit_test.py
│   │   │   ├── pixel_observation.py
│   │   │   ├── pixel_observation_test.py
│   │   │   ├── policy_env.py
│   │   │   ├── policy_env_test.py
│   │   │   ├── smooth_environment.py
│   │   │   ├── smooth_environment_test.py
│   │   │   ├── step_callback_wrapper.py
│   │   │   ├── step_callback_wrapper_test.py
│   │   │   ├── transform_wrappers.py
│   │   │   ├── transform_wrappers_test.py
│   │   │   ├── utils.py
│   │   │   └── utils_test.py
│   │   ├── hparams/
│   │   │   └── __init__.py
│   │   ├── layers.py
│   │   ├── loss.py
│   │   ├── loss_test.py
│   │   ├── metrics/
│   │   │   ├── __init__.py
│   │   │   ├── classification.py
│   │   │   ├── classification_test.py
│   │   │   ├── get_metrics.py
│   │   │   ├── metrics.py
│   │   │   ├── metrics_utils.py
│   │   │   ├── metrics_utils_test.py
│   │   │   ├── regression.py
│   │   │   └── rl_metrics.py
│   │   ├── replay.py
│   │   ├── spaces/
│   │   │   ├── __init__.py
│   │   │   ├── image.py
│   │   │   ├── named_tuple.py
│   │   │   ├── named_tuple_test.py
│   │   │   ├── space.py
│   │   │   ├── sparse.py
│   │   │   ├── sparse_test.py
│   │   │   ├── tensor_spaces.py
│   │   │   ├── tensor_spaces_test.py
│   │   │   ├── typed_dict.py
│   │   │   └── typed_dict_test.py
│   │   ├── task.py
│   │   └── transforms/
│   │       ├── __init__.py
│   │       ├── channels.py
│   │       ├── compose.py
│   │       ├── resize.py
│   │       ├── split_batch.py
│   │       ├── to_tensor.py
│   │       ├── transform.py
│   │       ├── transform_enum.py
│   │       ├── transforms_test.py
│   │       └── utils.py
│   ├── common.puml
│   ├── conftest.py
│   ├── experiments/
│   │   ├── __init__.py
│   │   ├── experiment.py
│   │   ├── experiment_test.py
│   │   ├── hpo_sweep.py
│   │   └── hpo_sweep_test.py
│   ├── main.py
│   ├── methods/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── aux_tasks/
│   │   │   ├── __init__.py
│   │   │   ├── auxiliary_task.py
│   │   │   ├── ewc.py
│   │   │   ├── reconstruction/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── ae.py
│   │   │   │   ├── decoder_for_dataset.py
│   │   │   │   ├── decoders.py
│   │   │   │   └── vae.py
│   │   │   └── transformation_based/
│   │   │       ├── __init__.py
│   │   │       ├── bases.py
│   │   │       └── rotation.py
│   │   ├── avalanche_methods/
│   │   │   ├── __init__.py
│   │   │   ├── agem.py
│   │   │   ├── agem_test.py
│   │   │   ├── ar1.py
│   │   │   ├── ar1_test.py
│   │   │   ├── base.py
│   │   │   ├── base_test.py
│   │   │   ├── conftest.py
│   │   │   ├── cwr_star.py
│   │   │   ├── cwr_star_test.py
│   │   │   ├── ewc.py
│   │   │   ├── ewc_test.py
│   │   │   ├── experience.py
│   │   │   ├── gdumb.py
│   │   │   ├── gdumb_test.py
│   │   │   ├── gem.py
│   │   │   ├── gem_test.py
│   │   │   ├── lwf.py
│   │   │   ├── lwf_test.py
│   │   │   ├── naive.py
│   │   │   ├── naive_test.py
│   │   │   ├── patched_models.py
│   │   │   ├── plugins.py
│   │   │   ├── replay.py
│   │   │   ├── replay_test.py
│   │   │   ├── synaptic_intelligence.py
│   │   │   └── synaptic_intelligence_test.py
│   │   ├── base_method.py
│   │   ├── base_method_test.py
│   │   ├── conftest.py
│   │   ├── d3rlpy_methods/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   └── base_test.py
│   │   ├── ewc_method.py
│   │   ├── ewc_method_test.py
│   │   ├── experience_replay.py
│   │   ├── experience_replay_test.py
│   │   ├── hat.py
│   │   ├── method_test.py
│   │   ├── models/
│   │   │   ├── __init__.py
│   │   │   ├── base_model/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── base_model.py
│   │   │   │   ├── model.py
│   │   │   │   ├── multihead_model.py
│   │   │   │   ├── multihead_model_test.py
│   │   │   │   ├── self_supervised_model.py
│   │   │   │   ├── self_supervised_model_test.py
│   │   │   │   └── semi_supervised_model.py
│   │   │   ├── baseline_model.puml
│   │   │   ├── fcnet.py
│   │   │   ├── forward_pass.py
│   │   │   ├── output_heads/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── classification_head.py
│   │   │   │   ├── output_head.py
│   │   │   │   ├── regression_head.py
│   │   │   │   └── rl/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── actor_critic_head.py
│   │   │   │       ├── episodic_a2c.py
│   │   │   │       ├── episodic_a2c_test.py
│   │   │   │       ├── policy_head.py
│   │   │   │       ├── policy_head_test.py
│   │   │   │       └── wasted_steps_calc.py
│   │   │   ├── output_heads.puml
│   │   │   └── simple_convnet.py
│   │   ├── models.puml
│   │   ├── packnet_method.py
│   │   ├── packnet_method_test.py
│   │   ├── pl_bolts_methods/
│   │   │   └── __init__.py
│   │   ├── pl_dqn.py
│   │   ├── pnn/
│   │   │   ├── __init__.py
│   │   │   ├── layers.py
│   │   │   ├── model_rl.py
│   │   │   ├── model_sl.py
│   │   │   └── pnn_method.py
│   │   ├── random_baseline.py
│   │   ├── random_baseline_test.py
│   │   ├── stable_baselines3_methods/
│   │   │   ├── __init__.py
│   │   │   ├── a2c.py
│   │   │   ├── a2c_test.py
│   │   │   ├── base.py
│   │   │   ├── base_test.py
│   │   │   ├── ddpg.py
│   │   │   ├── ddpg_test.py
│   │   │   ├── dqn.py
│   │   │   ├── dqn_test.py
│   │   │   ├── off_policy_method.py
│   │   │   ├── off_policy_method_test.py
│   │   │   ├── on_policy_method.py
│   │   │   ├── policy_wrapper.py
│   │   │   ├── ppo.py
│   │   │   ├── ppo_test.py
│   │   │   ├── sac.py
│   │   │   ├── sac_test.py
│   │   │   ├── td3.py
│   │   │   └── td3_test.py
│   │   └── trainer.py
│   ├── methods.puml
│   ├── sequoia.puml
│   ├── settings/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── assumptions/
│   │   │   ├── __init__.py
│   │   │   ├── assumptions.puml
│   │   │   ├── base.py
│   │   │   ├── classification.py
│   │   │   ├── context_discreteness.py
│   │   │   ├── context_visibility.py
│   │   │   ├── continual.py
│   │   │   ├── discrete_results.py
│   │   │   ├── iid.py
│   │   │   ├── iid_results.py
│   │   │   ├── incremental.py
│   │   │   ├── incremental_results.py
│   │   │   ├── incremental_test.py
│   │   │   ├── task_incremental.py
│   │   │   └── task_type.py
│   │   ├── base/
│   │   │   ├── __init__.py
│   │   │   ├── base.puml
│   │   │   ├── bases.py
│   │   │   ├── environment.py
│   │   │   ├── objects.py
│   │   │   ├── results.py
│   │   │   ├── setting.py
│   │   │   ├── setting_meta.py
│   │   │   └── setting_test.py
│   │   ├── offline_rl/
│   │   │   └── setting.py
│   │   ├── presets/
│   │   │   ├── __init__.py
│   │   │   ├── cartpole_pixels.yaml
│   │   │   ├── cartpole_state.yaml
│   │   │   ├── cifar10.yaml
│   │   │   ├── cifar100.yaml
│   │   │   ├── classic_control/
│   │   │   │   ├── cartpole.yaml
│   │   │   │   └── mountaincar_continuous.yaml
│   │   │   ├── fashion_mnist.yaml
│   │   │   ├── mnist.yaml
│   │   │   ├── monsterkong/
│   │   │   │   ├── monsterkong_3each.yaml
│   │   │   │   ├── monsterkong_4each.yaml
│   │   │   │   ├── monsterkong_5each.yaml
│   │   │   │   ├── monsterkong_all.yaml
│   │   │   │   ├── monsterkong_jumps.yaml
│   │   │   │   ├── monsterkong_jumps_and_ladders.yaml
│   │   │   │   ├── monsterkong_ladders.yaml
│   │   │   │   └── monsterkong_mix.yaml
│   │   │   ├── mujoco/
│   │   │   │   └── half_cheetah.yaml
│   │   │   ├── rl_track.yaml
│   │   │   └── sl_track.yaml
│   │   ├── rl/
│   │   │   ├── __init__.py
│   │   │   ├── continual/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── environment.py
│   │   │   │   ├── environment_test.py
│   │   │   │   ├── make_env.py
│   │   │   │   ├── make_env_test.py
│   │   │   │   ├── objects.py
│   │   │   │   ├── results.py
│   │   │   │   ├── setting.py
│   │   │   │   ├── setting_test.py
│   │   │   │   ├── tasks.py
│   │   │   │   ├── tasks_test.py
│   │   │   │   └── test_environment.py
│   │   │   ├── discrete/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── multienv_wrappers.py
│   │   │   │   ├── multienv_wrappers_test.py
│   │   │   │   ├── results.py
│   │   │   │   ├── setting.py
│   │   │   │   ├── setting_test.py
│   │   │   │   ├── tasks.py
│   │   │   │   ├── tasks_test.py
│   │   │   │   └── test_environment.py
│   │   │   ├── environment.py
│   │   │   ├── environment_test.py
│   │   │   ├── envs/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── classic_control.py
│   │   │   │   ├── monsterkong.py
│   │   │   │   ├── mujoco/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── half_cheetah.py
│   │   │   │   │   ├── half_cheetah_test.py
│   │   │   │   │   ├── hopper.py
│   │   │   │   │   ├── hopper_test.py
│   │   │   │   │   ├── modified_friction.py
│   │   │   │   │   ├── modified_friction_test.py
│   │   │   │   │   ├── modified_gravity.py
│   │   │   │   │   ├── modified_gravity_test.py
│   │   │   │   │   ├── modified_mass.py
│   │   │   │   │   ├── modified_mass_test.py
│   │   │   │   │   ├── modified_size.py
│   │   │   │   │   ├── modified_size_test.py
│   │   │   │   │   ├── modified_wall.py
│   │   │   │   │   ├── mujoco_model_utils.py
│   │   │   │   │   ├── walker2d.py
│   │   │   │   │   └── walker2d_test.py
│   │   │   │   └── variant_spec.py
│   │   │   ├── incremental/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── objects.py
│   │   │   │   ├── results.py
│   │   │   │   ├── setting.py
│   │   │   │   ├── setting_test.py
│   │   │   │   └── tasks.py
│   │   │   ├── multi_task/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── setting.py
│   │   │   │   └── setting_test.py
│   │   │   ├── objects.py
│   │   │   ├── setting.py
│   │   │   ├── setting_test.py
│   │   │   ├── task_incremental/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── setting.py
│   │   │   │   ├── setting_test.py
│   │   │   │   └── tasks.py
│   │   │   ├── traditional/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── setting.py
│   │   │   │   └── setting_test.py
│   │   │   └── wrappers/
│   │   │       ├── __init__.py
│   │   │       ├── measure_performance.py
│   │   │       ├── measure_performance_test.py
│   │   │       ├── no_typed_objects.py
│   │   │       ├── task_labels.py
│   │   │       └── typed_objects.py
│   │   ├── settings.puml
│   │   └── sl/
│   │       ├── README.md
│   │       ├── __init__.py
│   │       ├── continual/
│   │       │   ├── __init__.py
│   │       │   ├── environment.py
│   │       │   ├── environment_test.py
│   │       │   ├── envs.py
│   │       │   ├── objects.py
│   │       │   ├── results.py
│   │       │   ├── setting.py
│   │       │   ├── setting_test.py
│   │       │   └── wrappers.py
│   │       ├── discrete/
│   │       │   ├── __init__.py
│   │       │   ├── setting.py
│   │       │   └── setting_test.py
│   │       ├── domain_incremental/
│   │       │   ├── __init__.py
│   │       │   ├── setting.py
│   │       │   └── setting_test.py
│   │       ├── environment.py
│   │       ├── environment_test.py
│   │       ├── incremental/
│   │       │   ├── __init__.py
│   │       │   ├── environment.py
│   │       │   ├── environment_test.py
│   │       │   ├── objects.py
│   │       │   ├── results.py
│   │       │   ├── setting.py
│   │       │   ├── setting_test.py
│   │       │   └── unused_batch_transforms.py
│   │       ├── multi_task/
│   │       │   ├── __init__.py
│   │       │   ├── setting.py
│   │       │   └── setting_test.py
│   │       ├── setting.py
│   │       ├── task_incremental/
│   │       │   ├── __init__.py
│   │       │   ├── setting.py
│   │       │   └── setting_test.py
│   │       ├── traditional/
│   │       │   ├── __init__.py
│   │       │   ├── results.py
│   │       │   ├── setting.py
│   │       │   └── setting_test.py
│   │       └── wrappers/
│   │           ├── __init__.py
│   │           ├── measure_performance.py
│   │           └── measure_performance_test.py
│   ├── settings.puml
│   └── utils/
│       ├── __init__.py
│       ├── categorical.py
│       ├── data_utils.py
│       ├── encode.py
│       ├── generic_functions/
│       │   ├── __init__.py
│       │   ├── _namedtuple.py
│       │   ├── _namedtuple_test.py
│       │   ├── concatenate.py
│       │   ├── detach.py
│       │   ├── move.py
│       │   ├── replace.py
│       │   ├── replace_test.py
│       │   ├── singledispatchmethod.py
│       │   ├── slicing.py
│       │   ├── slicing_test.py
│       │   ├── stack.py
│       │   └── to_from_tensor.py
│       ├── logging_utils.py
│       ├── module_dict.py
│       ├── parseable.py
│       ├── plotting.py
│       ├── pretrained_utils.py
│       ├── readme.py
│       ├── serialization.py
│       └── utils.py
├── setup.cfg
├── setup.py
└── versioneer.py
Download .txt
Showing preview only (269K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (3006 symbols across 331 files)

FILE: examples/advanced/RL_and_SL_demo.py
  class SimpleRegularizationAuxTask (line 34) | class SimpleRegularizationAuxTask(AuxiliaryTask):
    class Options (line 49) | class Options(AuxiliaryTask.Options):
    method __init__ (line 61) | def __init__(
    method get_loss (line 75) | def get_loss(self, forward_pass: ForwardPass, y: Tensor = None) -> Loss:
    method on_task_switch (line 100) | def on_task_switch(self, task_id: int) -> None:
  class CustomizedBaselineModel (line 120) | class CustomizedBaselineModel(BaseModel):
    class HParams (line 122) | class HParams(BaseModel.HParams):
    method __init__ (line 130) | def __init__(
  class CustomMethod (line 149) | class CustomMethod(BaseMethod, target_setting=Setting):
    method __init__ (line 167) | def __init__(
    method create_model (line 181) | def create_model(self, setting: Setting) -> CustomizedBaselineModel:
    method configure (line 185) | def configure(self, setting: Setting):
    method fit (line 196) | def fit(self, train_env: Environment, valid_env: Environment):
    method add_argparse_args (line 207) | def add_argparse_args(cls, parser: ArgumentParser):
    method from_argparse_args (line 222) | def from_argparse_args(cls, args: Namespace):
  function demo_manual (line 232) | def demo_manual():
  function demo_command_line (line 273) | def demo_command_line():

FILE: examples/advanced/ewc_in_rl.py
  class NormRegularizer (line 33) | class NormRegularizer(PolicyWrapper[Policy]):
    method __init__ (line 41) | def __init__(self: Policy, *args, reg_coefficient: float = 1.0, ewc_p_...
    method on_task_switch (line 51) | def on_task_switch(self: Policy, task_id: Optional[int], *args, **kwar...
    method get_loss (line 69) | def get_loss(self: Policy) -> Union[float, Tensor]:
    method after_zero_grad (line 77) | def after_zero_grad(self: Policy):
    method before_optimizer_step (line 88) | def before_optimizer_step(self: Policy):
    method ewc_loss (line 93) | def ewc_loss(self: Policy) -> Union[float, Tensor]:
  class EWCPolicy (line 114) | class EWCPolicy(NormRegularizer):
    method __init__ (line 119) | def __init__(
    method consolidate (line 132) | def consolidate(self, new_fims: List[PMatAbstract], task: int) -> None:
    method _consolidate_fims (line 149) | def _consolidate_fims(
    method on_task_switch (line 162) | def on_task_switch(
    method ewc_loss (line 215) | def ewc_loss(self: Policy) -> Union[float, Tensor]:
  class ExampleRegularizationMethod (line 239) | class ExampleRegularizationMethod(StableBaselines3Method):
    method create_model (line 261) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> Base...
    method on_task_switch (line 272) | def on_task_switch(self, task_id: Optional[int]) -> None:
  class EWCExampleMethod (line 287) | class EWCExampleMethod(StableBaselines3Method):
    method create_model (line 300) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> Base...
    method on_task_switch (line 311) | def on_task_switch(self, task_id: Optional[int]) -> None:

FILE: examples/advanced/hat_demo.py
  class Masks (line 24) | class Masks(NamedTuple):
  class HatNet (line 34) | class HatNet(torch.nn.Module):
    method __init__ (line 48) | def __init__(self, image_space: Image, n_classes_per_task: Dict[int, i...
    method forward (line 96) | def forward(self, observations: TaskIncrementalSLSetting.Observations)...
    method mask (line 133) | def mask(self, t: Tensor, s_hat: float) -> Masks:
    method shared_step (line 141) | def shared_step(
  function compute_conv_output_size (line 192) | def compute_conv_output_size(
  class HatDemoMethod (line 199) | class HatDemoMethod(Method, target_setting=TaskIncrementalSLSetting):
    class HParams (line 209) | class HParams:
    method __init__ (line 221) | def __init__(self, hparams: HParams = None):
    method configure (line 228) | def configure(self, setting: TaskIncrementalSLSetting):
    method fit (line 252) | def fit(self, train_env: PassiveEnvironment, valid_env: PassiveEnviron...
    method get_actions (line 307) | def get_actions(self, observations: Observations, action_space: gym.Sp...
    method on_task_switch (line 315) | def on_task_switch(self, task_id: Optional[int]):
    method add_argparse_args (line 323) | def add_argparse_args(cls, parser: ArgumentParser) -> None:
    method from_argparse_args (line 329) | def from_argparse_args(cls, args: Namespace) -> "HatDemoMethod":

FILE: examples/advanced/pnn/layers.py
  class PNNConvLayer (line 10) | class PNNConvLayer(nn.Module):
    method __init__ (line 11) | def __init__(self, col, depth, n_in, n_out, kernel_size=3):
    method forward (line 22) | def forward(self, inputs):
  class PNNLinearBlock (line 32) | class PNNLinearBlock(nn.Module):
    method __init__ (line 33) | def __init__(self, col: int, depth: int, n_in: int, n_out: int):
    method forward (line 41) | def forward(self, inputs):

FILE: examples/advanced/pnn/model_rl.py
  class PnnA2CAgent (line 10) | class PnnA2CAgent(nn.Module):
    method __init__ (line 20) | def __init__(self, arch="mlp", hidden_size=256):
    method forward (line 38) | def forward(self, observations):
    method new_task (line 91) | def new_task(self, device, num_inputs, num_actions=5):
    method unfreeze_columns (line 131) | def unfreeze_columns(self):
    method freeze_columns (line 143) | def freeze_columns(self, skip=None):
    method parameters (line 164) | def parameters(self, task_id):
    method transfor_img (line 177) | def transfor_img(self, img):

FILE: examples/advanced/pnn/model_sl.py
  class PnnClassifier (line 15) | class PnnClassifier(nn.Module):
    method __init__ (line 25) | def __init__(self, n_layers):
    method forward (line 35) | def forward(self, observations):
    method new_task (line 68) | def new_task(self, device, sizes: List[int]):
    method freeze_columns (line 87) | def freeze_columns(self, skip=None):
    method shared_step (line 102) | def shared_step(
    method parameters (line 153) | def parameters(self, task_id):

FILE: examples/advanced/pnn/pnn_method.py
  class PnnMethod (line 36) | class PnnMethod(Method, target_setting=Setting):
    class HParams (line 48) | class HParams:
    method __init__ (line 64) | def __init__(self, hparams: HParams = None):
    method configure (line 72) | def configure(self, setting: Setting):
    method on_task_switch (line 154) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method set_optimizer (line 178) | def set_optimizer(self):
    method get_actions (line 184) | def get_actions(self, observations: Observations, action_space: spaces...
    method fit (line 203) | def fit(self, train_env: Environment, valid_env: Environment):
    method fit_rl (line 215) | def fit_rl(self, train_env: gym.Env, valid_env: gym.Env):
    method fit_sl (line 291) | def fit_sl(self, train_env: PassiveEnvironment, valid_env: PassiveEnvi...
    method add_argparse_args (line 337) | def add_argparse_args(cls, parser: ArgumentParser) -> None:
    method from_argparse_args (line 341) | def from_argparse_args(cls, args: Namespace) -> "PnnMethod":
  function main_rl (line 347) | def main_rl():
  function main_sl (line 380) | def main_sl():

FILE: examples/advanced/procgen_example.py
  class ProcGenConfig (line 22) | class ProcGenConfig:
    method make_env (line 69) | def make_env(self) -> gym.Env:
  class SequoiaProcGenAdapterWrapper (line 103) | class SequoiaProcGenAdapterWrapper(gym.Wrapper):
    method __init__ (line 112) | def __init__(self, env):
    method step (line 115) | def step(self, action):
    method seed (line 121) | def seed(self, seed: Optional[int] = None) -> List[int]:
    method render (line 126) | def render(self, mode: str = "rgb_array"):
  function make_procgen_setting (line 156) | def make_procgen_setting(
  function main_simple (line 253) | def main_simple():
  function main_using_other_setting (line 262) | def main_using_other_setting():

FILE: examples/basic/base_method_demo.py
  function baseline_demo_simple (line 12) | def baseline_demo_simple():
  function baseline_demo_command_line (line 33) | def baseline_demo_command_line():

FILE: examples/basic/pl_example.py
  class Model (line 31) | class Model(LightningModule):
    class HParams (line 38) | class HParams:
    method __init__ (line 50) | def __init__(
    method forward (line 107) | def forward(self, observations: ContinualSLSetting.Observations) -> Te...
    method training_step (line 129) | def training_step(
    method validation_step (line 134) | def validation_step(
    method test_step (line 139) | def test_step(self, batch: Tuple[Observations, Optional[Rewards]], bat...
    method shared_step (line 142) | def shared_step(
    method configure_optimizers (line 174) | def configure_optimizers(self):
  class ExampleMethod (line 178) | class ExampleMethod(Method, target_setting=ContinualSLSetting):
    method __init__ (line 194) | def __init__(self, hparams: Model.HParams = None):
    method configure (line 202) | def configure(self, setting: ContinualSLSetting):
    method fit (line 226) | def fit(
    method test (line 263) | def test(self, test_env: ContinualSLSetting.Environment):
    method get_actions (line 282) | def get_actions(self, observations: Observations, action_space: spaces...
    method on_task_switch (line 294) | def on_task_switch(self, task_id: Optional[int]) -> None:
  function main (line 310) | def main():

FILE: examples/basic/pl_example_packnet.py
  class ExamplePackNetMethod (line 15) | class ExamplePackNetMethod(ExampleMethod, target_setting=TaskIncremental...
    method __init__ (line 16) | def __init__(self, hparams: Model.HParams = None, packnet_hparams: Pac...
    method configure (line 26) | def configure(self, setting: TaskIncrementalSLSetting):
    method fit (line 37) | def fit(
    method on_task_switch (line 52) | def on_task_switch(self, task_id: Optional[int]):
  function main (line 67) | def main():

FILE: examples/basic/pl_example_test.py
  class TestPLExample (line 21) | class TestPLExample(MethodTests):
    method method (line 30) | def method(self, config: Config):
    method validate_results (line 34) | def validate_results(

FILE: examples/basic/quick_demo.py
  class MyModel (line 28) | class MyModel(nn.Module):
    method __init__ (line 39) | def __init__(
    method forward (line 72) | def forward(self, observations: Observations) -> Tensor:
    method shared_step (line 80) | def shared_step(
  class DemoMethod (line 129) | class DemoMethod(Method, target_setting=DomainIncrementalSLSetting):
    class HParams (line 136) | class HParams:
    method __init__ (line 142) | def __init__(self, hparams: HParams = None):
    method configure (line 151) | def configure(self, setting: DomainIncrementalSLSetting):
    method fit (line 167) | def fit(self, train_env: PassiveEnvironment, valid_env: PassiveEnviron...
    method get_actions (line 215) | def get_actions(self, observations: Observations, action_space: gym.Sp...
    method add_argparse_args (line 224) | def add_argparse_args(cls, parser: ArgumentParser):
    method from_argparse_args (line 229) | def from_argparse_args(cls, args: Namespace):
  function demo_simple (line 235) | def demo_simple():
  function demo_command_line (line 254) | def demo_command_line():

FILE: examples/basic/quick_demo_ewc.py
  class MyImprovedModel (line 23) | class MyImprovedModel(MyModel):
    method __init__ (line 26) | def __init__(
    method shared_step (line 47) | def shared_step(self, batch: Tuple[Observations, Rewards], *args, **kw...
    method on_task_switch (line 53) | def on_task_switch(self, task_id: int) -> None:
    method ewc_loss (line 70) | def ewc_loss(self) -> Tensor:
  class ImprovedDemoMethod (line 90) | class ImprovedDemoMethod(DemoMethod):
    class HParams (line 97) | class HParams(DemoMethod.HParams):
    method __init__ (line 105) | def __init__(self, hparams: HParams = None):
    method configure (line 108) | def configure(self, setting: DomainIncrementalSLSetting):
    method on_task_switch (line 122) | def on_task_switch(self, task_id: Optional[int]):
  function demo_ewc (line 126) | def demo_ewc():

FILE: examples/basic/quick_demo_test.py
  function test_quick_demo (line 13) | def test_quick_demo(monkeypatch):

FILE: examples/clcomp21/a2c_example.py
  class ActorCritic (line 28) | class ActorCritic(nn.Module):
    method __init__ (line 29) | def __init__(
    method forward (line 95) | def forward(self, observation: RLSetting.Observations) -> Tuple[Tensor...
  class ExampleA2CMethod (line 125) | class ExampleA2CMethod(Method, target_setting=RLSetting):
    class HParams (line 133) | class HParams(HyperParameters):
    method __init__ (line 150) | def __init__(self, hparams: HParams = None, render: bool = False):
    method configure (line 157) | def configure(self, setting: RLSetting):
    method fit (line 169) | def fit(self, train_env: ActiveEnvironment, valid_env: ActiveEnvironme...
    method get_actions (line 290) | def get_actions(
    method on_task_switch (line 300) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method add_argparse_args (line 314) | def add_argparse_args(cls, parser: ArgumentParser):
    method from_argparse_args (line 318) | def from_argparse_args(cls, args: Namespace):
    method get_search_space (line 322) | def get_search_space(self, setting: RLSetting) -> Dict:
    method adapt_to_new_hparams (line 325) | def adapt_to_new_hparams(self, new_hparams: Dict) -> None:

FILE: examples/clcomp21/a2c_example_test.py
  function test_cartpole_state (line 14) | def test_cartpole_state(cartpole_state_setting: SettingProxy[RLSetting]):
  function test_incremental_cartpole_state (line 30) | def test_incremental_cartpole_state(
  function test_RL_track (line 46) | def test_RL_track(rl_track_setting: SettingProxy[IncrementalRLSetting]):

FILE: examples/clcomp21/classifier.py
  class HParams (line 29) | class HParams(HyperParameters):
  class Classifier (line 43) | class Classifier(nn.Module):
    method __init__ (line 51) | def __init__(
    method create_output_head (line 74) | def create_output_head(self) -> nn.Module:
    method configure_optimizers (line 77) | def configure_optimizers(self) -> Optimizer:
    method create_encoder (line 84) | def create_encoder(self, image_space: Image) -> Tuple[nn.Module, int]:
    method forward (line 133) | def forward(self, observations: Observations) -> Tensor:
    method shared_step (line 142) | def shared_step(
  class ExampleMethod (line 191) | class ExampleMethod(Method, target_setting=ClassIncrementalSetting):
    method __init__ (line 199) | def __init__(self, hparams: HParams = None):
    method configure (line 206) | def configure(self, setting: ClassIncrementalSetting):
    method fit (line 219) | def fit(self, train_env: PassiveEnvironment, valid_env: PassiveEnviron...
    method get_actions (line 270) | def get_actions(self, observations: Observations, action_space: gym.Sp...
    method add_argparse_args (line 279) | def add_argparse_args(cls, parser: ArgumentParser):
    method from_argparse_args (line 284) | def from_argparse_args(cls, args: Namespace):

FILE: examples/clcomp21/classifier_test.py
  function test_mnist (line 11) | def test_mnist(mnist_setting: SettingProxy[ClassIncrementalSetting]):
  function test_SL_track (line 24) | def test_SL_track(sl_track_setting: SettingProxy[ClassIncrementalSetting]):

FILE: examples/clcomp21/conftest.py
  function mnist_setting (line 9) | def mnist_setting():
  function task_incremental_mnist_setting (line 18) | def task_incremental_mnist_setting():
  function fashion_mnist_setting (line 27) | def fashion_mnist_setting():
  function sl_track_setting (line 36) | def sl_track_setting():
  function cartpole_state_setting (line 49) | def cartpole_state_setting():
  function incremental_cartpole_state_setting (line 61) | def incremental_cartpole_state_setting():
  function rl_track_setting (line 73) | def rl_track_setting(tmp_path):

FILE: examples/clcomp21/dummy_method.py
  class DummyMethod (line 13) | class DummyMethod(Method, target_setting=Setting):
    method __init__ (line 16) | def __init__(self):
    method configure (line 19) | def configure(self, setting: Setting):
    method fit (line 30) | def fit(self, train_env: Environment, valid_env: Environment):
    method get_actions (line 74) | def get_actions(self, observations: Observations, action_space: gym.Sp...

FILE: examples/clcomp21/dummy_method_test.py
  function test_mnist (line 12) | def test_mnist(mnist_setting: SettingProxy[ClassIncrementalSetting]):
  function test_SL_track (line 25) | def test_SL_track(sl_track_setting: SettingProxy[ClassIncrementalSetting]):
  function test_RL_track (line 41) | def test_RL_track(rl_track_setting: SettingProxy[IncrementalRLSetting]):

FILE: examples/clcomp21/multihead_classifier.py
  class MultiHeadClassifier (line 23) | class MultiHeadClassifier(Classifier):
    class HParams (line 25) | class HParams(Classifier.HParams):
    method __init__ (line 28) | def __init__(
    method configure_optimizers (line 46) | def configure_optimizers(self) -> Optimizer:
    method create_output_head (line 50) | def create_output_head(self) -> nn.Module:
    method get_or_create_output_head (line 53) | def get_or_create_output_head(self, task_id: int) -> nn.Module:
    method forward (line 70) | def forward(self, observations: Observations) -> Tensor:
    method split_forward_pass (line 128) | def split_forward_pass(self, observations: Observations) -> Tensor:
    method task_inference_forward_pass (line 184) | def task_inference_forward_pass(self, observations: Observations) -> T...
    method on_task_switch (line 267) | def on_task_switch(self, task_id: Optional[int]):
  class ExampleTaskInferenceMethod (line 275) | class ExampleTaskInferenceMethod(ExampleMethod):
    method __init__ (line 279) | def __init__(self, hparams: MultiHeadClassifier.HParams = None):
    method configure (line 283) | def configure(self, setting: ClassIncrementalSetting):
    method on_task_switch (line 300) | def on_task_switch(self, task_id: Optional[int]):
    method get_actions (line 303) | def get_actions(self, observations, action_space):

FILE: examples/clcomp21/multihead_classifier_test.py
  function test_task_incremental_mnist (line 11) | def test_task_incremental_mnist(
  function test_mnist (line 27) | def test_mnist(mnist_setting: SettingProxy[ClassIncrementalSetting]):
  function test_SL_track (line 41) | def test_SL_track(sl_track_setting: SettingProxy[ClassIncrementalSetting]):

FILE: examples/clcomp21/regularization_example.py
  class RegularizedClassifier (line 23) | class RegularizedClassifier(MultiHeadClassifier):
    class HParams (line 29) | class HParams(MultiHeadClassifier.HParams):
    method __init__ (line 45) | def __init__(
    method shared_step (line 66) | def shared_step(self, batch: Tuple[Observations, Rewards], *args, **kw...
    method on_task_switch (line 72) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method ewc_loss (line 90) | def ewc_loss(self) -> Tensor:
  class ExampleRegMethod (line 110) | class ExampleRegMethod(ExampleTaskInferenceMethod):
    method __init__ (line 115) | def __init__(self, hparams: HParams = None):
    method configure (line 118) | def configure(self, setting: DomainIncrementalSLSetting):
    method on_task_switch (line 128) | def on_task_switch(self, task_id: Optional[int]):

FILE: examples/clcomp21/regularization_example_test.py
  function test_mnist (line 11) | def test_mnist(mnist_setting: SettingProxy[ClassIncrementalSetting]):
  function test_SL_track (line 25) | def test_SL_track(sl_track_setting: SettingProxy[ClassIncrementalSetting]):

FILE: examples/clcomp21/sb3_example.py
  class CustomPPOModel (line 16) | class CustomPPOModel(PPOModel):
    class HParams (line 18) | class HParams(PPOModel.HParams):
  class CustomPPOMethod (line 23) | class CustomPPOMethod(PPOMethod):
    method configure (line 28) | def configure(self, setting: ContinualRLSetting):
    method create_model (line 31) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> PPOM...
    method fit (line 34) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 37) | def get_actions(
    method on_task_switch (line 45) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method get_search_space (line 55) | def get_search_space(self, setting: ContinualRLSetting) -> Mapping[str...

FILE: examples/clcomp21/sb3_example_test.py
  function test_cartpole_state (line 12) | def test_cartpole_state(cartpole_state_setting: SettingProxy[RLSetting]):
  function test_incremental_cartpole_state (line 25) | def test_incremental_cartpole_state(
  function test_RL_track (line 40) | def test_RL_track(rl_track_setting: SettingProxy[IncrementalRLSetting]):

FILE: examples/demo_utils.py
  function demo_all_settings (line 12) | def demo_all_settings(
  function make_result_dataframe (line 77) | def make_result_dataframe(all_results):
  function compare_results (line 109) | def compare_results(
  function make_comparison_dataframe (line 135) | def make_comparison_dataframe(

FILE: examples/prerequisites/dataclasses_example.py
  class Point (line 9) | class Point:
  class HParams (line 33) | class HParams:

FILE: sequoia/_version.py
  function get_keywords (line 19) | def get_keywords():
  class VersioneerConfig (line 32) | class VersioneerConfig:
  function get_config (line 36) | def get_config():
  class NotThisMethod (line 50) | class NotThisMethod(Exception):
  function register_vcs_handler (line 58) | def register_vcs_handler(vcs, method):  # decorator
  function run_command (line 71) | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=Fal...
  function versions_from_parentdir (line 108) | def versions_from_parentdir(parentdir_prefix, root, verbose):
  function git_get_keywords (line 140) | def git_get_keywords(versionfile_abs):
  function git_versions_from_keywords (line 169) | def git_versions_from_keywords(keywords, tag_prefix, verbose):
  function git_pieces_from_vcs (line 235) | def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_comma...
  function plus_or_dot (line 325) | def plus_or_dot(pieces):
  function render_pep440 (line 332) | def render_pep440(pieces):
  function render_pep440_pre (line 356) | def render_pep440_pre(pieces):
  function render_pep440_post (line 372) | def render_pep440_post(pieces):
  function render_pep440_old (line 399) | def render_pep440_old(pieces):
  function render_git_describe (line 421) | def render_git_describe(pieces):
  function render_git_describe_long (line 441) | def render_git_describe_long(pieces):
  function render (line 461) | def render(pieces, style):
  function get_versions (line 499) | def get_versions():

FILE: sequoia/client/env_proxy.py
  class EnvironmentProxy (line 26) | class EnvironmentProxy(Environment[ObservationType, ActionType, RewardTy...
    method __init__ (line 27) | def __init__(self, env_fn, setting_type: Type[Setting]):
    method get_attribute (line 51) | def get_attribute(self, name: str, default: Any = MISSING) -> Any:
    method reset (line 58) | def reset(self) -> ObservationType:
    method __len__ (line 62) | def __len__(self) -> int:
    method step (line 65) | def step(
    method __iter__ (line 91) | def __iter__(self):
    method __next__ (line 94) | def __next__(self) -> ObservationType:
    method send (line 97) | def send(self, actions: ActionType):
    method close (line 110) | def close(self):
    method is_closed (line 114) | def is_closed(self) -> bool:
    method render (line 117) | def render(self, *args, **kwargs):
    method get_results (line 120) | def get_results(self) -> Results:
    method get_online_performance (line 123) | def get_online_performance(self) -> List[Metrics]:
    method get_average_online_performance (line 126) | def get_average_online_performance(self) -> Metrics:
    method __getattr__ (line 129) | def __getattr__(self, name: str):

FILE: sequoia/client/env_proxy_test.py
  function wrap_type_with_proxy (line 30) | def wrap_type_with_proxy(env_type: Type[EnvType]) -> EnvType:
  class TestEnvironmentProxy (line 44) | class TestEnvironmentProxy(_TestEnvDataset, _TestPassiveEnvironment, _Te...
  function test_sanity_check (line 56) | def test_sanity_check():
  function test_is_proxy_to (line 63) | def test_is_proxy_to(use_wrapper: bool):
  function test_issue_204 (line 101) | def test_issue_204():
  function test_interaction_with_test_environment (line 191) | def test_interaction_with_test_environment():

FILE: sequoia/client/server.py
  function server (line 1) | def server(grpc_host: str, grpc_port: int):

FILE: sequoia/client/setting_proxy.py
  class SettingProxy (line 35) | class SettingProxy(SettingABC, Generic[SettingType]):
    method __init__ (line 58) | def __init__(
    method observation_space (line 82) | def observation_space(self) -> gym.Space:
    method action_space (line 87) | def action_space(self) -> gym.Space:
    method reward_space (line 91) | def reward_space(self) -> gym.Space:
    method train_env (line 95) | def train_env(self) -> EnvironmentProxy:
    method val_env (line 99) | def val_env(self) -> EnvironmentProxy:
    method test_env (line 103) | def test_env(self) -> EnvironmentProxy:
    method test_env (line 109) | def test_env(self, value) -> None:
    method _temp_make_readable (line 114) | def _temp_make_readable(self, attribute: str) -> None:
    method config (line 119) | def config(self) -> Config:
    method config (line 123) | def config(self, value: Config) -> None:
    method prepare_data (line 126) | def prepare_data(self, *args, **kwargs):
    method setup (line 129) | def setup(self, stage: str = None):
    method get_name (line 132) | def get_name(self):
    method _is_readable (line 135) | def _is_readable(self, attribute: str) -> bool:
    method _is_writeable (line 147) | def _is_writeable(self, attribute: str) -> bool:
    method batch_size (line 160) | def batch_size(self) -> Optional[int]:
    method batch_size (line 164) | def batch_size(self, value: Optional[int]) -> None:
    method train_transforms (line 168) | def train_transforms(self) -> List[Callable]:
    method train_transforms (line 172) | def train_transforms(self, value: List[Callable]):
    method val_transforms (line 176) | def val_transforms(self) -> List[Callable]:
    method val_transforms (line 180) | def val_transforms(self, value: List[Callable]):
    method test_transforms (line 184) | def test_transforms(self) -> List[Callable]:
    method test_transforms (line 188) | def test_transforms(self, value: List[Callable]):
    method apply (line 191) | def apply(self, method: Method, config: Config = None) -> Results:
    method get_attribute (line 211) | def get_attribute(self, name: str) -> Any:
    method set_attribute (line 225) | def set_attribute(self, name: str, value: Any) -> None:
    method train_dataloader (line 228) | def train_dataloader(self, batch_size: int = None, num_workers: int = ...
    method val_dataloader (line 255) | def val_dataloader(self, batch_size: int = None, num_workers: int = No...
    method test_dataloader (line 280) | def test_dataloader(self, batch_size: int = None, num_workers: int = N...
    method __test_dataloader (line 296) | def __test_dataloader(
    method main_loop (line 315) | def main_loop(self, method: Method) -> Results:
    method test_loop (line 394) | def test_loop(self, method: Method) -> "IncrementalAssumption.Results":
    method __getattr__ (line 497) | def __getattr__(self, name: str):

FILE: sequoia/client/setting_proxy_test.py
  function test_spaces_match (line 34) | def test_spaces_match(setting_type: Type[Setting]):
  function test_transforms_get_propagated (line 42) | def test_transforms_get_propagated():
  class TestContinualSLSettingProxy (line 56) | class TestContinualSLSettingProxy(ContinualSLSettingTests):
  class TestContinualRLSettingProxy (line 60) | class TestContinualRLSettingProxy(ContinualRLSettingTests):
  function test_random_baseline (line 65) | def test_random_baseline(config):
  function test_random_baseline_rl (line 74) | def test_random_baseline_rl():
  function test_random_baseline_SL_track (line 102) | def test_random_baseline_SL_track():
  function test_baseline_SL_track (line 111) | def test_baseline_SL_track(config):
  function test_rl_track_setting_is_correct (line 136) | def test_rl_track_setting_is_correct():
  function test_sl_track_setting_is_correct (line 184) | def test_sl_track_setting_is_correct():

FILE: sequoia/common/batch.py
  function hasmethod (line 48) | def hasmethod(obj: Any, method_name: str) -> bool:
  class Batch (line 53) | class Batch(ABC, Mapping[str, T]):
    method __init_subclass__ (line 157) | def __init_subclass__(cls, *args, **kwargs):
    method __post_init__ (line 166) | def __post_init__(self):
    method __iter__ (line 177) | def __iter__(self) -> Iterator[str]:
    method __len__ (line 181) | def __len__(self) -> int:
    method __eq__ (line 185) | def __eq__(self, other: Union["Batch", Any]) -> bool:
    method __getitem__ (line 201) | def __getitem__(self, index: Any) -> T:
    method _getitem_none (line 208) | def _getitem_none(self, index: None) -> "Batch":
    method _getitem_by_name (line 216) | def _getitem_by_name(self, index: str) -> Union[Tensor, Any]:
    method _getitem_by_index (line 220) | def _getitem_by_index(self, index: int) -> Union[Tensor, Any]:
    method _getitem_with_slice (line 224) | def _getitem_with_slice(self, index: slice) -> "Batch":
    method _ (line 234) | def _(self: B, index) -> B:
    method _getitem_with_array (line 239) | def _getitem_with_array(self, index: np.ndarray) -> B:
    method _getitem_with_tuple (line 248) | def _getitem_with_tuple(self, index: Tuple[Union[slice, Tensor, np.nda...
    method slice (line 300) | def slice(self: B, index: Union[int, slice, np.ndarray, Tensor]) -> B:
    method __setitem__ (line 325) | def __setitem__(self, index: Union[int, str], value: Any):
    method keys (line 341) | def keys(self) -> KeysView[str]:
    method values (line 344) | def values(self) -> Tuple[T, ...]:
    method items (line 347) | def items(self) -> Iterable[Tuple[str, T]]:
    method devices (line 352) | def devices(self) -> Dict[str, Union[Optional[torch.device], Dict]]:
    method device (line 363) | def device(self) -> Optional[torch.device]:
    method dtypes (line 392) | def dtypes(self) -> Dict[str, Union[Optional[torch.dtype], Dict]]:
    method dtype (line 403) | def dtype(self) -> Tuple[Optional[torch.dtype]]:
    method as_namedtuple (line 423) | def as_namedtuple(self) -> Tuple[T, ...]:
    method as_list_of_tuples (line 426) | def as_list_of_tuples(self) -> Iterable[Tuple[T, ...]]:
    method as_tuple (line 441) | def as_tuple(self) -> Tuple[T, ...]:
    method to (line 458) | def to(self, *args, **kwargs):
    method float (line 466) | def float(self, dtype=torch.float):
    method float32 (line 469) | def float32(self, dtype=torch.float32):
    method int (line 472) | def int(self, dtype=torch.int):
    method double (line 475) | def double(self, dtype=torch.double):
    method numpy (line 478) | def numpy(self):
    method detach (line 499) | def detach(self):
    method cpu (line 515) | def cpu(self, **kwargs):
    method cuda (line 526) | def cuda(self, device=None, **kwargs):
    method shapes (line 538) | def shapes(self) -> Dict[str, Union[torch.Size, Dict]]:
    method batch_size (line 549) | def batch_size(self) -> Optional[int]:
    method with_batch_dimension (line 579) | def with_batch_dimension(self: B) -> B:
    method remove_batch_dimension (line 602) | def remove_batch_dimension(self: B) -> B:
    method split (line 611) | def split(self: B) -> List[B]:
    method stack (line 620) | def stack(cls: Type[B], items: List[B]) -> B:
    method concatenate (line 629) | def concatenate(cls: Type[B], items: List[B], **kwargs) -> B:
    method torch (line 636) | def torch(self, device: Union[str, torch.device] = None, dtype: torch....
    method _map (line 651) | def _map(self: B, func: Callable, *args, recursive: bool = True, **kwa...
    method _apply (line 670) | def _apply(
  function _replace_batch_items (line 689) | def _replace_batch_items(obj: Batch, **items) -> Batch:
  function _get_batch_slice (line 699) | def _get_batch_slice(value: Batch, indices: Sequence[int]) -> Batch:
  function set_batch_slice (line 709) | def set_batch_slice(target: Batch, indices: Sequence[int], values: Batch...

FILE: sequoia/common/batch_test.py
  class Observations (line 19) | class Observations(Batch):
  class Actions (line 25) | class Actions(Batch):
  class RLActions (line 30) | class RLActions(Actions):
  class Rewards (line 35) | class Rewards(Batch):
  function test_batch_behaves_like_a_dict (line 51) | def test_batch_behaves_like_a_dict(batch_type, items_dict):
  function test_to (line 81) | def test_to(batch_type: Type[Batch], items_dict: Dict[str, Tensor]):
  function test_tuple_indexing (line 154) | def test_tuple_indexing(
  function test_masking (line 189) | def test_masking():
  function test_newaxis (line 215) | def test_newaxis():
  function test_single_index (line 231) | def test_single_index():
  function test_remove_batch_dim (line 240) | def test_remove_batch_dim():
  function test_remove_batch_dim_with_nested_objects (line 271) | def test_remove_batch_dim_with_nested_objects():
  function test_split (line 301) | def test_split():
  function test_stack (line 354) | def test_stack(items: List[Batch], expected: Batch):
  function test_stack_with_none_values (line 400) | def test_stack_with_none_values(items: List[Batch], expected: Batch):
  function test_concatenate (line 464) | def test_concatenate(items: List[Batch], expected: Batch):
  function test_convert_between_ndarrays_and_tensors (line 488) | def test_convert_between_ndarrays_and_tensors(numpy_batch: Batch, torch_...
  class ForwardPass (line 507) | class ForwardPass(Batch):
  function test_nesting (line 513) | def test_nesting():
  function test_slicing_with_one_item (line 536) | def test_slicing_with_one_item():

FILE: sequoia/common/callbacks/knn_callback.py
  class KnnClassifierOptions (line 35) | class KnnClassifierOptions:
  class KnnCallback (line 47) | class KnnCallback(Callback):
    method __post_init__ (line 66) | def __post_init__(self):
    method on_train_start (line 72) | def on_train_start(self, trainer, pl_module):
    method setup (line 78) | def setup(self, trainer, pl_module, stage: str):
    method on_epoch_end (line 82) | def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
    method log (line 110) | def log(self, loss_object: Loss):
    method get_dataloaders (line 114) | def get_dataloaders(self, model: LightningModule, mode: str) -> List[D...
    method evaluate_knn (line 126) | def evaluate_knn(self, model: LightningModule) -> Tuple[Loss, Loss]:
  function evaluate (line 239) | def evaluate(
  function get_hidden_codes_array (line 290) | def get_hidden_codes_array(
  function fit_knn (line 316) | def fit_knn(
  function get_knn_performance (line 340) | def get_knn_performance(

FILE: sequoia/common/callbacks/vae_callback.py
  class SaveVaeSamplesCallback (line 17) | class SaveVaeSamplesCallback(Callback):
    method __post_init__ (line 24) | def __post_init__(self, *args, **kwargs):
    method setup (line 31) | def setup(self, trainer, pl_module, stage: str):
    method on_train_start (line 35) | def on_train_start(self, trainer, pl_module):
    method on_train_epoch_end (line 52) | def on_train_epoch_end(self, trainer: Trainer, pl_module: BaseModel):
    method reconstruct_samples (line 65) | def reconstruct_samples(self, data: Tensor):
    method generate_samples (line 85) | def generate_samples(self):

FILE: sequoia/common/config/config.py
  class Config (line 27) | class Config(Serializable, Parseable):
    method __post_init__ (line 54) | def __post_init__(self):
    method __del__ (line 61) | def __del__(self):
    method get_display (line 65) | def get_display(self) -> Optional[Display]:
    method seed_everything (line 86) | def seed_everything(self) -> None:

FILE: sequoia/common/config/wandb_config.py
  function patched_monitor (line 17) | def patched_monitor():
  class WandbConfig (line 58) | class WandbConfig(Serializable):
    method log_dir (line 126) | def log_dir(self):
    method wandb_login (line 134) | def wandb_login(self) -> bool:
    method wandb_init_kwargs (line 156) | def wandb_init_kwargs(self) -> Dict:
    method wandb_init (line 179) | def wandb_init(self, config_dict: Dict = None) -> wandb.wandb_run.Run:
    method make_logger (line 217) | def make_logger(self, wandb_parent_dir: Path = None) -> WandbLogger:

FILE: sequoia/common/gym_wrappers/action_limit.py
  class ActionCounter (line 13) | class ActionCounter(IterableWrapper):
    method __init__ (line 18) | def __init__(self, env: gym.Env):
    method step_count (line 22) | def step_count(self) -> int:
    method action_count (line 25) | def action_count(self) -> int:
    method step (line 28) | def step(self, action):
  class ActionLimit (line 34) | class ActionLimit(ActionCounter):
    method __init__ (line 42) | def __init__(self, env: gym.Env, max_steps: int):
    method max_steps (line 50) | def max_steps(self) -> int:
    method __len__ (line 53) | def __len__(self):
    method closed_error_message (line 56) | def closed_error_message(self) -> str:
    method step (line 59) | def step(self, action):

FILE: sequoia/common/gym_wrappers/action_limit_test.py
  function test_basics (line 12) | def test_basics():
  function test_EnvDataset_of_ActionLimit (line 17) | def test_EnvDataset_of_ActionLimit():
  function test_ActionLimit_of_EnvDataset (line 51) | def test_ActionLimit_of_EnvDataset():
  function test_delayed_EnvDataset_of_ActionLimit (line 86) | def test_delayed_EnvDataset_of_ActionLimit():

FILE: sequoia/common/gym_wrappers/add_done.py
  function add_done (line 24) | def add_done(observation: Any, done: Any) -> Any:
  function _add_done_to_array_obs (line 40) | def _add_done_to_array_obs(observation: T, done: bool) -> Dict[str, Unio...
  function _add_done_to_tuple_obs (line 46) | def _add_done_to_tuple_obs(observation: Tuple, done: bool) -> Tuple:
  function _add_done_to_dict_obs (line 51) | def _add_done_to_dict_obs(observation: Dict[K, V], done: bool) -> Dict[K...
  function add_done_to_space (line 58) | def add_done_to_space(observation: Space, done: Space) -> Space:
  function _add_done_to_box_space (line 74) | def _add_done_to_box_space(observation: Space, done: Space) -> spaces.Dict:
  function _add_done_to_tuple_space (line 83) | def _add_done_to_tuple_space(observation: spaces.Tuple, done: Space) -> ...
  function _add_done_to_dict_space (line 93) | def _add_done_to_dict_space(observation: spaces.Dict, done: Space) -> sp...
  class AddDoneToObservation (line 100) | class AddDoneToObservation(IterableWrapper):
    method __init__ (line 113) | def __init__(self, env: gym.Env, done_space: Space = None):
    method reset (line 124) | def reset(self, **kwargs):
    method step (line 132) | def step(self, action):

FILE: sequoia/common/gym_wrappers/add_info.py
  function add_info (line 22) | def add_info(observation, info):
  function _add_info_to_array_obs (line 39) | def _add_info_to_array_obs(observation: np.ndarray, info: Info) -> Tuple...
  function _add_info_to_tuple_obs (line 44) | def _add_info_to_tuple_obs(observation: Tuple, info: Info) -> Tuple:
  function _add_info_to_dict_obs (line 49) | def _add_info_to_dict_obs(observation: Dict[K, V], info: Info) -> Dict[K...
  function add_info_to_space (line 56) | def add_info_to_space(observation: Space, info: Space) -> Space:
  function _add_info_to_box_space (line 66) | def _add_info_to_box_space(observation: spaces.Box, info: Space) -> spac...
  function _add_info_to_tuple_space (line 76) | def _add_info_to_tuple_space(observation: spaces.Tuple, info: Space) -> ...
  function _add_info_to_dict_space (line 86) | def _add_info_to_dict_space(observation: spaces.Dict, info: Space) -> sp...
  class AddInfoToObservation (line 93) | class AddInfoToObservation(IterableWrapper):
    method __init__ (line 101) | def __init__(self, env: gym.Env, info_space: spaces.Space = None):
    method reset (line 114) | def reset(self, **kwargs):
    method step (line 122) | def step(self, action):

FILE: sequoia/common/gym_wrappers/convert_tensors.py
  function to_tensor (line 23) | def to_tensor(v, device: torch.device = None) -> Union[Tensor, Any]:
  function _ (line 43) | def _(
  function _ (line 53) | def _(v: Dict, device: torch.device = None) -> Dict:
  class ConvertToFromTensors (line 65) | class ConvertToFromTensors(IterableWrapper):
    method __init__ (line 81) | def __init__(self, env: gym.Env, device: Union[torch.device, str] = No...
    method reset (line 101) | def reset(self, *args, **kwargs):
    method observation (line 105) | def observation(self, observation):
    method action (line 108) | def action(self, action):
    method reward (line 117) | def reward(self, reward):
    method step (line 120) | def step(self, action):
  function supports_tensors (line 136) | def supports_tensors(space: S) -> bool:
  function has_tensor_support (line 141) | def has_tensor_support(space: S) -> bool:
  function _mark_supports_tensors (line 145) | def _mark_supports_tensors(space: S) -> None:
  function add_tensor_support (line 151) | def add_tensor_support(space: S, device: torch.device = None) -> S:
  function _ (line 187) | def _(space: Image, device: torch.device = None) -> Image:
  function _ (line 195) | def _(space: spaces.Dict, device: torch.device = None) -> spaces.Dict:
  function _ (line 205) | def _(space: TypedDictSpace, device: torch.device = None) -> TypedDictSp...
  function _ (line 215) | def _(space: Dict, device: torch.device = None) -> Dict:
  function _ (line 225) | def _(space: Dict, device: torch.device = None) -> Dict:
  function _ (line 236) | def _(space: spaces.Box, device: torch.device = None) -> spaces.Box:
  function _ (line 243) | def _(space: spaces.Discrete, device: torch.device = None) -> spaces.Box:
  function _ (line 250) | def _(space: spaces.MultiDiscrete, device: torch.device = None) -> space...

FILE: sequoia/common/gym_wrappers/convert_tensors_test.py
  function test_convert_tensors_wrapper (line 26) | def test_convert_tensors_wrapper(device: Union[str, torch.device]):
  class Foo (line 55) | class Foo(Batch):
  function test_preserves_dtype_of_namedtuple_space (line 60) | def test_preserves_dtype_of_namedtuple_space():
  function test_preserves_dtype_of_typeddict_space (line 71) | def test_preserves_dtype_of_typeddict_space():

FILE: sequoia/common/gym_wrappers/env_dataset.py
  class EnvDataset (line 24) | class EnvDataset(
    method __init__ (line 38) | def __init__(
    method reset_counters (line 97) | def reset_counters(self):
    method observation (line 103) | def observation(self, observation):
    method action (line 106) | def action(self, action):
    method reward (line 109) | def reward(self, reward):
    method step (line 112) | def step(self, action) -> StepResult:
    method __next__ (line 142) | def __next__(
    method send (line 191) | def send(self, action: ActionType) -> RewardType:
    method __iter__ (line 203) | def __iter__(self) -> Iterator[ObservationType]:
    method reached_step_limit (line 294) | def reached_step_limit(self) -> bool:
    method reached_episode_limit (line 300) | def reached_episode_limit(self) -> bool:
    method reached_episode_length_limit (line 306) | def reached_episode_length_limit(self) -> bool:
    method done_is_true (line 312) | def done_is_true(self) -> bool:
    method reset (line 340) | def reset(self, **kwargs) -> ObservationType:
    method close (line 348) | def close(self) -> None:
    method __add__ (line 365) | def __add__(self, other):

FILE: sequoia/common/gym_wrappers/env_dataset_test.py
  class TestEnvDataset (line 17) | class TestEnvDataset:
    method dummy_env_fn (line 24) | def dummy_env_fn(self):
    method test_step_normally_works_fine (line 27) | def test_step_normally_works_fine(self, dummy_env_fn: Type[DummyEnviro...
    method test_iterating_with_send (line 57) | def test_iterating_with_send(self, dummy_env_fn: Type[DummyEnvironment]):
    method test_raise_error_when_missing_action (line 85) | def test_raise_error_when_missing_action(self, dummy_env_fn: Type[Dumm...
    method test_doesnt_raise_error_when_action_sent (line 95) | def test_doesnt_raise_error_when_action_sent(self, dummy_env_fn: Type[...
    method test_max_episodes (line 105) | def test_max_episodes(self):
    method test_max_steps (line 127) | def test_max_steps(self):
    method test_max_steps_per_episode (line 158) | def test_max_steps_per_episode(self):
    method test_not_setting_max_steps_per_episode_with_vector_env_raises_warning (line 178) | def test_not_setting_max_steps_per_episode_with_vector_env_raises_warn...
    method test_observation_wrapper_applies_to_yielded_objects (line 192) | def test_observation_wrapper_applies_to_yielded_objects(self):
    method test_iteration_with_more_than_one_wrapper (line 250) | def test_iteration_with_more_than_one_wrapper(self):

FILE: sequoia/common/gym_wrappers/episode_limit.py
  class EpisodeCounter (line 18) | class EpisodeCounter(IterableWrapper):
    method __init__ (line 26) | def __init__(self, env: gym.Env):
    method episode_count (line 34) | def episode_count(self) -> int:
    method reset (line 37) | def reset(self):
    method step (line 61) | def step(self, action):
  class EpisodeLimit (line 76) | class EpisodeLimit(EpisodeCounter):
    method __init__ (line 84) | def __init__(self, env: gym.Env, max_episodes: int):
    method max_episodes (line 89) | def max_episodes(self) -> int:
    method closed_error_message (line 92) | def closed_error_message(self) -> str:
    method reset (line 101) | def reset(self):
    method __iter__ (line 126) | def __iter__(self):
    method step (line 129) | def step(self, action):

FILE: sequoia/common/gym_wrappers/episode_limit_test.py
  function test_basics (line 15) | def test_basics():
  function test_episode_limit_with_single_env (line 43) | def test_episode_limit_with_single_env(env_name: str):
  function test_episode_limit_with_single_env_dataset (line 84) | def test_episode_limit_with_single_env_dataset(env_name: str):
  function test_episode_limit_with_vectorized_env (line 115) | def test_episode_limit_with_vectorized_env(batch_size):
  function test_episode_limit_with_vectorized_env_dataset (line 164) | def test_episode_limit_with_vectorized_env_dataset(batch_size):
  function test_reset_vectorenv_with_unfinished_episodes_raises_warning (line 207) | def test_reset_vectorenv_with_unfinished_episodes_raises_warning(batch_s...

FILE: sequoia/common/gym_wrappers/measure_performance.py
  class MeasurePerformanceWrapper (line 14) | class MeasurePerformanceWrapper(IterableWrapper[EnvType], Generic[EnvTyp...
    method __init__ (line 15) | def __init__(self, env: Environment):
    method get_online_performance (line 19) | def get_online_performance(self) -> Dict[int, List[MetricsType]]:
    method get_average_online_performance (line 29) | def get_average_online_performance(self) -> Optional[MetricsType]:

FILE: sequoia/common/gym_wrappers/multi_task_environment.py
  function make_env_attributes_task (line 30) | def make_env_attributes_task(
  function add_task_labels (line 72) | def add_task_labels(observation: Any, task_labels: Any) -> Any:
  function _add_task_labels_to_single_obs (line 80) | def _add_task_labels_to_single_obs(observation: X, task_labels: T) -> Tu...
  function _add_task_labels_to_batch (line 92) | def _add_task_labels_to_batch(observation: Batch, task_labels: T) -> Batch:
  function _add_task_labels_to_space (line 100) | def _add_task_labels_to_space(observation: spaces.Space, task_labels: T)...
  function _add_task_labels_to_namedtuple (line 112) | def _add_task_labels_to_namedtuple(
  function _add_task_labels_to_tuple (line 123) | def _add_task_labels_to_tuple(observation: Tuple, task_labels: T) -> Tuple:
  function _add_task_labels_to_dict_space (line 128) | def _add_task_labels_to_dict_space(observation: spaces.Dict, task_labels...
  function _add_task_labels_to_typed_dict_space (line 136) | def _add_task_labels_to_typed_dict_space(
  function _add_task_labels_to_dict (line 149) | def _add_task_labels_to_dict(observation: Dict[str, V], task_labels: T) ...
  class MultiTaskEnvironment (line 157) | class MultiTaskEnvironment(MayCloseEarly):
    method __init__ (line 194) | def __init__(
    method current_task_id (line 300) | def current_task_id(self) -> int:
    method current_task_id (line 315) | def current_task_id(self, value: int) -> None:
    method set_on_task_switch_callback (line 318) | def set_on_task_switch_callback(self, callback: Callable[[int], None])...
    method on_task_switch (line 321) | def on_task_switch(self, task_id: int):
    method step (line 329) | def step(self, *args, **kwargs):
    method reset (line 358) | def reset(self, new_random_task: bool = None, **kwargs):
    method steps (line 394) | def steps(self) -> int:
    method steps (line 398) | def steps(self, value: int) -> None:
    method current_task (line 412) | def current_task(self) -> Dict[str, Any]:
    method current_task (line 439) | def current_task(self, task: Union[Dict[str, float], Sequence[float], ...
    method random_task (line 477) | def random_task(self) -> Dict:
    method update_task (line 512) | def update_task(self, values: Dict = None, **kwargs):
    method seed (line 532) | def seed(self, seed: Optional[int] = None) -> List[int]:
    method task_dict (line 538) | def task_dict(self, task_array: np.ndarray) -> Dict[str, float]:
    method task_schedule (line 545) | def task_schedule(self) -> Dict:
    method task_schedule (line 549) | def task_schedule(self, value: Dict[str, Any]):

FILE: sequoia/common/gym_wrappers/multi_task_environment_test.py
  function test_task_schedule (line 20) | def test_task_schedule():
  function test_multi_task (line 52) | def test_multi_task(environment_name: str):
  function test_monitor_env (line 72) | def test_monitor_env(environment_name):
  function test_update_task (line 106) | def test_update_task():
  function test_add_task_dict_to_info (line 126) | def test_add_task_dict_to_info():
  function test_add_task_id_to_obs (line 166) | def test_add_task_id_to_obs():
  function test_starting_step_and_max_step (line 219) | def test_starting_step_and_max_step():
  function test_task_id_is_added_even_when_no_known_task_schedule (line 284) | def test_task_id_is_added_even_when_no_known_task_schedule():
  function test_task_schedule_monsterkong (line 315) | def test_task_schedule_monsterkong():
  function test_task_schedule_with_callables (line 363) | def test_task_schedule_with_callables():
  function test_random_task_on_each_episode (line 413) | def test_random_task_on_each_episode():
  function test_random_task_on_each_episode_and_only_one_task_in_schedule (line 454) | def test_random_task_on_each_episode_and_only_one_task_in_schedule():
  function env_fn_monsterkong (line 486) | def env_fn_monsterkong() -> gym.Env:
  function env_fn_cartpole (line 504) | def env_fn_cartpole() -> gym.Env:
  function test_task_sequence_is_reproducible (line 523) | def test_task_sequence_is_reproducible(env_id: str):
  function test_iteration (line 576) | def test_iteration():

FILE: sequoia/common/gym_wrappers/observation_limit.py
  class ObservationLimit (line 14) | class ObservationLimit(IterableWrapper):
    method __init__ (line 22) | def __init__(self, env: gym.Env, max_steps: int):
    method reset (line 29) | def reset(self):
    method is_closed (line 49) | def is_closed(self) -> bool:
    method step (line 52) | def step(self, action):
    method close (line 71) | def close(self):

FILE: sequoia/common/gym_wrappers/observation_limit_test.py
  function test_step_limit_with_single_env (line 14) | def test_step_limit_with_single_env(env_name: str):
  function test_step_limit_with_single_env_dataset (line 40) | def test_step_limit_with_single_env_dataset(env_name: str):
  function test_step_limit_with_vectorized_env (line 71) | def test_step_limit_with_vectorized_env(batch_size):
  function test_step_limit_with_vectorized_env_partial_final_batch (line 99) | def test_step_limit_with_vectorized_env_partial_final_batch(batch_size):

FILE: sequoia/common/gym_wrappers/pixel_observation.py
  class PixelObservationWrapper (line 13) | class PixelObservationWrapper(PixelObservationWrapper_):
    method __init__ (line 24) | def __init__(self, env: Union[str, gym.Env]):
    method step (line 41) | def step(self, *args, **kwargs):
    method reset (line 47) | def reset(self, *args, **kwargs):
    method render (line 52) | def render(self, mode: str = "human", **kwargs):
    method to_array (line 57) | def to_array(self, image) -> np.ndarray:
  class ImageObservations (line 70) | class ImageObservations(IterableWrapper):
    method __init__ (line 71) | def __init__(self, env: gym.Env):

FILE: sequoia/common/gym_wrappers/pixel_observation_test.py
  function test_passing_string_to_constructor (line 10) | def test_passing_string_to_constructor():
  function test_observation_space (line 15) | def test_observation_space():
  function test_reset_gives_pixels (line 20) | def test_reset_gives_pixels():
  function test_step_obs_is_pixels (line 27) | def test_step_obs_is_pixels():
  function test_state_attribute_is_pixels (line 35) | def test_state_attribute_is_pixels():
  function test_render_rgb_array (line 42) | def test_render_rgb_array():
  function test_render_with_human_mode (line 54) | def test_render_with_human_mode():
  function test_render_with_human_mode_with_env_dataset (line 66) | def test_render_with_human_mode_with_env_dataset():

FILE: sequoia/common/gym_wrappers/policy_env.py
  class Environment (line 27) | class Environment(gym.Env, Generic[ObservationType, ActionType, RewardTy...
    method step (line 28) | def step(self, action: ActionType) -> Tuple[ObservationType, RewardTyp...
    method reset (line 31) | def reset(self) -> ObservationType:
  class StateTransition (line 53) | class StateTransition(Batch, Generic[ObservationType, ActionType]):
    method state (line 62) | def state(self) -> ObservationType:
    method next_state (line 66) | def next_state(self) -> ObservationType:
  function default_dataset_item_creator (line 74) | def default_dataset_item_creator(
  class PolicyEnv (line 113) | class PolicyEnv(gym.Wrapper, IterableDataset, Iterable[DatasetItem]):
    method __init__ (line 128) | def __init__(
    method set_policy (line 146) | def set_policy(self, policy: Callable[[ObservationType, gym.Space], Ac...
    method step (line 150) | def step(self, action: Optional[Any] = None) -> StepResult:
    method close (line 168) | def close(self) -> None:
    method reset (line 174) | def reset(self, *args, **kwargs) -> None:
    method __iter__ (line 180) | def __iter__(self) -> Iterator[DatasetItem]:

FILE: sequoia/common/gym_wrappers/policy_env_test.py
  function test_iterating_with_policy (line 8) | def test_iterating_with_policy():

FILE: sequoia/common/gym_wrappers/smooth_environment.py
  class SmoothTransitions (line 26) | class SmoothTransitions(MultiTaskEnvironment):
    method __init__ (line 51) | def __init__(
    method step (line 136) | def step(self, *args, **kwargs):
    method reset (line 142) | def reset(self, **kwargs):
    method current_task_id (line 149) | def current_task_id(self) -> Optional[int]:
    method task_array (line 156) | def task_array(self, task: Dict[str, float]) -> np.ndarray:
    method smooth_update (line 159) | def smooth_update(self) -> None:

FILE: sequoia/common/gym_wrappers/smooth_environment_test.py
  function test_task_schedule (line 10) | def test_task_schedule():
  function test_update_only_on_reset (line 64) | def test_update_only_on_reset():
  function test_task_id_is_always_None (line 92) | def test_task_id_is_always_None():

FILE: sequoia/common/gym_wrappers/step_callback_wrapper.py
  class Callback (line 11) | class Callback(Callable[[int, gym.Env], None], ABC):
    method __call__ (line 13) | def __call__(self, step: int, env: gym.Env, step_results: Tuple) -> None:
  class StepCallback (line 17) | class StepCallback(Callback, ABC):
    method __init__ (line 18) | def __init__(self, step: int, func: Callable[[int, gym.Env, Tuple], No...
    method __call__ (line 22) | def __call__(self, step: int, env: gym.Env, step_results: Tuple) -> None:
  class PeriodicCallback (line 28) | class PeriodicCallback(Callback):
    method __init__ (line 29) | def __init__(self, period: int, offset: int = 0, func: Callable[[int, ...
    method __call__ (line 34) | def __call__(self, step: int, env: gym.Env, step_results: Tuple) -> None:
  class StepCallbackWrapper (line 40) | class StepCallbackWrapper(IterableWrapper):
    method __init__ (line 43) | def __init__(
    method add_callback (line 52) | def add_callback(self, callback: Union[Callback]) -> None:
    method add_step_callback (line 55) | def add_step_callback(self, step: int, callback: Callable[[int, gym.En...
    method add_periodic_callback (line 62) | def add_periodic_callback(self, period: int, callback: StepCallback, o...
    method step (line 70) | def step(self, action):

FILE: sequoia/common/gym_wrappers/step_callback_wrapper_test.py
  function increment_i (line 10) | def increment_i(step: int, env: gym.Env, step_results: Tuple):
  function decrement_i (line 16) | def decrement_i(step: int, env: gym.Env, step_results: Tuple):
  function test_step_callback (line 22) | def test_step_callback():
  function test_periodic_callback (line 40) | def test_periodic_callback():

FILE: sequoia/common/gym_wrappers/transform_wrappers.py
  class TransformObservation (line 22) | class TransformObservation(TransformObservation_, IterableWrapper):
    method __init__ (line 23) | def __init__(self, env: gym.Env, f: Union[Callable, Compose]):
    method __call__ (line 39) | def __call__(self, *args, **kwargs):
    method __iter__ (line 42) | def __iter__(self):
  class TransformReward (line 52) | class TransformReward(TransformReward_, IterableWrapper):
    method __init__ (line 53) | def __init__(self, env: gym.Env, f: Union[Callable, Compose]):
  class TransformAction (line 80) | class TransformAction(IterableWrapper):
    method __init__ (line 81) | def __init__(self, env: gym.Env, f: Callable[[Union[gym.Env, Space]], ...
    method step (line 93) | def step(self, action):
    method action (line 96) | def action(self, action):

FILE: sequoia/common/gym_wrappers/transform_wrappers_test.py
  function test_compose_on_image_space (line 12) | def test_compose_on_image_space():
  function test_move_wrapper_and_iteration (line 34) | def test_move_wrapper_and_iteration():

FILE: sequoia/common/gym_wrappers/utils.py
  function is_classic_control_env (line 54) | def is_classic_control_env(env: Union[str, gym.Env, Type[gym.Env]]) -> b...
  function is_proxy_to (line 108) | def is_proxy_to(env, env_type_or_types: Union[Type[gym.Env], Tuple[Type[...
  function is_atari_env (line 117) | def is_atari_env(env: Union[str, gym.Env]) -> bool:
  function get_env_class (line 198) | def get_env_class(env: Union[str, gym.Env, Type[gym.Env], Callable[[], g...
  function is_monsterkong_env (line 214) | def is_monsterkong_env(env: Union[str, gym.Env, Callable[[], gym.Env]]) ...
  class StepResult (line 237) | class StepResult(NamedTuple):
  function has_wrapper (line 244) | def has_wrapper(
  class MayCloseEarly (line 265) | class MayCloseEarly(gym.Wrapper, ABC):
    method __init__ (line 272) | def __init__(self, env: gym.Env):
    method is_closed (line 276) | def is_closed(self) -> bool:
    method closed_error_message (line 284) | def closed_error_message(self) -> str:
    method reset (line 293) | def reset(self, **kwargs):
    method step (line 300) | def step(self, action):
    method close (line 307) | def close(self) -> None:
  class IterableWrapper (line 319) | class IterableWrapper(MayCloseEarly, IterableDataset, Generic[EnvType], ...
    method __init__ (line 332) | def __init__(self, env: gym.Env):
    method is_vectorized (line 339) | def is_vectorized(self) -> bool:
    method __next__ (line 343) | def __next__(self):
    method observation (line 366) | def observation(self, observation):
    method action (line 370) | def action(self, action):
    method reward (line 373) | def reward(self, reward):
    method get_length (line 379) | def get_length(self) -> Optional[int]:
    method send (line 409) | def send(self, action):
    method __iter__ (line 446) | def __iter__(self) -> Iterator:
  class RenderEnvWrapper (line 583) | class RenderEnvWrapper(IterableWrapper):
    method __init__ (line 586) | def __init__(self, env: gym.Env, display: Any = None):
    method step (line 590) | def step(self, action):
  function tile_images (line 595) | def tile_images(img_nhwc):

FILE: sequoia/common/gym_wrappers/utils_test.py
  function test_has_wrapper (line 27) | def test_has_wrapper(env, wrapper_type, result):

FILE: sequoia/common/layers.py
  class Lambda (line 16) | class Lambda(nn.Module):
    method __init__ (line 17) | def __init__(self, func: Callable):
    method forward (line 21) | def forward(self, x):
  class Reshape (line 25) | class Reshape(nn.Module):
    method __init__ (line 26) | def __init__(self, target_shape: Union[List[int], Tuple[int, ...]]):
    method forward (line 30) | def forward(self, inputs):
  class ConvBlock (line 34) | class ConvBlock(nn.Module):
    method __init__ (line 35) | def __init__(
    method forward (line 53) | def forward(self, x):
  class DeConvBlock (line 60) | class DeConvBlock(nn.Module):
    method __init__ (line 71) | def __init__(
    method forward (line 106) | def forward(self, x):
  function n_output_features (line 118) | def n_output_features(
  class Conv2d (line 125) | class Conv2d(nn.Conv2d):
    method forward (line 127) | def forward(self, input: Union[Image, Tensor]) -> Union[Tensor, Image]:
    method _ (line 131) | def _(self, input: Image) -> Image:
  class MaxPool2d (line 167) | class MaxPool2d(nn.MaxPool2d):
    method forward (line 169) | def forward(self, input: Union[Image, Tensor]) -> Union[Tensor, Image]:
    method _ (line 173) | def _(self, input: Image) -> Image:
  class Sequential (line 205) | class Sequential(nn.Sequential):
    method forward (line 211) | def forward(self, input):

FILE: sequoia/common/loss.py
  class Loss (line 67) | class Loss(Serializable, MappingABC):
    method __post_init__ (line 102) | def __post_init__(
    method __contains__ (line 137) | def __contains__(self, key: str) -> bool:
    method __getitem__ (line 142) | def __getitem__(self, key: str) -> Any:
    method __iter__ (line 147) | def __iter__(self) -> Iterable[str]:
    method __len__ (line 150) | def __len__(self) -> int:
    method total_loss (line 154) | def total_loss(self) -> Tensor:
    method requires_grad (line 158) | def requires_grad(self) -> bool:
    method backward (line 162) | def backward(self, *args, **kwargs):
    method metric (line 167) | def metric(self) -> Optional[Metrics]:
    method metric (line 176) | def metric(self, value: Metrics) -> None:
    method accuracy (line 188) | def accuracy(self) -> float:
    method mse (line 193) | def mse(self) -> Tensor:
    method __add__ (line 197) | def __add__(self, other: Union["Loss", Any]) -> "Loss":
    method __iadd__ (line 245) | def __iadd__(self, other: Union["Loss", Any]) -> "Loss":
    method __radd__ (line 276) | def __radd__(self, other: Any):
    method __mul__ (line 291) | def __mul__(self, factor: Union[float, Tensor]) -> "Loss":
    method __rmul__ (line 309) | def __rmul__(self, factor: Union[float, Tensor]) -> "Loss":
    method __truediv__ (line 313) | def __truediv__(self, coefficient: Union[float, Tensor]) -> "Loss":
    method unscaled_losses (line 317) | def unscaled_losses(self):
    method to_log_dict (line 324) | def to_log_dict(self, verbose: bool = False) -> Dict[str, Union[str, f...
    method to_pbar_message (line 370) | def to_pbar_message(self) -> Dict[str, float]:
    method clear_tensors (line 389) | def clear_tensors(self) -> None:
    method absorb (line 401) | def absorb(self, other: "Loss") -> None:
    method all_metrics (line 417) | def all_metrics(self) -> Dict[str, Metrics]:

FILE: sequoia/common/loss_test.py
  function test_demo (line 7) | def test_demo():
  function test_all_metrics (line 24) | def test_all_metrics():
  function test_to_log_dict_order (line 35) | def test_to_log_dict_order():

FILE: sequoia/common/metrics/classification.py
  class ClassificationMetrics (line 39) | class ClassificationMetrics(Metrics):
    method __post_init__ (line 59) | def __post_init__(
    method objective_name (line 87) | def objective_name(self) -> str:
    method __add__ (line 90) | def __add__(self, other: "ClassificationMetrics") -> "ClassificationMe...
    method to_log_dict (line 113) | def to_log_dict(self, verbose=False):
    method to_pbar_message (line 127) | def to_pbar_message(self) -> Dict[str, Union[str, float]]:
    method detach (line 132) | def detach(self) -> "ClassificationMetrics":
    method to (line 140) | def to(self, device: Union[str, torch.device]) -> "ClassificationMetri...
    method objective (line 150) | def objective(self) -> float:

FILE: sequoia/common/metrics/classification_test.py
  function test_classification_metrics_add_properly (line 8) | def test_classification_metrics_add_properly():
  function test_metrics_from_tensors (line 55) | def test_metrics_from_tensors():

FILE: sequoia/common/metrics/get_metrics.py
  function to_optional_tensor (line 22) | def to_optional_tensor(x: Optional[Union[Tensor, np.ndarray, List]]) -> ...
  function get_metrics (line 28) | def get_metrics(

FILE: sequoia/common/metrics/metrics.py
  class Metrics (line 19) | class Metrics(Serializable):
    method __post_init__ (line 26) | def __post_init__(self, **tensors):
    method __add__ (line 43) | def __add__(self, other):
    method __radd__ (line 48) | def __radd__(self, other):
    method __mul__ (line 58) | def __mul__(self, factor: Union[float, Tensor]) -> "Loss":
    method __rmul__ (line 63) | def __rmul__(self, factor: Union[float, Tensor]) -> "Loss":
    method __truediv__ (line 68) | def __truediv__(self, coefficient: Union[float, Tensor]) -> "Metrics":
    method to_log_dict (line 73) | def to_log_dict(self, verbose: bool = False) -> Dict:
    method to_pbar_message (line 104) | def to_pbar_message(self) -> Dict[str, Union[str, float]]:
    method numpy (line 107) | def numpy(self):
    method objective (line 120) | def objective(self) -> float:
    method objective_name (line 132) | def objective_name(self) -> str:

FILE: sequoia/common/metrics/metrics_utils.py
  function get_confusion_matrix (line 10) | def get_confusion_matrix(
  function accuracy (line 59) | def accuracy(y_pred: Union[Tensor, np.ndarray], y: Union[Tensor, np.ndar...
  function get_accuracy (line 68) | def get_accuracy(confusion_matrix: Union[Tensor, np.ndarray]) -> float:
  function class_accuracy (line 77) | def class_accuracy(y_pred: Tensor, y: Tensor) -> Tensor:
  function get_class_accuracy (line 83) | def get_class_accuracy(confusion_matrix: Tensor) -> Tensor:

FILE: sequoia/common/metrics/metrics_utils_test.py
  function test_accuracy (line 7) | def test_accuracy():
  function test_per_class_accuracy_perfect (line 25) | def test_per_class_accuracy_perfect():
  function test_per_class_accuracy_zero (line 47) | def test_per_class_accuracy_zero():
  function test_confusion_matrix (line 69) | def test_confusion_matrix():
  function test_per_class_accuracy_realistic (line 95) | def test_per_class_accuracy_realistic():

FILE: sequoia/common/metrics/regression.py
  class RegressionMetrics (line 24) | class RegressionMetrics(Metrics):
    method __post_init__ (line 35) | def __post_init__(
    method objective (line 55) | def objective(self) -> float:
    method __add__ (line 58) | def __add__(self, other: "RegressionMetrics") -> "RegressionMetrics":
    method to_pbar_message (line 81) | def to_pbar_message(self) -> Dict[str, Union[str, float]]:
    method to_log_dict (line 87) | def to_log_dict(self, verbose=False):
    method __mul__ (line 93) | def __mul__(self, factor: Union[float, Tensor]) -> "Loss":
    method __rmul__ (line 101) | def __rmul__(self, factor: Union[float, Tensor]) -> "Loss":
    method __truediv__ (line 106) | def __truediv__(self, coefficient: Union[float, Tensor]) -> "Regressio...
    method __lt__ (line 114) | def __lt__(self, other: Union["RegressionMetrics", Any]) -> bool:
    method __ge__ (line 119) | def __ge__(self, other: Union["RegressionMetrics", Any]) -> bool:

FILE: sequoia/common/metrics/rl_metrics.py
  class EpisodeMetrics (line 8) | class EpisodeMetrics(Metrics):
    method n_episodes (line 21) | def n_episodes(self) -> int:
    method objective_name (line 25) | def objective_name(self) -> str:
    method mean_reward_per_step (line 36) | def mean_reward_per_step(self) -> float:
    method __add__ (line 39) | def __add__(self, other: Union["EpisodeMetrics", Any]):
    method total_reward (line 65) | def total_reward(self) -> float:
    method total_steps (line 69) | def total_steps(self) -> int:
    method to_pbar_message (line 72) | def to_pbar_message(self) -> Dict[str, Union[str, float]]:
    method objective (line 76) | def objective(self) -> float:
    method to_log_dict (line 79) | def to_log_dict(self, verbose: bool = False):
    method episodes (line 96) | def episodes(self) -> int:
    method mean_reward_per_episode (line 100) | def mean_reward_per_episode(self) -> float:
  class GradientUsageMetric (line 135) | class GradientUsageMetric(Metrics):
    method __post_init__ (line 144) | def __post_init__(self):
    method __add__ (line 149) | def __add__(self, other: Union["GradientUsageMetric", Any]) -> "Gradie...
    method to_pbar_message (line 157) | def to_pbar_message(self) -> Dict[str, Union[str, float]]:

FILE: sequoia/common/replay.py
  class ReplayBuffer (line 22) | class ReplayBuffer(deque, Deque[T], Pickleable):
    method __init__ (line 29) | def __init__(self, capacity: int):
    method as_dataset (line 38) | def as_dataset(self) -> TensorDataset:
    method _push_and_sample (line 42) | def _push_and_sample(self, *values: T, size: int) -> List[T]:
    method _sample (line 64) | def _sample(self, size: int) -> List[T]:
    method full (line 71) | def full(self) -> bool:
  class UnlabeledReplayBuffer (line 75) | class UnlabeledReplayBuffer(ReplayBuffer[Tensor]):
    method sample_batch (line 76) | def sample_batch(self, size: int) -> Tensor:
    method push (line 80) | def push(self, x_batch: Tensor, y_batch: Tensor = None) -> None:
    method push_and_sample (line 83) | def push_and_sample(self, x_batch: Tensor, y_batch: Tensor = None, siz...
  class LabeledReplayBuffer (line 88) | class LabeledReplayBuffer(ReplayBuffer[Tuple[Tensor, Tensor]]):
    method sample (line 89) | def sample(self, size: int) -> Tuple[Tensor, Tensor]:
    method push (line 94) | def push(self, x_batch: Tensor, y_batch: Tensor) -> None:
    method push_and_sample (line 97) | def push_and_sample(
    method samples_per_class (line 105) | def samples_per_class(self) -> Dict[int, int]:
  class SemiSupervisedReplayBuffer (line 111) | class SemiSupervisedReplayBuffer(object):
    method __init__ (line 112) | def __init__(self, labeled_capacity: int, unlabeled_capacity: int = 0):
    method sample (line 133) | def sample(self, size: int) -> Tuple[Tensor, Tensor]:
    method sample_unlabeled (line 148) | def sample_unlabeled(self, size: int, take_from_labeled_buffer_first: ...
    method push_and_sample (line 205) | def push_and_sample(self, x: Tensor, y: Tensor, size: int = None) -> T...
    method push_and_sample_unlabeled (line 210) | def push_and_sample_unlabeled(self, x: Tensor, y: Tensor = None, size:...
    method clear (line 216) | def clear(self):
  class ReplayOptions (line 222) | class ReplayOptions(Serializable):
    method enabled (line 237) | def enabled(self) -> bool:

FILE: sequoia/common/spaces/image.py
  function could_become_image (line 14) | def could_become_image(space: spaces.Space) -> bool:
  class Image (line 23) | class Image(spaces.Box, Space[T]):
    method __init__ (line 30) | def __init__(
    method channels (line 80) | def channels(self) -> int:
    method height (line 84) | def height(self) -> int:
    method width (line 88) | def width(self) -> int:
    method batch_size (line 92) | def batch_size(self) -> Optional[int]:
    method from_box (line 96) | def from_box(cls, box_space: spaces.Box):
    method wrap (line 100) | def wrap(cls, space: Union["Image", spaces.Box]):
    method channels_last (line 108) | def channels_last(self) -> bool:
    method __repr__ (line 111) | def __repr__(self):
    method sample (line 114) | def sample(self) -> T:
  class ImageTensorSpace (line 118) | class ImageTensorSpace(Image, TensorBox):
    method from_box (line 120) | def from_box(cls, box_space: TensorBox, device: torch.device = None):
    method __repr__ (line 124) | def __repr__(self):
    method sample (line 127) | def sample(self):
  function _batch_image_space (line 143) | def _batch_image_space(space: Image, n: int = 1) -> Union[Image, spaces....

FILE: sequoia/common/spaces/named_tuple.py
  class NamedTupleSpace (line 14) | class NamedTupleSpace(spaces.Tuple):
    method __init__ (line 28) | def __init__(
    method __getitem__ (line 61) | def __getitem__(self, index: Union[int, str]) -> Space:
    method __getattr__ (line 66) | def __getattr__(self, attr: str) -> Space:
    method __repr__ (line 73) | def __repr__(self):
    method _replace (line 83) | def _replace(self, **kwargs):
    method __eq__ (line 92) | def __eq__(self, other: Union["NamedTupleSpace", Any]) -> bool:
    method sample (line 95) | def sample(self):
    method contains (line 98) | def contains(self, x) -> bool:
    method keys (line 107) | def keys(self) -> List[str]:
    method values (line 110) | def values(self) -> List[Space]:
    method items (line 113) | def items(self) -> Iterable[Tuple[str, Space]]:
  function __eq__ (line 118) | def __eq__(self, other: Union["NamedTupleSpace", Any]) -> bool:
  function batch_namedtuple_space (line 133) | def batch_namedtuple_space(space: NamedTupleSpace, n: int = 1):
  function flatten_namedtuple_space_sample (line 140) | def flatten_namedtuple_space_sample(space: NamedTupleSpace, x: NamedTuple):

FILE: sequoia/common/spaces/named_tuple_test.py
  function test_basic (line 14) | def test_basic():
  class StateTransition (line 37) | class StateTransition(NamedTuple):
  function test_basic_with_dtype (line 43) | def test_basic_with_dtype():
  function test_isinstance_namedtuple (line 66) | def test_isinstance_namedtuple():
  function test_equals_tuple_space_with_same_items (line 77) | def test_equals_tuple_space_with_same_items():
  function test_batch_objets_considered_valid_samples (line 98) | def test_batch_objets_considered_valid_samples():
  function test_batch_space (line 127) | def test_batch_space():

FILE: sequoia/common/spaces/space.py
  class Space (line 9) | class Space(_Space, Generic[T]):
    method sample (line 10) | def sample(self) -> T:
    method __contains__ (line 13) | def __contains__(self, x: Union[T, Any]) -> bool:
    method contains (line 16) | def contains(self, v: Union[T, Any]) -> bool:

FILE: sequoia/common/spaces/sparse.py
  class Sparse (line 28) | class Sparse(Space[Optional[T]]):
    method __init__ (line 37) | def __init__(self, base: Space[T], sparsity: float = 0.0):
    method sparsity (line 47) | def sparsity(self) -> float:
    method seed (line 53) | def seed(self, seed=None):
    method sample (line 57) | def sample(self) -> Optional[T]:
    method contains (line 68) | def contains(self, x: Union[Optional[T], Any]) -> bool:
    method __repr__ (line 75) | def __repr__(self):
    method __eq__ (line 78) | def __eq__(self, other: Any):
    method to_jsonable (line 83) | def to_jsonable(self, sample_n):
    method from_jsonable (line 92) | def from_jsonable(self, sample_n):
  function _is_singledispatch (line 110) | def _is_singledispatch(module_function):
  function register_sparse_variant (line 114) | def register_sparse_variant(module, module_fn_name: str):
  function flatdim_sparse (line 137) | def flatdim_sparse(space: Sparse) -> int:
  function flatten_sparse (line 142) | def flatten_sparse(space: Sparse[T], x: Optional[T]) -> Optional[np.ndar...
  function flatten_sparse_space (line 147) | def flatten_sparse_space(space: Sparse[T]) -> Optional[np.ndarray]:
  function unflatten_sparse (line 154) | def unflatten_sparse(space: Sparse[T], x: np.ndarray) -> Optional[T]:
  function create_empty_array_sparse (line 162) | def create_empty_array_sparse(space: Sparse, n=1, fn=np.zeros) -> np.nda...
  function create_shared_memory_for_sparse_space (line 167) | def create_shared_memory_for_sparse_space(space: Sparse, n: int = 1, ctx...
  function write_to_shared_memory (line 182) | def write_to_shared_memory(
  function read_from_shared_memory (line 213) | def read_from_shared_memory(
  function batch_sparse_space (line 237) | def batch_sparse_space(space: Sparse, n: int = 1) -> gym.Space:
  function concatenate_sparse_items (line 299) | def concatenate_sparse_items(
  function sparse_sample_to_tensor (line 330) | def sparse_sample_to_tensor(

FILE: sequoia/common/spaces/sparse_test.py
  function equals (line 33) | def equals(value, expected) -> bool:
  function is_sparse (line 53) | def is_sparse(iterable: Iterable[bool]) -> bool:
  function test_sample (line 73) | def test_sample(base_space: gym.Space):
  function test_contains (line 91) | def test_contains(base_space: gym.Space, sparsity: float):
  function test_batching_works (line 101) | def test_batching_works(base_space: gym.Space, n: int = 3):
  function test_batching_works (line 116) | def test_batching_works(base_space: gym.Space, sparsity: float, n: int =...
  function test_change_doesnt_persist_after_import (line 157) | def test_change_doesnt_persist_after_import():
  function test_change_persists_after_full_import (line 165) | def test_change_persists_after_full_import():
  function test_flatdim (line 174) | def test_flatdim(base_space: gym.Space):
  function test_flatdim (line 184) | def test_flatdim(base_space: gym.Space):
  function test_seeding_works (line 198) | def test_seeding_works(base_space: gym.Space):
  function test_flatten (line 211) | def test_flatten(base_space: gym.Space):
  function test_equality (line 225) | def test_equality(base_space: gym.Space):

FILE: sequoia/common/spaces/tensor_spaces.py
  function get_numpy_dtype_equivalent_to (line 29) | def get_numpy_dtype_equivalent_to(torch_dtype: torch.dtype) -> np.dtype:
  function get_torch_dtype_equivalent_to (line 43) | def get_torch_dtype_equivalent_to(numpy_dtype: np.dtype) -> torch.dtype:
  function is_numpy_dtype (line 61) | def is_numpy_dtype(dtype: Any) -> bool:
  function is_torch_dtype (line 65) | def is_torch_dtype(dtype: Any) -> bool:
  function supports_tensors (line 72) | def supports_tensors(space: gym.Space) -> bool:
  class TensorSpace (line 77) | class TensorSpace(gym.Space, ABC):
    method __init__ (line 82) | def __init__(self, *args, device: torch.device = None, **kwargs):
  class TensorBox (line 113) | class TensorBox(TensorSpace, spaces.Box):
    method __init__ (line 116) | def __init__(self, low, high, shape=None, dtype=np.float32, device: to...
    method sample (line 122) | def sample(self):
    method contains (line 128) | def contains(self, x: Union[list, np.ndarray, Tensor]) -> bool:
    method __repr__ (line 144) | def __repr__(self):
    method from_box (line 153) | def from_box(cls, box: spaces.Box, device: torch.device = None):
  class TensorDiscrete (line 163) | class TensorDiscrete(TensorSpace, spaces.Discrete):
    method contains (line 164) | def contains(self, v: Union[int, Tensor]) -> bool:
    method sample (line 169) | def sample(self):
  class TensorMultiDiscrete (line 176) | class TensorMultiDiscrete(TensorSpace, spaces.MultiDiscrete):
    method contains (line 177) | def contains(self, v: Tensor) -> bool:
    method sample (line 184) | def sample(self):
  function _batch_discrete_space (line 195) | def _batch_discrete_space(space: TensorDiscrete, n: int = 1) -> TensorMu...

FILE: sequoia/common/spaces/tensor_spaces_test.py
  function test_tensor_box (line 10) | def test_tensor_box(np_dtype: np.dtype):

FILE: sequoia/common/spaces/typed_dict.py
  class TypedDictSpace (line 44) | class TypedDictSpace(spaces.Dict, Space[M]):
    method __init__ (line 127) | def __init__(self, spaces: Mapping[str, Space] = None, dtype: Type[M] ...
    method keys (line 237) | def keys(self) -> Sequence[str]:
    method items (line 240) | def items(self) -> Iterable[Tuple[str, Space]]:
    method values (line 243) | def values(self) -> Sequence[Space]:
    method sample (line 246) | def sample(self) -> M:
    method __getattr__ (line 251) | def __getattr__(self, attr: str) -> Space:
    method __getitem__ (line 257) | def __getitem__(self, key: Union[str, int]) -> Space:
    method __len__ (line 266) | def __len__(self) -> int:
    method contains (line 272) | def contains(self, x: Union[M, Mapping[str, Space]]) -> bool:
    method __repr__ (line 295) | def __repr__(self) -> str:
    method __eq__ (line 303) | def __eq__(self, other):
  function _batch_typed_dict_space (line 310) | def _batch_typed_dict_space(space: TypedDictSpace, n: int = 1) -> spaces...
  function _concatenate_typed_dicts (line 318) | def _concatenate_typed_dicts(
  function _ (line 337) | def _(space: TypedDictSpace, sample: Union[T, Mapping]) -> T:
  function _ (line 347) | def _(

FILE: sequoia/common/spaces/typed_dict_test.py
  function test_basic (line 15) | def test_basic():
  function test_supports_dataclasses (line 36) | def test_supports_dataclasses():
  class StateTransition (line 64) | class StateTransition(Mapping[str, T]):
    method __post_init__ (line 69) | def __post_init__(self):
    method __len__ (line 72) | def __len__(self) -> int:
    method __getitem__ (line 75) | def __getitem__(self, attr: str) -> T:
    method __iter__ (line 80) | def __iter__(self) -> Iterable[str]:
  function test_basic_with_dtype (line 84) | def test_basic_with_dtype():
  function test_isinstance (line 106) | def test_isinstance():
  function test_equals_dict_space_with_same_items (line 117) | def test_equals_dict_space_with_same_items():
  function test_batch_objets_considered_valid_samples (line 136) | def test_batch_objets_considered_valid_samples():
  function test_batch_space (line 165) | def test_batch_space():
  function test_batch_space_preserves_dtype (line 180) | def test_batch_space_preserves_dtype():
  class DummyDictEnv (line 241) | class DummyDictEnv(gym.Env):
    method __init__ (line 242) | def __init__(self):
    method reset (line 252) | def reset(self):
    method step (line 255) | def step(self, action):
    method seed (line 258) | def seed(self, seed=None):
  function test_vector_env (line 266) | def test_vector_env():
  function test_object_with_extra_keys_fits (line 282) | def test_object_with_extra_keys_fits():
  function test_order_of_keys_is_same_in_samples (line 303) | def test_order_of_keys_is_same_in_samples():
  function test_debugging (line 316) | def test_debugging():
  function test_equality (line 327) | def test_equality():

FILE: sequoia/common/task.py
  class Task (line 15) | class Task(Serializable):

FILE: sequoia/common/transforms/channels.py
  function has_channels_last (line 22) | def has_channels_last(img_or_shape: Union[Img, Tuple[int, ...], spaces.B...
  function has_channels_first (line 30) | def has_channels_first(img_or_shape: Union[Img, Tuple[int, ...], spaces....
  function channels_last_if_needed (line 43) | def channels_last_if_needed(x: Any) -> Any:
  function channels_first_if_needed (line 51) | def channels_first_if_needed(x: Any) -> Any:
  class NamedDimensions (line 59) | class NamedDimensions(Transform[Tensor, Tensor]):
    method __init__ (line 64) | def __init__(self, names: Iterable[str]):
    method __call__ (line 67) | def __call__(self, tensor: Tensor) -> Tensor:
  function three_channels (line 72) | def three_channels(x: Any) -> Any:
  function _ (line 87) | def _(x: Tensor) -> Tensor:
  function _ (line 116) | def _(x: np.ndarray) -> np.ndarray:
  function _ (line 139) | def _(x: spaces.Box) -> spaces.Box:
  function _ (line 145) | def _(x: Tuple[int, ...]) -> Tuple[int, ...]:
  function _three_channels (line 163) | def _three_channels(x: Any) -> Any:
  function _three_channels (line 172) | def _three_channels(x: Any) -> Any:
  function _three_channels (line 179) | def _three_channels(x: TypedDictSpace) -> TypedDictSpace:
  class ThreeChannels (line 187) | class ThreeChannels(Transform[Tensor, Tensor]):
    method __call__ (line 199) | def __call__(self, x: Tensor) -> Tensor:
  function channels_first (line 204) | def channels_first(x: Any) -> Any:
  function _ (line 215) | def _(x: Tensor) -> Tensor:
  function _ (line 228) | def _(x: Tuple[int, ...]) -> Tuple[int, ...]:
  function _ (line 238) | def _(x: spaces.Box) -> spaces.Box:
  function _ (line 248) | def _(x: Tuple[int, ...]) -> Tuple[int, ...]:
  function _ (line 257) | def _(x: spaces.Box) -> spaces.Box:
  class ChannelsFirst (line 266) | class ChannelsFirst(Transform[Union[np.ndarray, Tensor], Tensor]):
    method __call__ (line 274) | def __call__(self, x: Tensor) -> Tensor:
    method apply (line 278) | def apply(cls, x: Tensor) -> Tensor:
  class ChannelsFirstIfNeeded (line 308) | class ChannelsFirstIfNeeded(ChannelsFirst):
    method apply (line 312) | def apply(cls, x: Tensor) -> Tensor:
  function channels_last (line 325) | def channels_last(x: Any) -> Any:
  function _ (line 330) | def _(x: Tensor) -> Tensor:
  function _ (line 342) | def _(x: Tuple[int, ...]) -> Tuple[int, ...]:
  function _ (line 352) | def _(x: np.ndarray) -> np.ndarray:
  function _ (line 361) | def _(x: spaces.Box) -> spaces.Box:
  class ChannelsLast (line 370) | class ChannelsLast(Transform[Tensor, Tensor]):
    method __call__ (line 371) | def __call__(self, x: Tensor) -> Tensor:
    method apply (line 375) | def apply(cls, x: Tensor) -> Tensor:
  class ChannelsLastIfNeeded (line 380) | class ChannelsLastIfNeeded(ChannelsLast):
    method apply (line 384) | def apply(cls, x: Tensor) -> Tensor:

FILE: sequoia/common/transforms/compose.py
  class Compose (line 15) | class Compose(List[T], ComposeBase, Transform[InputType, OutputType]):
    method __init__ (line 30) | def __init__(self, *args, **kwargs):
    method __call__ (line 34) | def __call__(self, img):

FILE: sequoia/common/transforms/resize.py
  function resize (line 28) | def resize(x: Img, size: Tuple[int, ...], **kwargs) -> Img:
  function _ (line 34) | def _(x: Image.Image, size: Tuple[int, ...], **kwargs) -> Image.Image:
  function _resize_array_or_tensor (line 40) | def _resize_array_or_tensor(x: np.ndarray, size: Tuple[int, ...], **kwar...
  function _resize_namedtuple_space (line 69) | def _resize_namedtuple_space(
  function _resize_namedtuple (line 84) | def _resize_namedtuple(x: Dict, size: Tuple[int, ...], **kwargs) -> Dict:
  function _resize_typed_dict (line 97) | def _resize_typed_dict(x: TypedDictSpace, size: Tuple[int, ...], **kwarg...
  function _resize_image_shape (line 111) | def _resize_image_shape(x: Tuple[int, ...], size: Tuple[int, ...], **kwa...
  function _resize_space (line 136) | def _resize_space(x: spaces.Box, size: Tuple[int, ...], **kwargs) -> spa...
  class Resize (line 149) | class Resize(Resize_, Transform[Img, Img]):
    method __init__ (line 150) | def __init__(self, size: Tuple[int, ...], interpolation=InterpolationM...
    method __call__ (line 155) | def __call__(self, img):
    method forward (line 164) | def forward(self, img: Img) -> Img:

FILE: sequoia/common/transforms/split_batch.py
  class SplitBatch (line 15) | class SplitBatch(Transform[Any, Tuple[ObservationType, RewardType]]):
    method __init__ (line 51) | def __init__(self, observation_type: Type[ObservationType], reward_typ...
    method __call__ (line 56) | def __call__(self, batch: Any) -> Tuple[ObservationType, RewardType]:
  function split_batch (line 60) | def split_batch(
  function n_fields (line 207) | def n_fields(batch_type: Type[Batch]) -> int:
  function n_required_fields (line 224) | def n_required_fields(batch_type: Type) -> int:

FILE: sequoia/common/transforms/to_tensor.py
  function copy_if_negative_strides (line 31) | def copy_if_negative_strides(image: Img) -> Img:
  function image_to_tensor (line 52) | def image_to_tensor(image: Union[Img, Sequence[Img], gym.Space]) -> Unio...
  function _ (line 80) | def _(image: Union[Image, np.ndarray]) -> Tensor:
  function _list_of_images_to_tensor (line 110) | def _list_of_images_to_tensor(image: Sequence[Img]) -> Tensor:
  function _to_tensor_effect_on_image_shape (line 115) | def _to_tensor_effect_on_image_shape(image: Tuple[int, ...]) -> Tuple[in...
  function _ (line 125) | def _(image: spaces.Box) -> spaces.Box:
  function _ (line 144) | def _(space: Dict, device: torch.device = None) -> Dict:
  function _space_with_images_to_tensor (line 158) | def _space_with_images_to_tensor(space: Dict, device: torch.device = Non...
  function _space_with_images_to_tensor (line 170) | def _space_with_images_to_tensor(
  class ToTensor (line 211) | class ToTensor(ToTensor_, Transform):
    method __call__ (line 212) | def __call__(self, image):

FILE: sequoia/common/transforms/transform.py
  class Transform (line 17) | class Transform(Generic[InputType, OutputType]):
    method __call__ (line 21) | def __call__(self, input: InputType) -> OutputType:
    method __call__ (line 25) | def __call__(self, input: Shape) -> Shape:
    method __call__ (line 29) | def __call__(self, input: Space) -> Space:
    method __call__ (line 33) | def __call__(self, input: Union[InputType, Space, Shape]) -> Union[Out...

FILE: sequoia/common/transforms/transform_enum.py
  class Transforms (line 38) | class Transforms(Enum):
    method __call__ (line 60) | def __call__(self, x):
    method _missing_ (line 64) | def _missing_(cls, value: Any):
    method shape_change (line 74) | def shape_change(self, input_shape: Union[Tuple[int, ...], torch.Size]...
    method space_change (line 79) | def space_change(self, input_space: gym.Space) -> gym.Space:
  class Compose (line 88) | class Compose(List[T], ComposeBase):
    method __init__ (line 112) | def __init__(self, *args, **kwargs):
  function encode_transforms (line 133) | def encode_transforms(v: Transforms) -> str:
  function decode_transforms (line 138) | def decode_transforms(v: str) -> Transforms:

FILE: sequoia/common/transforms/transforms_test.py
  function test_transform (line 68) | def test_transform(transform: Transforms, input_shape, output_shape):
  function test_to_tensor (line 103) | def test_to_tensor(transform: Transforms, input_shape, output_shape):
  function test_applying_transforms_on_weird_input_raises_error (line 123) | def test_applying_transforms_on_weird_input_raises_error(
  function test_compose_applied_on_shape (line 137) | def test_compose_applied_on_shape():
  function test_channels_first_transform_on_gym_env (line 152) | def test_channels_first_transform_on_gym_env():
  function test_preserves_device_when_possible (line 171) | def test_preserves_device_when_possible():

FILE: sequoia/common/transforms/utils.py
  function is_image (line 11) | def is_image(v: Any) -> bool:

FILE: sequoia/conftest.py
  function xfail_param (line 42) | def xfail_param(*args, reason: str):
  function skip_param (line 46) | def skip_param(*args, reason: str):
  function skipif_param (line 50) | def skipif_param(condition, *args, reason: str):
  function add_np (line 55) | def add_np(doctest_namespace):
  function trainer_config (line 60) | def trainer_config(tmp_path_factory):
  function config (line 72) | def config(tmp_path: Path):
  function session_config (line 80) | def session_config(tmp_path_factory: Path):
  function id_fn (line 86) | def id_fn(params: Any) -> str:
  function get_all_dataset_names (line 103) | def get_all_dataset_names(method_class: Type[Method] = None) -> List[str]:
  function get_dataset_params (line 114) | def get_dataset_params(
  function pytest_addoption (line 134) | def pytest_addoption(parser):
  function slow_param (line 145) | def slow_param(*args):
  function find_class_under_test (line 150) | def find_class_under_test(
  function parametrize_test_datasets (line 174) | def parametrize_test_datasets(metafunc):
  function pytest_generate_tests (line 232) | def pytest_generate_tests(metafunc):
  class DummyEnvironment (line 240) | class DummyEnvironment(gym.Env):
    method __init__ (line 250) | def __init__(self, start: int = 0, target: int = 5, max_value: int = N...
    method step (line 266) | def step(self, action: int):
    method reset (line 281) | def reset(self):
    method seed (line 286) | def seed(self, seed: Optional[int]) -> List[int]:
  function param_requires_monsterkong (line 298) | def param_requires_monsterkong(*args):
  function param_requires_atari_py (line 311) | def param_requires_atari_py(*args):
  function param_requires_mtenv (line 322) | def param_requires_mtenv(*args):
  function param_requires_metaworld (line 336) | def param_requires_metaworld(*args):
  function param_requires_mujoco (line 349) | def param_requires_mujoco(*args):
  function param_requires_pyglet (line 370) | def param_requires_pyglet(*args):

FILE: sequoia/experiments/experiment.py
  function get_method_names (line 26) | def get_method_names() -> Dict[str, Type[Method]]:
  class Experiment (line 32) | class Experiment(Parseable, Serializable):
    method __post_init__ (line 65) | def __post_init__(self):
    method run_experiment (line 170) | def run_experiment(
    method launch (line 221) | def launch(
    method main (line 270) | def main(
  function launch_batch_of_runs (line 326) | def launch_batch_of_runs(
  function parse_setting_and_method_instances (line 407) | def parse_setting_and_method_instances(
  function get_class_with_name (line 440) | def get_class_with_name(
  function check_has_descendants (line 464) | def check_has_descendants(potential_classes: List[Type[Method]]) -> List...
  function main (line 481) | def main():

FILE: sequoia/experiments/experiment_test.py
  function test_no_collisions_in_method_names (line 24) | def test_no_collisions_in_method_names():
  function test_no_collisions_in_setting_names (line 29) | def test_no_collisions_in_setting_names():
  function test_applicable_methods (line 33) | def test_applicable_methods():
  function mock_apply (line 40) | def mock_apply(self: Setting, method: Method, config: Config) -> Results:
  function set_argv_for_debug (line 51) | def set_argv_for_debug(monkeypatch):
  function method_type (line 56) | def method_type(request, monkeypatch, set_argv_for_debug):
  function setting_type (line 62) | def setting_type(request, monkeypatch, set_argv_for_debug):
  function test_experiment_from_args (line 70) | def test_experiment_from_args(
  function test_launch_experiment_with_constructor (line 90) | def test_launch_experiment_with_constructor(
  function test_none_setting (line 105) | def test_none_setting(method_type: Optional[Type[Method]], tmp_path: Pat...
  function test_none_method (line 124) | def test_none_method(setting_type: Optional[Type[Setting]]):

FILE: sequoia/experiments/hpo_sweep.py
  class HPOSweep (line 16) | class HPOSweep(Experiment):
    method __post_init__ (line 44) | def __post_init__(self):
    method launch (line 51) | def launch(self, argv: Union[str, List[str]] = None, strict_args: bool...
    method main (line 99) | def main(
  function main (line 140) | def main():

FILE: sequoia/experiments/hpo_sweep_test.py
  class MockResults (line 19) | class MockResults(Results):
    method __init__ (line 20) | def __init__(self, hparams):
    method objective (line 25) | def objective(self) -> float:
    method make_plots (line 28) | def make_plots(self):
    method to_log_dict (line 31) | def to_log_dict(self, verbose: bool = False):
    method summary (line 39) | def summary(self):
  function mock_apply (line 43) | def mock_apply(self: Setting, method: Method, config: Config = None) -> ...
  function set_argv_for_debug (line 56) | def set_argv_for_debug(monkeypatch):
  function method_type (line 61) | def method_type(request, monkeypatch, set_argv_for_debug):
  function setting_type (line 67) | def setting_type(request, monkeypatch, set_argv_for_debug):
  function test_launch_sweep_with_constructor (line 78) | def test_launch_sweep_with_constructor(

FILE: sequoia/main.py
  function main (line 26) | def main():
  function add_run_command (line 79) | def add_run_command(command_subparsers: _SubParsersAction) -> None:
  function run (line 92) | def run(setting: Setting, method: Method, config: Config) -> Results:
  class SweepConfig (line 111) | class SweepConfig(Config):
  function sweep (line 135) | def sweep(setting: Setting, method: Method, config: SweepConfig) -> Sett...
  function add_sweep_command (line 168) | def add_sweep_command(command_subparsers: _SubParsersAction) -> None:
  function add_info_command (line 180) | def add_info_command(command_subparsers: _SubParsersAction) -> None:
  function info (line 221) | def info(component: Union[Type[Setting], Type[Method]] = None) -> None:
  function get_help (line 247) | def get_help(component: Type[Setting]) -> str:
  function add_args_for_settings_and_methods (line 276) | def add_args_for_settings_and_methods(command_subparser: ArgumentParser):

FILE: sequoia/methods/__init__.py
  function register_method (line 59) | def register_method(
  function get_external_methods (line 108) | def get_external_methods() -> Dict[str, Type[Method]]:
  function add_external_methods (line 204) | def add_external_methods(all_methods: List[Type[Method]]) -> List[Type[M...
  function get_all_methods (line 212) | def get_all_methods() -> List[Type[Method]]:

FILE: sequoia/methods/aux_tasks/auxiliary_task.py
  class AuxiliaryTask (line 17) | class AuxiliaryTask(nn.Module):
    class Options (line 40) | class Options(HyperParameters):
    method __init__ (line 46) | def __init__(self, *args, options: Options = None, name: str = None, *...
    method encode (line 77) | def encode(self, x: Tensor) -> Tensor:
    method logits (line 81) | def logits(self, h_x: Tensor) -> Tensor:
    method get_loss (line 85) | def get_loss(self, forward_pass: Dict[str, Tensor], y: Tensor = None) ...
    method coefficient (line 118) | def coefficient(self) -> float:
    method coefficient (line 122) | def coefficient(self, value: float) -> None:
    method enable (line 129) | def enable(self) -> None:
    method disable (line 138) | def disable(self) -> None:
    method enabled (line 145) | def enabled(self) -> bool:
    method disabled (line 149) | def disabled(self) -> bool:
    method on_task_switch (line 152) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method model (line 156) | def model(self) -> LightningModule:
    method set_model (line 160) | def set_model(model: "Model") -> None:
    method shared_modules (line 163) | def shared_modules(self) -> Dict[str, nn.Module]:

FILE: sequoia/methods/aux_tasks/ewc.py
  class EWCTask (line 35) | class EWCTask(AuxiliaryTask):
    class Options (line 58) | class Options(AuxiliaryTask.Options):
    method __init__ (line 76) | def __init__(self, *args, name: str = None, options: "EWCTask.Options"...
    method get_loss (line 108) | def get_loss(self, forward_pass: ForwardPass, y: Tensor = None) -> Loss:
    method on_task_switch (line 127) | def on_task_switch(self, task_id: Optional[int]):
    method update_anchor_weights (line 172) | def update_anchor_weights(self, new_task_id: int) -> None:
    method _ignoring_task_boundaries (line 257) | def _ignoring_task_boundaries(self):
    method consolidate (line 263) | def consolidate(self, new_fims: List[PMatAbstract], task: Optional[int...
    method get_current_model_weights (line 306) | def get_current_model_weights(self) -> PVector:

FILE: sequoia/methods/aux_tasks/reconstruction/ae.py
  class AEReconstructionTask (line 14) | class AEReconstructionTask(AuxiliaryTask):
    method __init__ (line 24) | def __init__(self, coefficient: float = None, options: AuxiliaryTask.O...
    method create_decoder (line 34) | def create_decoder(self, input_shape: Union[torch.Size, Tuple[int, ......
    method get_loss (line 49) | def get_loss(self, forward_pass: Dict[str, Tensor], y: Tensor = None) ...
    method forward (line 65) | def forward(self, h_x: Tensor) -> Tensor:  # type: ignore
    method reconstruct (line 70) | def reconstruct(self, x: Tensor) -> Tensor:
    method reconstruction_loss (line 75) | def reconstruction_loss(self, recon_x: Tensor, x: Tensor) -> Tensor:

FILE: sequoia/methods/aux_tasks/reconstruction/decoder_for_dataset.py
  function get_decoder_class_for_dataset (line 16) | def get_decoder_class_for_dataset(input_shape: Union[Tuple[int, int, int...

FILE: sequoia/methods/aux_tasks/reconstruction/decoders.py
  class Decoder (line 9) | class Decoder(nn.Sequential, ABC):
  class MnistDecoder (line 16) | class MnistDecoder(Decoder):
    method __init__ (line 19) | def __init__(self, code_size: int, out_channels: int = 3):
  class CifarDecoder (line 38) | class CifarDecoder(Decoder):
    method __init__ (line 41) | def __init__(self, code_size: int):
  class ImageNetDecoder (line 55) | class ImageNetDecoder(Decoder):
    method __init__ (line 58) | def __init__(self, code_size: int):

FILE: sequoia/methods/aux_tasks/reconstruction/vae.py
  class VAEReconstructionTask (line 14) | class VAEReconstructionTask(AEReconstructionTask):
    class Options (line 25) | class Options(AEReconstructionTask.Options):
    method __init__ (line 31) | def __init__(self, coefficient: float = None, options: "VAEReconstruct...
    method forward (line 43) | def forward(self, h_x: Tensor) -> Tensor:  # type: ignore
    method reparameterize (line 50) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method get_loss (line 56) | def get_loss(self, forward_pass: Dict[str, Tensor], y: Tensor = None) ...
    method generate (line 71) | def generate(self, z: Tensor) -> Tensor:
    method kl_divergence_loss (line 76) | def kl_divergence_loss(mu: Tensor, logvar: Tensor) -> Tensor:

FILE: sequoia/methods/aux_tasks/transformation_based/bases.py
  function wrap_pil_transform (line 19) | def wrap_pil_transform(function: Callable):
  class TransformationBasedTask (line 32) | class TransformationBasedTask(AuxiliaryTask):
    class Options (line 46) | class Options(AuxiliaryTask.Options):
    method __init__ (line 53) | def __init__(
    method get_loss (line 101) | def get_loss(self, x: Tensor, h_x: Tensor, y_pred: Tensor = None, y: T...
    method get_loss_for_arg (line 124) | def get_loss_for_arg(self, x: Tensor, h_x: Tensor, fn_arg: Any, alpha:...
  class ClassifyTransformationTask (line 156) | class ClassifyTransformationTask(TransformationBasedTask):
    method __init__ (line 166) | def __init__(
    method get_loss (line 182) | def get_loss(self, x: Tensor, h_x: Tensor, y_pred: Tensor = None, y: T...
  class RegressTransformationTask (line 188) | class RegressTransformationTask(TransformationBasedTask):
    method __init__ (line 200) | def __init__(
    method get_function_args (line 240) | def get_function_args(self) -> Tensor:
    method get_loss (line 247) | def get_loss(self, x: Tensor, h_x: Tensor, y_pred: Tensor = None, y: T...
  class ScaleToRange (line 256) | class ScaleToRange(nn.Module):
    method __init__ (line 257) | def __init__(self, arg_min: float, arg_amp: float):
    method forward (line 262) | def forward(self, x: Tensor) -> Tensor:

FILE: sequoia/methods/aux_tasks/transformation_based/rotation.py
  function rotate (line 8) | def rotate(x: Tensor, angle: int) -> Tensor:
  class RotationTask (line 55) | class RotationTask(ClassifyTransformationTask):
    class Options (line 57) | class Options(ClassifyTransformationTask.Options):
    method __init__ (line 65) | def __init__(self, name="rotation", options: "RotationTask.Options" = ...

FILE: sequoia/methods/avalanche_methods/agem.py
  class AGEMMethod (line 22) | class AGEMMethod(AvalancheMethod[AGEM]):

FILE: sequoia/methods/avalanche_methods/agem_test.py
  class TestAGEMMethod (line 12) | class TestAGEMMethod(_TestAvalancheMethod):

FILE: sequoia/methods/avalanche_methods/ar1.py
  class AR1Method (line 19) | class AR1Method(AvalancheMethod[AR1]):

FILE: sequoia/methods/avalanche_methods/ar1_test.py
  class TestAR1Method (line 22) | class TestAR1Method(_TestAvalancheMethod):
    method test_short_task_incremental_setting (line 44) | def test_short_task_incremental_setting(

FILE: sequoia/methods/avalanche_methods/base.py
  class WandBLogger (line 54) | class WandBLogger(_WandBLogger):
    method import_wandb (line 64) | def import_wandb(self):
    method args_parse (line 71) | def args_parse(self):
    method before_run (line 76) | def before_run(self):
  class AvalancheMethod (line 88) | class AvalancheMethod(
    method __post_init__ (line 163) | def __post_init__(self):
    method configure (line 171) | def configure(self, setting: ClassIncrementalSetting) -> None:
    method create_cl_strategy (line 236) | def create_cl_strategy(self, setting: ClassIncrementalSetting) -> Stra...
    method create_model (line 247) | def create_model(self, setting: ClassIncrementalSetting) -> Module:
    method make_optimizer (line 305) | def make_optimizer(self) -> Optimizer:
    method fit (line 316) | def fit(self, train_env: PassiveEnvironment, valid_env: PassiveEnviron...
    method get_actions (line 321) | def get_actions(
    method set_testing (line 342) | def set_testing(self):
    method on_task_switch (line 346) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method get_search_space (line 367) | def get_search_space(self, setting: ClassIncrementalSetting):
    method adapt_to_new_hparams (line 370) | def adapt_to_new_hparams(self, new_hparams: Dict):
    method environment_to_experience (line 376) | def environment_to_experience(self, env: PassiveEnvironment, setting: ...
  function test_epoch (line 450) | def test_epoch(strategy, test_env: ContinualSLTestEnvironment, **kwargs):
  function test_epoch_gym_env (line 469) | def test_epoch_gym_env(strategy: BaseStrategy, test_env: ContinualSLTest...

FILE: sequoia/methods/avalanche_methods/base_test.py
  class _TestAvalancheMethod (line 22) | class _TestAvalancheMethod(MethodTests):
    method method (line 39) | def method(cls, config: Config, request) -> AvalancheMethod:
    method test_hparams_have_same_defaults_as_in_avalanche (line 44) | def test_hparams_have_same_defaults_as_in_avalanche(self):
    method validate_results (line 71) | def validate_results(
    method test_short_sl_track (line 84) | def test_short_sl_track(
  function test_warning_if_environment_to_experience_isnt_overwritten (line 101) | def test_warning_if_environment_to_experience_isnt_overwritten(short_sl_...
  class MyDummyMethod (line 109) | class MyDummyMethod(AvalancheMethod):
    method environment_to_experience (line 110) | def environment_to_experience(self, env, setting):
  function test_no_warning_if_environment_to_experience_is_overwritten (line 170) | def test_no_warning_if_environment_to_experience_is_overwritten(short_sl...

FILE: sequoia/methods/avalanche_methods/conftest.py
  function config (line 22) | def config(tmp_path_factory):
  function fast_scenario (line 28) | def fast_scenario(use_task_labels=False, shuffle=True):

FILE: sequoia/methods/avalanche_methods/cwr_star.py
  class CWRStarMethod (line 19) | class CWRStarMethod(AvalancheMethod[CWRStar]):

FILE: sequoia/methods/avalanche_methods/cwr_star_test.py
  class TestCWRStarMethod (line 12) | class TestCWRStarMethod(_TestAvalancheMethod):

FILE: sequoia/methods/avalanche_methods/ewc.py
  class EWCMethod (line 24) | class EWCMethod(AvalancheMethod[EWC]):

FILE: sequoia/methods/avalanche_methods/ewc_test.py
  class TestEWCMethod (line 21) | class TestEWCMethod(_TestAvalancheMethod):
    method method (line 50) | def method(cls, config: Config, request) -> AvalancheMethod:
    method test_short_task_incremental_setting (line 79) | def test_short_task_incremental_setting(
    method test_short_class_incremental_setting (line 112) | def test_short_class_incremental_setting(

FILE: sequoia/methods/avalanche_methods/experience.py
  class SequoiaExperience (line 17) | class SequoiaExperience(IterableWrapper, Experience):
    method __init__ (line 18) | def __init__(
    method dataset (line 121) | def dataset(self) -> AvalancheDataset:
    method dataset (line 125) | def dataset(self, value: AvalancheDataset) -> None:
    method task_label (line 129) | def task_label(self):
    method task_labels (line 145) | def task_labels(self):
    method current_experience (line 149) | def current_experience(self):
    method origin_stream (line 154) | def origin_stream(self) -> SLSetting:

FILE: sequoia/methods/avalanche_methods/gdumb.py
  class GDumbPlugin (line 32) | class GDumbPlugin(_GDumbPlugin):
    method __init__ (line 43) | def __init__(self, mem_size: int = 200):
    method after_train_dataset_adaptation (line 49) | def after_train_dataset_adaptation(self, strategy: BaseStrategy, **kwa...
  class GDumbMethod (line 122) | class GDumbMethod(AvalancheMethod[GDumb]):
    method create_cl_strategy (line 140) | def create_cl_strategy(self, setting: ClassIncrementalSetting) -> GDumb:

FILE: sequoia/methods/avalanche_methods/gdumb_test.py
  class TestGDumbMethod (line 12) | class TestGDumbMethod(_TestAvalancheMethod):

FILE: sequoia/methods/avalanche_methods/gem.py
  class GEMMethod (line 21) | class GEMMethod(AvalancheMethod[GEM]):

FILE: sequoia/methods/avalanche_methods/gem_test.py
  class TestGEMMethod (line 12) | class TestGEMMethod(_TestAvalancheMethod):

FILE: sequoia/methods/avalanche_methods/lwf.py
  class LwFPlugin (line 20) | class LwFPlugin(LwFPlugin_):
    method _distillation_loss (line 25) | def _distillation_loss(self, out: Tensor, prev_out: Tensor) -> Tensor:
  class LwFMethod (line 45) | class LwFMethod(AvalancheMethod[LwF]):
    method create_cl_strategy (line 66) | def create_cl_strategy(self, setting: SLSetting) -> LwF:

FILE: sequoia/methods/avalanche_methods/lwf_test.py
  class TestLwFMethod (line 12) | class TestLwFMethod(_TestAvalancheMethod):

FILE: sequoia/methods/avalanche_methods/naive.py
  class NaiveMethod (line 14) | class NaiveMethod(AvalancheMethod[Naive]):

FILE: sequoia/methods/avalanche_methods/naive_test.py
  class TestNaiveMethod (line 12) | class TestNaiveMethod(_TestAvalancheMethod):

FILE: sequoia/methods/avalanche_methods/patched_models.py
  class PatchedMultiTaskModule (line 21) | class PatchedMultiTaskModule(MultiTaskModule):
    method known_task_ids (line 24) | def known_task_ids(self) -> List[Any]:
    method task_inference_forward_pass (line 27) | def task_inference_forward_pass(self, x: Tensor) -> Tensor:
  class MultiHeadClassifier (line 127) | class MultiHeadClassifier(_MultiHeadClassifier):
    method __init__ (line 128) | def __init__(self, in_features: int, initial_out_features: int = 2):
    method adaptation (line 140) | def adaptation(self, dataset: AvalancheDataset):
    method forward (line 148) | def forward(self, x: Tensor, task_labels: Optional[Tensor]) -> Tensor:
    method forward_single_task (line 157) | def forward_single_task(self, x: Tensor, task_label: Optional[Tensor]):
  class MTSimpleCNN (line 187) | class MTSimpleCNN(_MTSimpleCNN, PatchedMultiTaskModule):
    method __init__ (line 188) | def __init__(self):
    method forward (line 192) | def forward(self, x: Tensor, task_labels: Optional[Tensor] = None) -> ...
    method known_task_ids (line 208) | def known_task_ids(self) -> List[Any]:
  class MTSimpleMLP (line 212) | class MTSimpleMLP(_MTSimpleMLP, PatchedMultiTaskModule):
    method __init__ (line 213) | def __init__(self, input_size: int = 28 * 28, hidden_size: int = 512):
    method forward (line 220) | def forward(self, x: Tensor, task_labels: Optional[Tensor] = None) -> ...
    method known_task_ids (line 230) | def known_task_ids(self) -> List[Any]:

FILE: sequoia/methods/avalanche_methods/plugins.py
  class GatherDataset (line 14) | class GatherDataset(StrategyPlugin):
    method __init__ (line 19) | def __init__(self):
    method after_forward (line 31) | def after_forward(self, strategy, **kwargs):
    method after_training_epoch (line 38) | def after_training_epoch(self, strategy, **kwargs):
    method after_eval_forward (line 47) | def after_eval_forward(self, strategy, **kwargs):
    method after_eval_exp (line 54) | def after_eval_exp(self, strategy, **kwargs):
    method train (line 66) | def train(self):
    method eval (line 69) | def eval(self):
    method after_training_exp (line 72) | def after_training_exp(self, strategy: "BaseStrategy", **kwargs):
  class OnlineAccuracyPlugin (line 88) | class OnlineAccuracyPlugin(StrategyPlugin):
    method __init__ (line 89) | def __init__(self):
    method _calc_accuracy (line 94) | def _calc_accuracy(self, strategy: "BaseStrategy") -> float:
    method after_forward (line 100) | def after_forward(self, strategy: "BaseStrategy", **kwargs):
    method after_training_epoch (line 107) | def after_training_epoch(self, strategy, **kwargs):

FILE: sequoia/methods/avalanche_methods/replay.py
  class ReplayPlugin (line 24) | class ReplayPlugin(ReplayPlugin_):
    method __init__ (line 25) | def __init__(self, mem_size: int = 200, storage_policy: Optional["Stor...
  class ExperienceBalancedStoragePolicy (line 37) | class ExperienceBalancedStoragePolicy(ExperienceBalancedStoragePolicy_):
    method __call__ (line 38) | def __call__(self, strategy: BaseStrategy, **kwargs):
  class ReplayMethod (line 73) | class ReplayMethod(AvalancheMethod[Replay]):
    method create_cl_strategy (line 86) | def create_cl_strategy(self, setting: SLSetting) -> Replay:

FILE: sequoia/methods/avalanche_methods/replay_test.py
  class TestReplayMethod (line 12) | class TestReplayMethod(_TestAvalancheMethod):

FILE: sequoia/methods/avalanche_methods/synaptic_intelligence.py
  class SynapticIntelligencePlugin (line 28) | class SynapticIntelligencePlugin(SynapticIntelligencePlugin_):
    method extract_weights (line 33) | def extract_weights(model: Module, target: ParamDict, excluded_paramet...
    method extract_grad (line 52) | def extract_grad(model, target: ParamDict, excluded_parameters: Set[st...
    method compute_ewc_loss (line 62) | def compute_ewc_loss(
    method post_update (line 103) | def post_update(model, syn_data: SynDataType, excluded_parameters: Set...
    method update_ewc_data (line 143) | def update_ewc_data(
  class SynapticIntelligenceMethod (line 214) | class SynapticIntelligenceMethod(AvalancheMethod[SynapticIntelligence]):
    method create_cl_strategy (line 237) | def create_cl_strategy(self, setting: SLSetting) -> SynapticIntelligence:

FILE: sequoia/methods/avalanche_methods/synaptic_intelligence_test.py
  class TestSynapticIntelligenceMethod (line 12) | class TestSynapticIntelligenceMethod(_TestAvalancheMethod):

FILE: sequoia/methods/base_method.py
  class BaseMethod (line 47) | class BaseMethod(Method, Serializable, Parseable, target_setting=SLSetti...
    method __init__ (line 67) | def __init__(
    method configure (line 162) | def configure(self, setting: SettingType) -> None:
    method fit (line 270) | def fit(
    method get_actions (line 303) | def get_actions(self, observations: Observations, action_space: gym.Sp...
    method create_model (line 324) | def create_model(self, setting: SettingType) -> BaseModel[SettingType]:
    method create_trainer (line 344) | def create_trainer(self, setting: SettingType) -> Trainer:
    method get_experiment_name (line 369) | def get_experiment_name(self, setting: Setting, experiment_id: str = N...
    method get_search_space (line 401) | def get_search_space(self, setting: Setting) -> Mapping[str, Union[str...
    method adapt_to_new_hparams (line 420) | def adapt_to_new_hparams(self, new_hparams: Dict[str, Any]) -> None:
    method hparam_sweep (line 439) | def hparam_sweep(
    method receive_results (line 470) | def receive_results(self, setting: Setting, results: Results):
    method configure_callbacks (line 476) | def configure_callbacks(self, setting: SettingType = None) -> List[Cal...
    method apply_all (line 501) | def apply_all(self, argv: Union[str, List[str]] = None) -> Dict[Type[S...
    method __init_subclass__ (line 525) | def __init_subclass__(cls, target_setting: Type[SettingType] = Setting...
    method on_task_switch (line 546) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method setup_wandb (line 556) | def setup_wandb(self, run: Run) -> None:

FILE: sequoia/methods/base_method_test.py
  class TestBaseMethod (line 21) | class TestBaseMethod(MethodTests):
    method trainer_options (line 27) | def trainer_options(cls, tmp_path_factory) -> TrainerConfig:
    method method (line 38) | def method(cls, config: Config, trainer_options: TrainerConfig) -> Bas...
    method validate_results (line 43) | def validate_results(
    method test_cartpole_state (line 57) | def test_cartpole_state(self, config: Config, trainer_options: Trainer...
    method test_incremental_cartpole_state (line 81) | def test_incremental_cartpole_state(self, config: Config, trainer_opti...
    method test_device_of_output_head_is_correct (line 101) | def test_device_of_output_head_is_correct(
  function test_weird_pl_bug (line 116) | def test_weird_pl_bug():

FILE: sequoia/methods/conftest.py
  function short_class_incremental_setting (line 15) | def short_class_incremental_setting(session_config: Config):
  function short_continual_sl_setting (line 54) | def short_continual_sl_setting(session_config: Config):
  function short_discrete_task_agnostic_sl_setting (line 92) | def short_discrete_task_agnostic_sl_setting(session_config: Config):
  function short_task_incremental_setting (line 130) | def short_task_incremental_setting(session_config: Config):
  function short_sl_track_setting (line 170) | def short_sl_track_setting(session_config: Config):

FILE: sequoia/methods/d3rlpy_methods/base.py
  class OfflineRLWrapper (line 19) | class OfflineRLWrapper(gym.Wrapper):
    method __init__ (line 20) | def __init__(self, env):
    method reset (line 24) | def reset(self):
    method step (line 28) | def step(self, action):
  class BaseOfflineRLMethod (line 33) | class BaseOfflineRLMethod(Method, target_setting=OfflineRLSetting):
    method __init__ (line 36) | def __init__(
    method configure (line 55) | def configure(self, setting: OfflineRLSetting) -> None:
    method fit (line 60) | def fit(
    method get_actions (line 82) | def get_actions(self, obs: Union[np.ndarray, Observations], action_spa...
  class DQNMethod (line 98) | class DQNMethod(BaseOfflineRLMethod):
  class DoubleDQNMethod (line 102) | class DoubleDQNMethod(BaseOfflineRLMethod):
  class DDPGMethod (line 106) | class DDPGMethod(BaseOfflineRLMethod):
  class TD3Method (line 110) | class TD3Method(BaseOfflineRLMethod):
  class SACMethod (line 114) | class SACMethod(BaseOfflineRLMethod):
  class DiscreteSACMethod (line 118) | class DiscreteSACMethod(BaseOfflineRLMethod):
  class CQLMethod (line 122) | class CQLMethod(BaseOfflineRLMethod):
  class DiscreteCQLMethod (line 126) | class DiscreteCQLMethod(BaseOfflineRLMethod):
  class BEARMethod (line 130) | class BEARMethod(BaseOfflineRLMethod):
  class AWRMethod (line 134) | class AWRMethod(BaseOfflineRLMethod):
  class DiscreteAWRMethod (line 138) | class DiscreteAWRMethod(BaseOfflineRLMethod):
  class BCMethod (line 142) | class BCMethod(BaseOfflineRLMethod):
  class DiscreteBCMethod (line 146) | class DiscreteBCMethod(BaseOfflineRLMethod):
  class BCQMethod (line 150) | class BCQMethod(BaseOfflineRLMethod):
  class DiscreteBCQMethod (line 154) | class DiscreteBCQMethod(BaseOfflineRLMethod):

FILE: sequoia/methods/d3rlpy_methods/base_test.py
  class BaseOfflineRLMethodTests (line 9) | class BaseOfflineRLMethodTests:
    method method (line 13) | def method(self):
    method test_offlinerl (line 17) | def test_offlinerl(self, method, dataset: str):
    method test_traditionalrl (line 39) | def test_traditionalrl(self, method, dataset):
  class TestDQNMethod (line 65) | class TestDQNMethod(BaseOfflineRLMethodTests):
  class TestDoubleDQNMethod (line 69) | class TestDoubleDQNMethod(BaseOfflineRLMethodTests):
  class TestDDPGMethod (line 73) | class TestDDPGMethod(BaseOfflineRLMethodTests):
  class TestTD3Method (line 77) | class TestTD3Method(BaseOfflineRLMethodTests):
  class TestSACMethod (line 81) | class TestSACMethod(BaseOfflineRLMethodTests):
  class TestDiscreteSACMethod (line 85) | class TestDiscreteSACMethod(BaseOfflineRLMethodTests):
  class TestCQLMethod (line 89) | class TestCQLMethod(BaseOfflineRLMethodTests):
  class TestDiscreteCQLMethod (line 93) | class TestDiscreteCQLMethod(BaseOfflineRLMethodTests):
  class TestBEARMethod (line 97) | class TestBEARMethod(BaseOfflineRLMethodTests):
  class TestAWRMethod (line 101) | class TestAWRMethod(BaseOfflineRLMethodTests):
  class TestDiscreteAWRMethod (line 105) | class TestDiscreteAWRMethod(BaseOfflineRLMethodTests):
  class TestBCMethod (line 109) | class TestBCMethod(BaseOfflineRLMethodTests):
  class TestDiscreteBCMethod (line 113) | class TestDiscreteBCMethod(BaseOfflineRLMethodTests):
  class TestBCQMethod (line 117) | class TestBCQMethod(BaseOfflineRLMethodTests):
  class TestDiscreteBCQMethod (line 121) | class TestDiscreteBCQMethod(BaseOfflineRLMethodTests):

FILE: sequoia/methods/ewc_method.py
  class EwcModel (line 25) | class EwcModel(BaseModel):
    class HParams (line 29) | class HParams(BaseModel.HParams):
    method __init__ (line 35) | def __init__(self, setting: Setting, hparams: "EwcModel.HParams", conf...
    method get_loss (line 40) | def get_loss(self, forward_pass, rewards=None, loss_name=""):
  class EwcMethod (line 46) | class EwcMethod(BaseMethod, target_setting=IncrementalSLSetting):
    method __init__ (line 55) | def __init__(
    method configure (line 64) | def configure(self, setting: IncrementalAssumption):
    method on_task_switch (line 86) | def on_task_switch(self, task_id: Optional[int]):
    method create_model (line 89) | def create_model(self, setting: Setting) -> EwcModel:
  function demo (line 108) | def demo():

FILE: sequoia/methods/ewc_method_test.py
  class TestEWCMethod (line 26) | class TestEWCMethod(BaseMethodTests):
    method method (line 31) | def method(cls, config: Config, trainer_options: TrainerConfig) -> Ewc...
    method test_task_incremental_mnist (line 38) | def test_task_incremental_mnist(self, monkeypatch):
    method test_raises_warning_when_applied_to_non_cl_setting (line 101) | def test_raises_warning_when_applied_to_non_cl_setting(self, non_cl_se...

FILE: sequoia/methods/experience_replay.py
  class ExperienceReplayMethod (line 34) | class ExperienceReplayMethod(Method, target_setting=ClassIncrementalSett...
    method __init__ (line 37) | def __init__(
    method configure (line 64) | def configure(self, setting: ClassIncrementalSetting):
    method fit (line 90) | def fit(self, train_env: Environment, valid_env: Environment):
    method get_actions (line 184) | def get_actions(self, observations: Observations, action_space: gym.Sp...
    method on_task_switch (line 201) | def on_task_switch(self, task_id: Optional[int]):
    method add_argparse_args (line 207) | def add_argparse_args(cls, parser: ArgumentParser) -> None:
    method from_argparse_args (line 222) | def from_argparse_args(cls, args: Namespace, dest: str = None):
    method get_search_space (line 247) | def get_search_space(self, setting: ClassIncrementalSetting) -> Dict:
    method adapt_to_new_hparams (line 255) | def adapt_to_new_hparams(self, new_hparams: Dict[str, Any]) -> None:
    method setup_wandb (line 280) | def setup_wandb(self, run: Run) -> None:
  class Buffer (line 305) | class Buffer(nn.Module):
    method __init__ (line 306) | def __init__(
    method x (line 338) | def x(self):
    method y (line 342) | def y(self):
    method add_reservoir (line 346) | def add_reservoir(self, batch: Dict[str, Tensor]) -> None:
    method sample (line 399) | def sample(self, n_samples: int, exclude_task: int = None) -> Dict[str...

FILE: sequoia/methods/experience_replay_test.py
  class TestExperienceReplay (line 13) | class TestExperienceReplay(MethodTests):
    method method (line 19) | def method(cls, config: Config) -> ExperienceReplayMethod:
    method validate_results (line 23) | def validate_results(
    method test_class_incremental_mnist (line 34) | def test_class_incremental_mnist(self, config: Config):

FILE: sequoia/methods/hat.py
  class Masks (line 37) | class Masks(NamedTuple):
  class HatNet (line 47) | class HatNet(torch.nn.Module):
    method __init__ (line 62) | def __init__(self, image_space: Image, n_classes_per_task: Dict[int, i...
    method forward (line 110) | def forward(self, observations: TaskIncrementalSLSetting.Observations)...
    method mask (line 147) | def mask(self, t: Tensor, s_hat: float) -> Masks:
    method shared_step (line 155) | def shared_step(
  function compute_conv_output_size (line 207) | def compute_conv_output_size(
  class HatMethod (line 214) | class HatMethod(Method, target_setting=TaskIncrementalSLSetting):
    class HParams (line 229) | class HParams(HyperParameters):
    method __init__ (line 241) | def __init__(self, hparams: HParams = None):
    method configure (line 248) | def configure(self, setting: TaskIncrementalSLSetting):
    method fit (line 272) | def fit(self, train_env: PassiveEnvironment, valid_env: PassiveEnviron...
    method get_actions (line 330) | def get_actions(self, observations: Observations, action_space: gym.Sp...
    method on_task_switch (line 338) | def on_task_switch(self, task_id: Optional[int]):
    method add_argparse_args (line 346) | def add_argparse_args(cls, parser: ArgumentParser) -> None:
    method from_argparse_args (line 352) | def from_argparse_args(cls, args: Namespace) -> "HatMethod":
    method get_search_space (line 358) | def get_search_space(self, setting: Setting) -> Mapping[str, Union[str...
    method adapt_to_new_hparams (line 374) | def adapt_to_new_hparams(self, new_hparams: Dict[str, Any]) -> None:
    method setup_wandb (line 392) | def setup_wandb(self, run: Run) -> None:

FILE: sequoia/methods/method_test.py
  function key_fn (line 14) | def key_fn(setting_class: Type[Setting]):
  function make_setting_type_fixture (line 20) | def make_setting_type_fixture(method_type: Type[Method]) -> pytest.fixture:
  class MethodTests (line 43) | class MethodTests(ABC):
    method __init_subclass__ (line 55) | def __init_subclass__(cls, method: Type[MethodType] = None):
    method method (line 71) | def method(cls, config: Config) -> MethodType:
    method validate_results (line 80) | def validate_results(
    method setting (line 98) | def setting(self, setting_type: Type[Setting], session_config: Config):
    method test_debug (line 165) | def test_debug(self, method: MethodType, setting: Setting, config: Con...
  class NewSetting (line 172) | class NewSetting(Setting):
  class NewMethod (line 177) | class NewMethod(Method, target_setting=NewSetting):
    method fit (line 178) | def fit(self, train_env, valid_env):
    method get_actions (line 181) | def get_actions(self, observations, action_space):
  function test_passing_arg_to_class_constructor_works (line 185) | def test_passing_arg_to_class_constructor_works():
  function test_cant_change_target_setting (line 191) | def test_cant_change_target_setting():
  function test_target_setting_is_inherited (line 198) | def test_target_setting_is_inherited():
  class SettingA (line 207) | class SettingA(Setting):
  class SettingA1 (line 212) | class SettingA1(SettingA):
  class SettingA2 (line 217) | class SettingA2(SettingA):
  class SettingB (line 222) | class SettingB(Setting):
  class MethodA (line 226) | class MethodA(Method, target_setting=SettingA):
    method fit (line 227) | def fit(self, train_env, valid_env):
    method get_actions (line 230) | def get_actions(self, observations, action_space):
  class MethodB (line 234) | class MethodB(Method, target_setting=SettingB):
    method fit (line 235) | def fit(self, train_env, valid_env):
    method get_actions (line 238) | def get_actions(self, observations, action_space):
  class CoolGeneralMethod (line 242) | class CoolGeneralMethod(Method, target_setting=Setting):
    method fit (line 243) | def fit(self, train_env, valid_env):
    method get_actions (line 246) | def get_actions(self, observations, action_space):
  function test_method_is_applicable_to_setting (line 250) | def test_method_is_applicable_to_setting():
  function test_is_applicable_also_works_on_instances (line 290) | def test_is_applicable_also_works_on_instances():

FILE: sequoia/methods/models/base_model/base_model.py
  class BaseModel (line 36) | class BaseModel(SemiSupervisedModel, MultiHeadModel, SelfSupervisedModel...
    class HParams (line 48) | class HParams(SemiSupervisedModel.HParams, SelfSupervisedModel.HParams...
    method __init__ (line 113) | def __init__(self, setting: SettingType, hparams: HParams, config: Con...
    method on_fit_start (line 145) | def on_fit_start(self):
    method forward (line 150) | def forward(self, observations: Setting.Observations) -> ForwardPass: ...
    method create_output_head (line 187) | def create_output_head(self, task_id: Optional[int]) -> OutputHead:
    method output_head_type (line 209) | def output_head_type(self, setting: SettingType) -> Type[OutputHead]:
    method automatic_optimization (line 215) | def automatic_optimization(self) -> bool:
    method training_step (line 218) | def training_step(
    method validation_step (line 234) | def validation_step(
    method test_step (line 248) | def test_step(
    method shared_step (line 262) | def shared_step(
    method on_task_switch (line 280) | def on_task_switch(self, task_id: Optional[int]) -> None:

FILE: sequoia/methods/models/base_model/model.py
  class Model (line 71) | class Model(LightningModule, Generic[SettingType]):
    class HParams (line 86) | class HParams(HyperParameters):
    method __init__ (line 148) | def __init__(self, setting: SettingType, hparams: HParams, config: Con...
    method make_encoder (line 218) | def make_encoder(self) -> Tuple[nn.Module, int]:
    method forward (line 241) | def forward(self, observations: IncrementalAssumption.Observations) ->...
    method encode (line 271) | def encode(self, observations: Observations) -> Tensor:
    method create_output_head (line 307) | def create_output_head(self, task_id: Optional[int]) -> OutputHead:
    method output_head_type (line 360) | def output_head_type(self, setting: SettingType) -> Type[OutputHead]:
    method training_step (line 386) | def training_step(
    method validation_step (line 403) | def validation_step(
    method test_step (line 418) | def test_step(
    method shared_step (line 433) | def shared_step(
    method training_step_end (line 469) | def training_step_end(self, step_outputs: Union[Loss, List[Loss]]) -> ...
    method validation_step_end (line 499) | def validation_step_end(self, step_outputs: Union[ForwardPass, List[Fo...
    method test_step_end (line 504) | def test_step_end(self, step_outputs: Union[ForwardPass, List[ForwardP...
    method shared_step_end (line 509) | def shared_step_end(
    method split_batch (line 563) | def split_batch(self, batch: Any) -> Tuple[Observations, Optional[Rewa...
    method get_loss (line 592) | def get_loss(
    method output_head_loss (line 640) | def output_head_loss(
    method preprocess_observations (line 652) | def preprocess_observations(self, observations: Observations) -> Obser...
    method preprocess_rewards (line 661) | def preprocess_rewards(self, reward: Rewards) -> Rewards:
    method configure_optimizers (line 664) | def configure_optimizers(self):
    method batch_size (line 677) | def batch_size(self) -> int:
    method batch_size (line 681) | def batch_size(self, value: int) -> None:
    method learning_rate (line 685) | def learning_rate(self) -> float:
    method learning_rate (line 689) | def learning_rate(self, value: float) -> None:
    method on_task_switch (line 692) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method shared_modules (line 699) | def shared_modules(self) -> Dict[str, nn.Module]:
    method _are_batched (line 723) | def _are_batched(self, observations: IncrementalAssumption.Observation...

FILE: sequoia/methods/models/base_model/multihead_model.py
  class MultiHeadModel (line 22) | class MultiHeadModel(Model[SettingType]):
    class HParams (line 28) | class HParams(Model.HParams):
    method __init__ (line 34) | def __init__(self, setting: SettingType, hparams: HParams, config: Con...
    method output_head_loss (line 58) | def output_head_loss(
    method on_before_zero_grad (line 214) | def on_before_zero_grad(self, optimizer):
    method shared_step (line 222) | def shared_step(
    method on_task_switch (line 250) | def on_task_switch(self, task_id: Optional[int]):
    method shared_modules (line 274) | def shared_modules(self) -> Dict[str, nn.Module]:
    method load_state_dict (line 296) | def load_state_dict(
    method get_or_create_output_head (line 331) | def get_or_create_output_head(self, task_id: int) -> nn.Module:
    method forward (line 349) | def forward(self, observations: IncrementalAssumption.Observations) ->...
    method setup_for_task (line 410) | def setup_for_task(self, task_id: int) -> None:
    method split_forward_pass (line 415) | def split_forward_pass(self, observations: Observations) -> ForwardPass:
    method task_inference_forward_pass (line 474) | def task_inference_forward_pass(self, observations: Observations) -> T...
  function get_task_indices (line 563) | def get_task_indices(
  function cleanup_task_labels (line 614) | def cleanup_task_labels(

FILE: sequoia/methods/models/base_model/multihead_model_test.py
  function mixed_samples (line 29) | def mixed_samples(config: Config):
  class MockOutputHead (line 42) | class MockOutputHead(OutputHead):
    method __init__ (line 43) | def __init__(self, *args, Actions: Type, task_id: int = -1, **kwargs):
    method forward (line 49) | def forward(self, observations, representations) -> Tensor:  # type: i...
    method get_loss (line 68) | def get_loss(self, forward_pass, actions, rewards):
  function test_multiple_tasks_within_same_batch (line 88) | def test_multiple_tasks_within_same_batch(
  function test_multitask_rl_bug_without_PL (line 135) | def test_multitask_rl_bug_without_PL(monkeypatch):
  function test_multitask_rl_bug_with_PL (line 236) | def test_multitask_rl_bug_with_PL(monkeypatch, config: Config):
  function test_get_task_indices (line 364) | def test_get_task_indices(input, expected):
  function test_task_inference_sl (line 378) | def test_task_inference_sl(
  function test_task_inference_rl_easy (line 428) | def test_task_inference_rl_easy(config: Config):
  function test_task_inference_rl_hard (line 449) | def test_task_inference_rl_hard(config: Config):
  function test_task_inference_multi_task_sl (line 471) | def test_task_inference_multi_task_sl(config: Config):

FILE: sequoia/methods/models/base_model/self_supervised_model.py
  class SelfSupervisedModel (line 29) | class SelfSupervisedModel(Model[SettingType]):
    class HParams (line 38) | class HParams(Model.HParams):
    method __init__ (line 44) | def __init__(self, setting: Setting, hparams: HParams, config: Config):
    method get_loss (line 50) | def get_loss(
    method add_auxiliary_task (line 75) | def add_auxiliary_task(
    method create_auxiliary_tasks (line 93) | def create_auxiliary_tasks(self) -> Dict[str, AuxiliaryTask]:
    method on_task_switch (line 119) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method shared_modules (line 130) | def shared_modules(self) -> Dict[str, nn.Module]:

FILE: sequoia/methods/models/base_model/self_supervised_model_test.py
  function test_get_applicable_settings (line 25) | def test_get_applicable_settings():
  function method_and_coefficients (line 42) | def method_and_coefficients(request, tmp_path_factory):
  function test_fast_dev_run (line 73) | def test_fast_dev_run(
  function validate_results (line 90) | def validate_results(results: Results, aux_task_coefficients: Dict[str, ...

FILE: sequoia/methods/models/base_model/semi_supervised_model.py
  class SemiSupervisedModel (line 22) | class SemiSupervisedModel(Model[SettingType]):
    class HParams (line 24) | class HParams(Model.HParams):
    method get_loss (line 33) | def get_loss(

FILE: sequoia/methods/models/fcnet.py
  class FCNet (line 10) | class FCNet(nn.Sequential):
    class HParams (line 14) | class HParams(HyperParameters):
      method __post_init__ (line 38) | def __post_init__(self):
    method __init__ (line 67) | def __init__(self, in_features: int, out_features: int, hparams: HPara...
    method __init__ (line 71) | def __init__(
    method __init__ (line 81) | def __init__(self, in_features: int, out_features: int, hparams: HPara...

FILE: sequoia/methods/models/forward_pass.py
  class ForwardPass (line 14) | class ForwardPass(Batch, FlattenedAccess):
    method h_x (line 32) | def h_x(self) -> Any:

FILE: sequoia/methods/models/output_heads/classification_head.py
  class ClassificationOutput (line 26) | class ClassificationOutput(Actions):
    method action (line 36) | def action(self) -> LongTensor:
    method y_pred_log_prob (line 40) | def y_pred_log_prob(self) -> Tensor:
    method y_pred_prob (line 45) | def y_pred_prob(self) -> Tensor:
    method probabilities (line 50) | def probabilities(self) -> Tensor:
  class ClassificationHead (line 57) | class ClassificationHead(OutputHead):
    class HParams (line 59) | class HParams(FCNet.HParams, OutputHead.HParams):
    method __init__ (line 86) | def __init__(
    method forward (line 115) | def forward(self, observations: Observations, representations: Tensor)...
    method get_loss (line 129) | def get_loss(

FILE: sequoia/methods/models/output_heads/output_head.py
  class OutputHead (line 25) | class OutputHead(nn.Module, ABC):
    class HParams (line 41) | class HParams(HyperParameters, Parseable):
    method __init__ (line 44) | def __init__(
    method make_dense_network (line 64) | def make_dense_network(
    method forward (line 82) | def forward(
    method get_loss (line 102) | def get_loss(self, forward_pass: ForwardPass, actions: Actions, reward...
    method clear_all_buffers (line 108) | def clear_all_buffers(self) -> None:
    method upgrade_hparams (line 114) | def upgrade_hparams(self):

FILE: sequoia/methods/models/output_heads/regression_head.py
  class RegressionHead (line 17) | class RegressionHead(OutputHead):
    class HParams (line 21) | class HParams(FCNet.HParams, OutputHead.HParams):
    method __init__ (line 24) | def __init__(
    method forward (line 63) | def forward(self, observations: Observations, representations: Tensor)...
    method get_loss (line 67) | def get_loss(self, forward_pass: ForwardPass, actions: Actions, reward...

FILE: sequoia/methods/models/output_heads/rl/actor_critic_head.py
  class ActorCriticHead (line 29) | class ActorCriticHead(ClassificationHead):
    class HParams (line 31) | class HParams(ClassificationHead.HParams):
    method __init__ (line 37) | def __init__(
    method forward (line 86) | def forward(
    method get_loss (line 119) | def get_loss(
  function concat_obs_and_action (line 173) | def concat_obs_and_action(observation_action: Tuple[Tensor, Tensor]) -> ...

FILE: sequoia/methods/models/output_heads/rl/episodic_a2c.py
  class A2CHeadOutput (line 27) | class A2CHeadOutput(PolicyHeadOutput):
  class EpisodicA2C (line 34) | class EpisodicA2C(PolicyHead):
    class HParams (line 45) | class HParams(PolicyHead.HParams):
    method __init__ (line 61) | def __init__(
    method actor (line 93) | def actor(self) -> nn.Module:
    method forward (line 96) | def forward(
    method num_stored_steps (line 113) | def num_stored_steps(self, env_index: int) -> Optional[int]:
    method get_episode_loss (line 123) | def get_episode_loss(self, env_index: int, done: bool) -> Optional[Loss]:
    method optimizer_step (line 200) | def optimizer_step(self):
  function compute_returns_and_advantage (line 211) | def compute_returns_and_advantage(self, last_values: Tensor, dones: np.n...

FILE: sequoia/methods/models/output_heads/rl/episodic_a2c_test.py
  class FakeEnvironment (line 24) | class FakeEnvironment(SyncVectorEnv):
    method __init__ (line 25) | def __init__(
    method step (line 44) | def step(self, actions):
  function test_with_controllable_episode_lengths (line 67) | def test_with_controllable_episode_lengths(batch_size: int, monkeypatch):
  function test_loss_is_nonzero_at_episode_end (line 200) | def test_loss_is_nonzero_at_episode_end(batch_size: int):
  function test_loss_is_nonzero_at_episode_end_iterate (line 282) | def test_loss_is_nonzero_at_episode_end_iterate(batch_size: int):
  function test_buffers_are_stacked_correctly (line 356) | def test_buffers_are_stacked_correctly(monkeypatch):

FILE: sequoia/methods/models/output_heads/rl/policy_head.py
  class PolicyHeadOutput (line 56) | class PolicyHeadOutput(ClassificationOutput):
    method y_pred_prob (line 65) | def y_pred_prob(self) -> Tensor:
    method y_pred_log_prob (line 70) | def y_pred_log_prob(self) -> Tensor:
    method action_log_prob (line 75) | def action_log_prob(self) -> Tensor:
    method action_prob (line 79) | def action_prob(self) -> Tensor:
  class PolicyHead (line 94) | class PolicyHead(ClassificationHead):
    class HParams (line 108) | class HParams(ClassificationHead.HParams):
    method __init__ (line 136) | def __init__(
    method create_buffers (line 191) | def create_buffers(self):
    method forward (line 203) | def forward(
    method to (line 238) | def to(self: T, device: Optional[Union[int, torch.device]] = None, **k...
    method get_loss (line 244) | def get_loss(
    method on_episode_end (line 373) | def on_episode_end(self, env_index: int) -> None:
    method get_episode_loss (line 378) | def get_episode_loss(self, env_index: int, done: bool) -> Optional[Loss]:
    method get_gradient_usage_metrics (line 434) | def get_gradient_usage_metrics(self, env_index: int) -> GradientUsageM...
    method get_returns (line 452) | def get_returns(rewards: Union[Tensor, List[Tensor]], gamma: float) ->...
    method policy_gradient (line 459) | def policy_gradient(
    method training (line 486) | def training(self) -> bool:
    method training (line 490) | def training(self, value: bool) -> None:
    method clear_all_buffers (line 503) | def clear_all_buffers(self) -> None:
    method clear_buffers (line 516) | def clear_buffers(self, env_index: int) -> None:
    method detach_all_buffers (line 522) | def detach_all_buffers(self):
    method detach_buffers (line 530) | def detach_buffers(self, env_index: int) -> None:
    method _detach_buffer (line 544) | def _detach_buffer(self, old_buffer: Sequence[Tensor]) -> deque:
    method _make_buffer (line 551) | def _make_buffer(self, elements: Sequence[T] = None) -> Deque[T]:
    method _make_buffers (line 557) | def _make_buffers(self) -> List[deque]:
    method stack_buffers (line 560) | def stack_buffers(self, env_index: int):
  function discounted_sum_of_future_rewards (line 586) | def discounted_sum_of_future_rewards(rewards: Union[Tensor, List[Tensor]...
  function vanilla_policy_gradient (line 605) | def vanilla_policy_gradient(
  function make_gamma_matrix (line 643) | def make_gamma_matrix(gamma: float, T: int, device=None) -> Tensor:
  function normalize (line 659) | def normalize(x: Tensor):
  function tuple_of_lists (line 666) | def tuple_of_lists(list_of_tuples: List[Tuple[T, ...]]) -> Tuple[List[T]...
  function list_of_tuples (line 670) | def list_of_tuples(tuple_of_lists: Tuple[List[T], ...]) -> List[Tuple[T,...

FILE: sequoia/methods/models/output_heads/rl/policy_head_test.py
  class FakeEnvironment (line 29) | class FakeEnvironment(SyncVectorEnv):
    method __init__ (line 30) | def __init__(
    method step (line 49) | def step(self, actions):
  function test_with_controllable_episode_lengths (line 71) | def test_with_controllable_episode_lengths(batch_size: int, monkeypatch):
  function test_loss_is_nonzero_at_episode_end (line 209) | def test_loss_is_nonzero_at_episode_end(batch_size: int):
  function test_done_is_sometimes_True_when_iterating_through_env (line 292) | def test_done_is_sometimes_True_when_iterating_through_env(batch_size: i...
  function test_loss_is_nonzero_at_episode_end_iterate (line 308) | def test_loss_is_nonzero_at_episode_end_iterate(batch_size: int):
  function test_buffers_are_stacked_correctly (line 381) | def test_buffers_are_stacked_correctly(monkeypatch):
  function test_sanity_check_cartpole_done_vector (line 519) | def test_sanity_check_cartpole_done_vector():

FILE: sequoia/methods/models/output_heads/rl/wasted_steps_calc.py
  function get_fraction_of_observations_with_grad (line 7) | def get_fraction_of_observations_with_grad(

FILE: sequoia/methods/models/simple_convnet.py
  class SimpleConvNet (line 4) | class SimpleConvNet(nn.Module):
    method __init__ (line 5) | def __init__(self, in_channels: int = 3, n_classes: int = 10):
    method forward (line 35) | def forward(self, x: Tensor) -> Tensor:

FILE: sequoia/methods/packnet_method.py
  class PackNet (line 19) | class PackNet(Callback, nn.Module):
    class HParams (line 26) | class HParams(HyperParameters):
    method __init__ (line 34) | def __init__(
    method filtered_parameter_iterator (line 81) | def filtered_parameter_iterator(self, module: nn.Module) -> Iterable[T...
    method prune (line 112) | def prune(self, model: nn.Module, prune_quantile: float) -> Dict[str, ...
    method fine_tune_mask (line 169) | def fine_tune_mask(self, model: nn.Module):
    method training_mask (line 178) | def training_mask(self, model: nn.Module):
    method fix_biases (line 201) | def fix_biases(self, model: nn.Module):
    method fix_batch_norm (line 214) | def fix_batch_norm(self, model: nn.Module):
    method set_params_dict (line 224) | def set_params_dict(self, model: nn.Module):
    method fix_all_layers (line 234) | def fix_all_layers(self, model: nn.Module):
    method apply_eval_mask (line 248) | def apply_eval_mask(self, model: nn.Module, task_idx: int):
    method mask_remaining_params (line 266) | def mask_remaining_params(self, model: nn.Module) -> Dict[str, Tensor]:
    method total_epochs (line 282) | def total_epochs(self) -> int:
    method config_instructions (line 285) | def config_instructions(self):
    method save_final_state (line 299) | def save_final_state(self, model, PATH="model_weights.pth"):
    method load_final_state (line 308) | def load_final_state(self, model):
    method on_init_end (line 316) | def on_init_end(self, trainer: Trainer):
    method on_after_backward (line 319) | def on_after_backward(self, trainer: Trainer, pl_module: LightningModu...
    method on_train_epoch_end (line 326) | def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningMod...
    method on_fit_end (line 340) | def on_fit_end(self, trainer: Trainer, pl_module: LightningModule):
  class PackNetMethod (line 353) | class PackNetMethod(BaseMethod, target_setting=IncrementalSLSetting):
    method __init__ (line 364) | def __init__(
    method configure (line 376) | def configure(self, setting: Setting):
    method fit (line 399) | def fit(self, train_env, valid_env):
    method on_task_switch (line 402) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method configure_callbacks (line 416) | def configure_callbacks(self, setting: TaskIncrementalSLSetting = None...
    method create_trainer (line 442) | def create_trainer(self, setting) -> Trainer:
    method adapt_to_new_hparams (line 453) | def adapt_to_new_hparams(self, new_hparams: Dict[str, Any]) -> None:
    method get_search_space (line 468) | def get_search_space(self, setting: Setting) -> Mapping[str, Union[str...

FILE: sequoia/methods/packnet_method_test.py
  class TestPackNetMethod (line 7) | class TestPackNetMethod(BaseMethodTests):
    method validate_results (line 10) | def validate_results(self, setting, method, results):

FILE: sequoia/methods/pl_dqn.py
  class DQN (line 72) | class DQN(nn.Module):
    method __init__ (line 75) | def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 1...
    method forward (line 89) | def forward(self, x: Tensor) -> Tensor:
  class Experience (line 98) | class Experience(Generic[T]):
  class ExperienceBatch (line 109) | class ExperienceBatch(Generic[T]):
    method __len__ (line 121) | def __len__(self) -> int:
    method __getitem__ (line 124) | def __getitem__(self, index: Union[int, slice]) -> Union[Experience[T]...
    method stack (line 142) | def stack(cls, items: Sequence["Experience[T]"]) -> "ExperienceBatch[T]":
    method _map (line 159) | def _map(self, fn: Callable[[T], V]) -> "ExperienceBatch[V]":
    method numpy (line 164) | def numpy(self) -> "ExperienceBatch[np.ndarray]":
    method to (line 170) | def to(self, device: torch.device = None, **kwargs) -> "ExperienceBatc...
  class ReplayBuffer (line 177) | class ReplayBuffer(Generic[T]):
    method __init__ (line 183) | def __init__(self, capacity: int) -> None:
    method __len__ (line 190) | def __len__(self) -> int:
    method append (line 193) | def append(self, experience: Experience[T]) -> None:
    method sample (line 201) | def sample(
  class RLDataset (line 210) | class RLDataset(IterableDataset[ExperienceBatch[T]]):
    method __init__ (line 217) | def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
    method __iter__ (line 226) | def __iter__(self) -> Iterator[Experience[T]]:
  class Agent (line 233) | class Agent:
    method __init__ (line 243) | def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
    method reset (line 254) | def reset(self) -> None:
    method get_action (line 258) | def get_action(self, state: Tensor, net: nn.Module, epsilon: float) ->...
    method play_step (line 280) | def play_step(
  class DQNLightning (line 319) | class DQNLightning(pl.LightningModule):
    class HParams (line 328) | class HParams(Serializable):
    method __init__ (line 359) | def __init__(self, env: Union[str, gym.Env[np.ndarray, int]], hp: HPar...
    method populate (line 394) | def populate(self, steps: int = 1000) -> None:
    method forward (line 408) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method dqn_mse_loss (line 420) | def dqn_mse_loss(self, batch: ExperienceBatch[Tensor]) -> torch.Tensor:
    method training_step (line 446) | def training_step(self, batch: ExperienceBatch[Tensor], batch_idx: int...
    method configure_optimizers (line 494) | def configure_optimizers(self) -> List[Optimizer]:
    method __dataloader (line 499) | def __dataloader(self) -> DataLoader:
    method train_dataloader (line 510) | def train_dataloader(self) -> DataLoader:
    method get_device (line 514) | def get_device(self, batch) -> str:
    method add_model_specific_args (line 519) | def add_model_specific_args(cls, parent_parser: ArgumentParser):  # pr...
  function get_max_episode_length (line 524) | def get_max_episode_length(env: Union[gym.Env, gym.Wrapper]) -> Optional...
  class PlDqnMethod (line 543) | class PlDqnMethod(Method, target_setting=RLSetting):
    method __init__ (line 544) | def __init__(self, hp: DQNLightning.HParams = None) -> None:
    method configure (line 549) | def configure(self, setting: RLSetting) -> None:
    method fit (line 553) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 580) | def get_actions(self, observations: Observations, action_space: Discre...
  function main (line 593) | def main() -> None:

FILE: sequoia/methods/pnn/layers.py
  class PNNConvLayer (line 9) | class PNNConvLayer(nn.Module):
    method __init__ (line 10) | def __init__(self, col, depth, n_in, n_out, kernel_size=3):
    method forward (line 21) | def forward(self, inputs):
  class PNNLinearBlock (line 31) | class PNNLinearBlock(nn.Module):
    method __init__ (line 32) | def __init__(self, col: int, depth: int, n_in: int, n_out: int):
    method forward (line 40) | def forward(self, inputs):

FILE: sequoia/methods/pnn/model_rl.py
  class PnnA2CAgent (line 11) | class PnnA2CAgent(nn.Module):
    method __init__ (line 21) | def __init__(self, arch="mlp", hidden_size=256):
    method forward (line 40) | def forward(self, observations):
    method new_task (line 93) | def new_task(self, device, num_inputs, num_actions=5):
    method unfreeze_columns (line 133) | def unfreeze_columns(self):
    method freeze_columns (line 145) | def freeze_columns(self, skip: List[int] = None):
    method parameters (line 166) | def parameters(self, task_id):
    method transfor_img (line 179) | def transfor_img(self, img):

FILE: sequoia/methods/pnn/model_sl.py
  class PnnClassifier (line 16) | class PnnClassifier(nn.Module):
    method __init__ (line 26) | def __init__(self, n_layers):
    method forward (line 36) | def forward(self, observations: Observations):
    method new_task (line 90) | def new_task(self, device, sizes: List[int]):
    method freeze_columns (line 111) | def freeze_columns(self, skip: List[int] = None):
    method shared_step (line 126) | def shared_step(
    method parameters (line 177) | def parameters(self, task_id):

FILE: sequoia/methods/pnn/pnn_method.py
  class PnnMethod (line 48) | class PnnMethod(Method, target_setting=IncrementalAssumption):
    class HParams (line 57) | class HParams(HyperParameters):
    method __init__ (line 73) | def __init__(self, hparams: HParams = None):
    method configure (line 81) | def configure(self, setting: Setting):
    method on_task_switch (line 160) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method set_optimizer (line 184) | def set_optimizer(self):
    method get_actions (line 190) | def get_actions(self, observations: Observations, action_space: spaces...
    method fit (line 209) | def fit(self, train_env: Environment, valid_env: Environment):
    method fit_rl (line 221) | def fit_rl(self, train_env: gym.Env, valid_env: gym.Env):
    method fit_sl (line 297) | def fit_sl(self, train_env: PassiveEnvironment, valid_env: PassiveEnvi...
    method add_argparse_args (line 343) | def add_argparse_args(cls, parser: ArgumentParser) -> None:
    method from_argparse_args (line 347) | def from_argparse_args(cls, args: Namespace) -> "PnnMethod":
    method get_search_space (line 352) | def get_search_space(self, setting: Setting) -> Mapping[str, Union[str...
    method adapt_to_new_hparams (line 368) | def adapt_to_new_hparams(self, new_hparams: Dict[str, Any]) -> None:
    method setup_wandb (line 386) | def setup_wandb(self, run: Run) -> None:
  function main_rl (line 403) | def main_rl():
  function main_sl (line 436) | def main_sl():

FILE: sequoia/methods/random_baseline.py
  class RandomBaselineMethod (line 25) | class RandomBaselineMethod(Method, target_setting=Setting):
    method __init__ (line 32) | def __init__(self):
    method configure (line 35) | def configure(self, setting: Setting):
    method fit (line 45) | def fit(
    method get_actions (line 83) | def get_actions(self, observations: Observations, action_space: gym.Sp...
    method get_search_space (line 86) | def get_search_space(self, setting: Setting) -> Mapping[str, Union[str...
    method adapt_to_new_hparams (line 109) | def adapt_to_new_hparams(self, new_hparams: Dict[str, Any]) -> None:
    method add_argparse_args (line 125) | def add_argparse_args(cls, parser: ArgumentParser):
    method from_argparse_args (line 129) | def from_argparse_args(cls, args: Namespace):

FILE: sequoia/methods/random_baseline_test.py
  function test_is_applicable_to_all_settings (line 23) | def test_is_applicable_to_all_settings():

FILE: sequoia/methods/stable_baselines3_methods/a2c.py
  class A2CModel (line 24) | class A2CModel(A2C, OnPolicyModel):
    class HParams (line 37) | class HParams(OnPolicyModel.HParams):
  class A2CMethod (line 126) | class A2CMethod(OnPolicyMethod):
    method configure (line 137) | def configure(self, setting: ContinualRLSetting):
    method create_model (line 153) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> A2CM...
    method fit (line 156) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 159) | def get_actions(
    method on_task_switch (line 167) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method get_search_space (line 178) | def get_search_space(self, setting: ContinualRLSetting) -> Mapping[str...

FILE: sequoia/methods/stable_baselines3_methods/a2c_test.py
  class TestA2C (line 8) | class TestA2C(DiscreteActionSpaceMethodTests):

FILE: sequoia/methods/stable_baselines3_methods/base.py
  class RemoveInfoWrapper (line 78) | class RemoveInfoWrapper(gym.Wrapper):
    method step (line 83) | def step(self, action):
  class SB3BaseHParams (line 90) | class SB3BaseHParams(HyperParameters):
  class StableBaselines3Method (line 140) | class StableBaselines3Method(Method, ABC, target_setting=ContinualRLSett...
    method __post_init__ (line 179) | def __post_init__(self):
    method configure (line 205) | def configure(self, setting: ContinualRLSetting):
    method create_model (line 253) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> Base...
    method fit (line 259) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 321) | def get_actions(
    method get_search_space (line 330) | def get_search_space(self, setting: Setting) -> Mapping[str, Union[str...
    method adapt_to_new_hparams (line 348) | def adapt_to_new_hparams(self, new_hparams: Dict[str, Any]) -> None:
    method setup_wandb (line 366) | def setup_wandb(self, run: Run) -> None:
    method on_task_switch (line 382) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method clear_buffers (line 394) | def clear_buffers(self):

FILE: sequoia/methods/stable_baselines3_methods/base_test.py
  class StableBaselines3MethodTests (line 31) | class StableBaselines3MethodTests(MethodTests):
    method test_clear_buffers_between_tasks (line 38) | def test_clear_buffers_between_tasks(self, clear_buffers: bool, config...
    method test_hparams_have_same_defaults_as_in_sb3 (line 68) | def test_hparams_have_same_defaults_as_in_sb3(
    method method (line 108) | def method(cls, config: Config) -> StableBaselines3Method:
    method validate_results (line 112) | def validate_results(
    method test_debug (line 123) | def test_debug(self, method: StableBaselines3Method, setting: RLSettin...
  class DiscreteActionSpaceMethodTests (line 130) | class DiscreteActionSpaceMethodTests(StableBaselines3MethodTests):
    method test_monsterkong (line 137) | def test_monsterkong(self):
  function get_current_length_of_replay_buffer (line 155) | def get_current_length_of_replay_buffer(algo: BaseAlgorithm) -> int:
  function _ (line 161) | def _(algo: OffPolicyAlgorithm):
  function _ (line 166) | def _(algo: OnPolicyAlgorithm):
  class ContinuousActionSpaceMethodTests (line 171) | class ContinuousActionSpaceMethodTests(StableBaselines3MethodTests):

FILE: sequoia/methods/stable_baselines3_methods/ddpg.py
  class DDPGModel (line 23) | class DDPGModel(DDPG, OffPolicyModel):
    class HParams (line 27) | class HParams(OffPolicyModel.HParams):
  class DDPGMethod (line 52) | class DDPGMethod(OffPolicyMethod):
    method configure (line 63) | def configure(self, setting: ContinualRLSetting):
    method create_model (line 66) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> DDPG...
    method fit (line 69) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 72) | def get_actions(
    method on_task_switch (line 80) | def on_task_switch(self, task_id: Optional[int]) -> None:

FILE: sequoia/methods/stable_baselines3_methods/ddpg_test.py
  class TestDDPG (line 11) | class TestDDPG(ContinuousActionSpaceMethodTests):

FILE: sequoia/methods/stable_baselines3_methods/dqn.py
  class DQNModel (line 24) | class DQNModel(DQN, OffPolicyModel):
    class HParams (line 28) | class HParams(OffPolicyModel.HParams):
    method train (line 75) | def train(self, gradient_steps: int, batch_size: int = 100) -> None:
  class DQNMethod (line 81) | class DQNMethod(OffPolicyMethod):
    method configure (line 92) | def configure(self, setting: ContinualRLSetting):
    method create_model (line 107) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> DQNM...
    method fit (line 110) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 113) | def get_actions(
    method on_task_switch (line 125) | def on_task_switch(self, task_id: Optional[int]) -> None:

FILE: sequoia/methods/stable_baselines3_methods/dqn_test.py
  class TestDQN (line 17) | class TestDQN(DiscreteActionSpaceMethodTests, OffPolicyMethodTests):
    method test_classic_control_state (line 25) | def test_classic_control_state(self, config: Config):
    method test_incremental_classic_control_state (line 29) | def test_incremental_classic_control_state(self, config: Config):
    method test_dqn_monsterkong_adds_channel_first_transform (line 32) | def test_dqn_monsterkong_adds_channel_first_transform(self):

FILE: sequoia/methods/stable_baselines3_methods/off_policy_method.py
  function decode_trainfreq (line 25) | def decode_trainfreq(v: Any):
  class OffPolicyModel (line 34) | class OffPolicyModel(OffPolicyAlgorithm, ABC):
    class HParams (line 38) | class HParams(SB3BaseHParams):
  class OffPolicyMethod (line 89) | class OffPolicyMethod(StableBaselines3Method, ABC):
    method __post_init__ (line 99) | def __post_init__(self):
    method configure (line 103) | def configure(self, setting: ContinualRLSetting):
    method create_model (line 218) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> OffP...
    method fit (line 221) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 224) | def get_actions(
    method on_task_switch (line 232) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method clear_buffers (line 243) | def clear_buffers(self):

FILE: sequoia/methods/stable_baselines3_methods/off_policy_method_test.py
  class OffPolicyMethodTests (line 6) | class OffPolicyMethodTests:

FILE: sequoia/methods/stable_baselines3_methods/on_policy_method.py
  class OnPolicyModel (line 24) | class OnPolicyModel(OnPolicyAlgorithm, ABC):
    class HParams (line 28) | class HParams(SB3BaseHParams):
  class OnPolicyMethod (line 98) | class OnPolicyMethod(StableBaselines3Method, ABC):
    method configure (line 106) | def configure(self, setting: ContinualRLSetting):
    method create_model (line 131) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> OnPo...
    method fit (line 135) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 138) | def get_actions(
    method on_task_switch (line 146) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method clear_buffers (line 157) | def clear_buffers(self):
    method get_search_space (line 166) | def get_search_space(self, setting: ContinualRLSetting) -> Mapping[str...

FILE: sequoia/methods/stable_baselines3_methods/policy_wrapper.py
  class PolicyWrapper (line 22) | class PolicyWrapper(BasePolicy, ABC, Generic[Policy]):
    method __init__ (line 36) | def __init__(self, *args, _already_initialized: bool = False, **kwargs):
    method get_loss (line 43) | def get_loss(self: Policy) -> Union[float, Tensor]:
    method before_optimizer_step (line 50) | def before_optimizer_step(self: Policy):
    method after_zero_grad (line 55) | def after_zero_grad(self: Policy):
    method wrap_policy (line 67) | def wrap_policy(
    method wrap_policy_class (line 120) | def wrap_policy_class(
    method wrap_algorithm (line 141) | def wrap_algorithm(cls: Type[Wrapper], algo: SB3Algo, **wrapper_kwargs...
    method wrap_algorithm_class (line 164) | def wrap_algorithm_class(
  class A2CWithEWC (line 215) | class A2CWithEWC(A2C):
    method __init__ (line 216) | def __init__(self, *args, ewc_coefficient: float = 1.0, ewc_p_norm: in...
    method _setup_model (line 222) | def _setup_model(self):
    method on_task_switch (line 231) | def on_task_switch(self, task_id: Optional[int]) -> None:

FILE: sequoia/methods/stable_baselines3_methods/ppo.py
  class PPOModel (line 23) | class PPOModel(PPO, OnPolicyModel):
    class HParams (line 36) | class HParams(OnPolicyModel.HParams):
  class PPOMethod (line 134) | class PPOMethod(OnPolicyMethod):
    method configure (line 141) | def configure(self, setting: ContinualRLSetting):
    method create_model (line 144) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> PPOM...
    method fit (line 148) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 151) | def get_actions(
    method on_task_switch (line 159) | def on_task_switch(self, task_id: Optional[int]) -> None:
    method get_search_space (line 170) | def get_search_space(self, setting: ContinualRLSetting) -> Mapping[str...

FILE: sequoia/methods/stable_baselines3_methods/ppo_test.py
  class TestPPO (line 8) | class TestPPO(DiscreteActionSpaceMethodTests):

FILE: sequoia/methods/stable_baselines3_methods/sac.py
  class SACModel (line 22) | class SACModel(SAC, OffPolicyModel):
    class HParams (line 26) | class HParams(OffPolicyModel.HParams):
  class SACMethod (line 50) | class SACMethod(OffPolicyMethod):
    method configure (line 61) | def configure(self, setting: ContinualRLSetting):
    method create_model (line 64) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> SACM...
    method fit (line 67) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 70) | def get_actions(
    method on_task_switch (line 78) | def on_task_switch(self, task_id: Optional[int]) -> None:

FILE: sequoia/methods/stable_baselines3_methods/sac_test.py
  class TestSAC (line 17) | class TestSAC(ContinuousActionSpaceMethodTests):
    method test_continuous_mountaincar (line 29) | def test_continuous_mountaincar(self, Setting: Type[Setting], observe_...

FILE: sequoia/methods/stable_baselines3_methods/td3.py
  class TD3Model (line 21) | class TD3Model(TD3, OffPolicyModel):
    class HParams (line 23) | class HParams(OffPolicyModel.HParams):
  class TD3Method (line 48) | class TD3Method(OffPolicyMethod):
    method configure (line 57) | def configure(self, setting: ContinualRLSetting):
    method create_model (line 60) | def create_model(self, train_env: gym.Env, valid_env: gym.Env) -> TD3M...
    method fit (line 63) | def fit(self, train_env: gym.Env, valid_env: gym.Env):
    method get_actions (line 66) | def get_actions(
    method on_task_switch (line 74) | def on_task_switch(self, task_id: Optional[int]) -> None:

FILE: sequoia/methods/stable_baselines3_methods/td3_test.py
  class TestTD3 (line 8) | class TestTD3(ContinuousActionSpaceMethodTests):

FILE: sequoia/methods/trainer.py
  class TrainerConfig (line 36) | class TrainerConfig(HyperParameters, Parseable):
    method make_trainer (line 79) | def make_trainer(
  class Trainer (line 120) | class Trainer(_Trainer):
    method __init__ (line 121) | def __init__(self, **kwargs):
    method fit (line 124) | def fit(self, model, train_dataloader=None, val_dataloaders=None, data...
  function _apply_to_batch (line 160) | def _apply_to_batch(
  class ProfiledEnvironment (line 178) | class ProfiledEnvironment(IterableWrapper, DataLoader):
    method __iter__ (line 179) | def __iter__(self):
  class PatchedDataConnector (line 193) | class PatchedDataConnector(DataConnector):
    method get_profiled_train_dataloader (line 194) | def get_profiled_train_dataloader(self, train_dataloader: DataLoader):

FILE: sequoia/settings/assumptions/base.py
  class AssumptionBase (line 17) | class AssumptionBase(SettingABC):

FILE: sequoia/settings/assumptions/context_discreteness.py
  class ContinuousContextAssumption (line 9) | class ContinuousContextAssumption(AssumptionBase):
  class DiscreteContextAssumption (line 16) | class DiscreteContextAssumption(ContinuousContextAssumption):

FILE: sequoia/settings/assumptions/context_visibility.py
  class HiddenContextAssumption (line 9) | class HiddenContextAssumption(AssumptionBase):
  class PartiallyObservableContextAssumption (line 23) | class PartiallyObservableContextAssumption(HiddenContextAssumption):
  class FullyObservableContextAssumption (line 33) | class FullyObservableContextAssumption(PartiallyObservableContextAssumpt...

FILE: sequoia/settings/assumptions/continual.py
  class ContinualResults (line 34) | class ContinualResults(TaskResults[MetricsType]):
    method online_performance (line 39) | def online_performance(self) -> Dict[int, MetricsType]:
    method online_performance_metrics (line 55) | def online_performance_metrics(self) -> MetricsType:
    method to_log_dict (line 58) | def to_log_dict(self, verbose: bool = False) -> Dict:
    method summary (line 67) | def summary(self, verbose: bool = False) -> str:
  class ContinualAssumption (line 75) | class ContinualAssumption(AssumptionBase):
    class Observations (line 97) | class Observations(AssumptionBase.Observations):
    class Actions (line 101) | class Actions(AssumptionBase.Actions):
    class Rewards (line 105) | class Rewards(AssumptionBase.Rewards):
    method main_loop (line 120) | def main_loop(self, method: Method) -> ContinualResults:
    method test_loop (line 159) | def test_loop(self, method: Method) -> "IncrementalAssumption.Results":
    method setup_wandb (line 248) | def setup_wandb(self, method: Method) -> Run:
    method log_results (line 291) | def log_results(self, method: Method, results: Results, prefix: str = ...
    method phases (line 329) | def phases(self) -> int:
  class TestEnvironment (line 345) | class TestEnvironment(gym.wrappers.Monitor, IterableWrapper[EnvType], ABC):
    method __init__ (line 350) | def __init__(
    method step (line 371) | def step(self, action):
    method reset (line 379) | def reset(self, **kwargs):
    method get_results (line 386) | def get_results(self) -> Results:
    method step (line 404) | def step(self, action):

FILE: sequoia/settings/assumptions/discrete_results.py
  class TaskSequenceResults (line 16) | class TaskSequenceResults(Results, Generic[MetricType]):
    method __post_init__ (line 25) | def __post_init__(self):
    method objective_name (line 33) | def objective_name(self) -> str:
    method num_tasks (line 37) | def num_tasks(self) -> int:
    method average_metrics (line 48) | def average_metrics(self) -> MetricType:
    method average_metrics_per_task (line 52) | def average_metrics_per_task(self) -> List[MetricType]:
    method objective (line 56) | def objective(self) -> float:
    method to_log_dict (line 59) | def to_log_dict(self, verbose: bool = False) -> Dict:
    method summary (line 66) | def summary(self, verbose: bool = False):
    method make_plots (line 72) | def make_plots(self) -> Dict[str, plt.Figure]:

FILE: sequoia/settings/assumptions/iid.py
  class TraditionalSetting (line 14) | class TraditionalSetting(TaskIncrementalAssumption):
    method phases (line 22) | def phases(self) -> int:

FILE: sequoia/settings/assumptions/iid_results.py
  class TaskResults (line 14) | class TaskResults(Results, Generic[MetricType]):
    method __post_init__ (line 27) | def __post_init__(self):
    method __str__ (line 33) | def __str__(self) -> str:
    method __repr__ (line 36) | def __repr__(self) -> str:
    method average_metrics (line 40) | def average_metrics(self) -> MetricType:
    method objective (line 45) | def objective(self) -> float:
    method objective_name (line 59) | def objective_name(self) -> str:
    method __str__ (line 63) | def __str__(self):
    method to_log_dict (line 66) | def to_log_dict(self, verbose: bool = False) -> Dict:
    method summary (line 84) | def summary(self) -> str:
    method make_plots (line 87) | def make_plots(self) -> Dict[str, plt.Figure]:

FILE: sequoia/settings/assumptions/incremental.py
  class IncrementalAssumption (line 26) | class IncrementalAssumption(ContinualAssumption):
    class Observations (line 44) | class Observations(Setting.Observations):
    method __post_init__ (line 78) | def __post_init__(self):
    method phases (line 92) | def phases(self) -> int:
    method current_task_id (line 102) | def current_task_id(self) -> Optional[int]:
    method current_task_id (line 115) | def current_task_id(self, value: int) -> None:
    method task_boundary_reached (line 119) | def task_boundary_reached(self, method: Method, task_id: int, training...
    method main_loop (line 151) | def main_loop(self, method: Method) -> IncrementalResults:
    method test_loop (line 220) | def test_loop(self, method: Method) -> "IncrementalAssumption.Results":
    method train_dataloader (line 298) | def train_dataloader(
    method val_dataloader (line 305) | def val_dataloader(
    method test_dataloader (line 314) | def test_dataloader(
    method _get_objective_scaling_factor (line 320) | def _get_objective_scaling_factor(self) -> float:

FILE: sequoia/settings/assumptions/incremental_results.py
  class IncrementalResults (line 23) | class IncrementalResults(Results, Generic[MetricType]):
    method __post_init__ (line 38) | def __post_init__(self):
    method runtime_minutes (line 45) | def runtime_minutes(self) -> Optional[float]:
    method runtime_hours (line 49) | def runtime_hours(self) -> Optional[float]:
    method transfer_matrix (line 53) | def transfer_matrix(self) -> List[List[TaskResults]]:
    method metrics_matrix (line 59) | def metrics_matrix(self) -> List[List[MetricType]]:
    method objective_matrix (line 77) | def objective_matrix(self) -> List[List[float]]:
    method cl_score (line 94) | def cl_score(self) -> float:
    method _runtime_score (line 115) | def _runtime_score(self) -> float:
    method _online_performance_score (line 144) | def _online_performance_score(self) -> float:
    method _final_performance_score (line 154) | def _final_performance_score(self) -> float:
    method objective (line 165) | def objective(self) -> float:
    method num_tasks (line 170) | def num_tasks(self) -> int:
    method online_performance (line 174) | def online_performance(self) -> List[Dict[int, MetricType]]:
    method online_performance_metrics (line 193) | def online_performance_metrics(self) -> List[MetricType]:
    method final_performance (line 200) | def final_performance(self) -> List[TaskResults[MetricType]]:
    method final_performance_metrics (line 204) | def final_performance_metrics(self) -> List[MetricType]:
    method average_online_performance (line 208) | def average_online_performance(self) -> MetricType:
    method average_final_performance (line 212) | def average_final_performance(self) -> MetricType:
    method to_log_dict (line 215) | def to_log_dict(self, verbose: bool = False) -> Dict:
    method summary (line 238) | def summary(self, verbose: bool = False):
    method make_plots (line 246) | def make_plots(self) -> Dict[str, Union[plt.Figure, Dict]]:
    method __str__ (line 273) | def __str__(self) -> str:

FILE: sequoia/settings/assumptions/incremental_test.py
  class DummyMethod (line 14) | class DummyMethod(Method, target_setting=IncrementalAssumption):
    method __init__ (line 19) | def __init__(self):
    method fit (line 27) | def fit(self, train_env: gym.Env = None, valid_env: gym.Env = None):
    method test (line 39) | def test(self, test_env: TestEnvironment):
    method get_actions (line 47) | def get_actions(
    method on_task_switch (line 52) | def on_task_switch(self, task_id: int = None):
  class OtherDummyMethod (line 58) | class OtherDummyMethod(Method, target_setting=IncrementalAssumption):
    method __init__ (line 59) | def __init__(self):
    method fit (line 62) | def fit(self, train_env: Environment, valid_env: Environment):
    method get_actions (line 98) | def get_actions(self, observations: Observations, action_space: Space)...

FILE: sequoia/settings/assumptions/task_incremental.py
  class TaskIncrementalAssumption (line 10) | class TaskIncrementalAssumption(FullyObservableContextAssumption, Increm...

FILE: sequoia/settings/assumptions/task_type.py
  class ClassificationActions (line 10) | class ClassificationActions(Actions):
    method action (line 20) | def action(self) -> LongTensor:
    method y_pred_log_prob (line 24) | def y_pred_log_prob(self) -> Tensor:
    method y_pred_prob (line 29) | def y_pred_prob(self) -> Tensor:
    method probabilities (line 34) | def probabilities(self) -> Tensor:

FILE: sequoia/settings/base/bases.py
  class SettingABC (line 52) | class SettingABC:
    method apply (line 78) | def apply(self, method: "Method", config: "Config" = None) -> "Setting...
    method prepare_data (line 120) | def prepare_data(self, *args, **kwargs):
    method setup (line 124) | def setup(self, stage: Optional[str] = None):
    method train_dataloader (line 128) | def train_dataloader(self, *args, **kwargs) -> Environment[Observation...
    method val_dataloader (line 132) | def val_dataloader(self, *args, **kwargs) -> Environment[Observations,...
    method test_dataloader (line 136) | def test_dataloader(self, *args, **kwargs) -> Environment[Observations...
    method get_available_datasets (line 141) | def get_available_datasets(cls) -> Iterable[str]:
    method __init_subclass__ (line 154) | def __init_subclass__(cls, **kwargs):
    method get_applicable_methods (line 168) | def get_applicable_methods(cls) -> List[Type["Method"]]:
    method register_method (line 179) | def register_method(cls, method: Type["Method"]):
    method get_name (line 184) | def get_name(cls) -> str:
    method immediate_children (line 193) | def immediate_children(cls) -> Iterable[Type["SettingABC"]]:
    method get_immediate_children (line 200) | def get_immediate_children(cls) -> List[Type["SettingABC"]]:
    method children (line 205) | def children(cls) -> Iterable[Type["SettingABC"]]:
    method get_children (line 214) | def get_children(cls) -> List[Type["SettingABC"]]:
    method immediate_parents (line 218) | def immediate_parents(cls) -> List[Type["SettingABC"]]:
    method get_immediate_parents (line 225) | def get_immediate_parents(cls) -> List[Type["SettingABC"]]:
    method parents (line 232) | def parents(cls) -> Iterable[Type["SettingABC"]]:
    method get_parents (line 244) | def get_parents(cls) -> List[Type["SettingABC"]]:
    method get_path_to_source_file (line 248) | def get_path_to_source_file(cls: Type) -> Path:
    method get_tree_string (line 254) | def get_tree_string(
  class Method (line 283) | class Method(Generic[SettingType], Parseable, ABC):
    method configure (line 293) | def configure(self, setting: SettingType) -> None:
    method get_actions (line 301) | def get_actions(
    method fit (line 309) | def fit(
    method test (line 319) | def test(self, test_env: Environment[Observations, Actions, Optional[R...
    method receive_results (line 351) | def receive_results(self, setting: SettingType, results: Results) -> N...
    method setup_wandb (line 431) | def setup_wandb(self, run: Run) -> None:
    method set_training (line 446) | def set_training(self) -> None:
    method set_testing (line 463) | def set_testing(self) -> None:
    method training (line 481) | def training(self) -> bool:
    method testing (line 492) | def testing(self) -> bool:
    method main (line 508) | def main(cls, argv: Optional[Union[str, List[str]]] = None) -> Results:
    method is_applicable (line 532) | def is_applicable(cls, setting: Union[SettingType, Type[SettingType]])...
    method get_applicable_settings (line 567) | def get_applicable_settings(cls) -> List[Type[SettingType]]:
    method all_evaluation_settings (line 578) | def all_evaluation_settings(cls, **kwargs) -> Iterable[SettingType]:
    method get_name (line 591) | def get_name(cls) -> str:
    method get_family (line 600) | def get_family(cls) -> Optional[str]:
    method get_full_name (line 609) | def get_full_name(cls) -> str:
    method __init_subclass__ (line 619) | def __init_subclass__(cls, target_setting: Type[SettingType] = None, *...
    method get_path_to_source_file (line 642) | def get_path_to_source_file(cls) -> Path:
    method get_experiment_name (line 645) | def get_experiment_name(self, setting: SettingABC, experiment_id: str ...
    method get_search_space (line 675) | def get_search_space(self, setting: SettingABC) -> Mapping[str, Union[...
    method adapt_to_new_hparams (line 694) | def adapt_to_new_hparams(self, new_hparams: Dict[str, Any]) -> None:
    method hparam_sweep (line 716) | def hparam_sweep(

FILE: sequoia/settings/base/environment.py
  class Environment (line 21) | class Environment(
    method is_closed (line 34) | def is_closed(self) -> bool:

FILE: sequoia/settings/base/objects.py
  class Observations (line 11) | class Observations(Batch):
    method state (line 17) | def state(self) -> Tensor:
    method __len__ (line 20) | def __len__(self) -> int:
  class Actions (line 25) | class Actions(Batch):
    method actions (line 36) | def actions(self) -> Tensor:
    method actions_np (line 40) | def actions_np(self) -> np.ndarray:
    method predictions (line 47) | def predictions(self) -> Tensor:
  class Rewards (line 55) | class Rewards(Batch, Generic[T]):
    method labels (line 69) | def labels(self) -> T:
    method reward (line 73) | def reward(self) -> T:

FILE: sequoia/settings/base/results.py
  class Results (line 37) | class Results(Serializable, ABC):
    method objective (line 53) | def objective(self) -> float:
    method summary (line 62) | def summary(self) -> str:
    method make_plots (line 70) | def make_plots(self) -> Dict[str, plt.Figure]:
    method to_log_dict (line 79) | def to_log_dict(self, verbose: bool = False) -> Dict[str, Any]:
    method save (line 83) | def save(self, path: Union[str, Path], dump_fn=None, **kwargs) -> None:
    method save_to_dir (line 88) | def save_to_dir(self, save_dir: Union[str, Path], filename: str = "res...
    method __eq__ (line 112) | def __eq__(self, other: Any) -> bool:
    method __gt__ (line 119) | def __gt__(self, other: Any) -> bool:

FILE: sequoia/settings/base/setting.py
  class Setting (line 60) | class Setting(
    method __post_init__ (line 165) | def __post_init__(
    method apply (line 247) | def apply(self, method: Method, config: Config = None) -> "Setting.Res...
    method get_metrics (line 281) | def get_metrics(self, actions: Actions, rewards: Rewards) -> Union[flo...
    method image_space (line 323) | def image_space(self) -> Optional[gym.Space]:
    method observation_space (line 338) | def observation_space(self) -> gym.Space:
    method observation_space (line 342) | def observation_space(self, value: gym.Space) -> None:
    method action_space (line 365) | def action_space(self) -> gym.Space:
    method action_space (line 369) | def action_space(self, value: gym.Space) -> None:
    method reward_space (line 373) | def reward_space(self) -> gym.Space:
    method reward_space (line 377) | def reward_space(self, value: gym.Space) -> None:
    method get_available_datasets (line 381) | def get_available_datasets(cls) -> Iterable[str]:
    method _setup_config (line 385) | def _setup_config(self, method: Method) -> Config:
    method main (line 403) | def main(cls, argv: Optional[Union[str, List[str]]] = None) -> Results:
    method apply_all (line 417) | def apply_all(self, argv: Union[str, List[str]] = None) -> Dict[Type["...
    method _check_environments (line 436) | def _check_environments(self):
    method _check_observations (line 525) | def _check_observations(self, env: Environment, observations: Any):
    method _check_actions (line 549) | def _check_actions(self, env: Environment, actions: Any):
    method _check_rewards (line 559) | def _check_rewards(self, env: Environment, rewards: Any):
    method __new__ (line 573) | def __new__(cls, *args, **kwargs):
    method load_benchmark (line 577) | def load_benchmark(cls: Type[SettingType], benchmark: Union[str, Path]...

FILE: sequoia/settings/base/setting_meta.py
  class SettingMeta (line 13) | class SettingMeta(Type["Setting"]):
    method __call__ (line 32) | def __call__(cls, *args, **kwargs):
    method __instancecheck__ (line 71) | def __instancecheck__(self, instance):

FILE: sequoia/settings/base/setting_test.py
  class Setting1 (line 15) | class Setting1(Setting):
    method __post_init__ (line 19) | def __post_init__(self):
  class Setting2 (line 25) | class Setting2(Setting1):
    method __post_init__ (line 28) | def __post_init__(self):
  function test_settings_override_with_constant_take_init (line 34) | def test_settings_override_with_constant_take_init():
  function test_loading_benchmark_doesnt_overwrite_constant (line 48) | def test_loading_benchmark_doesnt_overwrite_constant():
  function test_init_still_works (line 58) | def test_init_still_works():
  function test_passing_unexpected_arg_raises_typeerror (line 63) | def test_passing_unexpected_arg_raises_typeerror():
  class SettingA (line 69) | class SettingA(Setting):
  class SettingA1 (line 74) | class SettingA1(SettingA):
  class SettingA2 (line 79) | class SettingA2(SettingA):
  class SettingB (line 84) | class SettingB(Setting):
  class MethodA (line 88) | class MethodA(Method, target_setting=SettingA):
  class MethodB (line 92) | class MethodB(Method, target_setting=SettingB):
  class CoolGeneralMethod (line 96) | class CoolGeneralMethod(Method, target_setting=Setting):
  function test_that_transforms_can_be_set_through_command_line (line 100) | def test_that_transforms_can_be_set_through_command_line():
  class SettingTests (line 123) | class SettingTests:
    method __init_subclass__ (line 182) | def __init_subclass__(cls, setting: Type[Setting] = None):
    method assert_chance_level (line 196) | def assert_chance_level(self, setting: Setting, results: Setting.Resul...
    method test_random_baseline (line 209) | def test_random_baseline(self, config: Config):
  function make_dataset_fixture (line 237) | def make_dataset_fixture(setting_type: Union[Type[Setting], functools.pa...

FILE: sequoia/settings/offline_rl/setting.py
  class OfflineRLResults (line 21) | class OfflineRLResults(Results):
    method summary (line 24) | def summary(self) -> str:
    method make_plots (line 27) | def make_plots(self) -> Dict[str, plt.Figure]:
    method to_log_dict (line 30) | def to_log_dict(self, verbose: bool = False) -> Dict[str, Any]:
    method objective (line 41) | def objective(self):
  class OfflineRLSetting (line 62) | class OfflineRLSetting(Setting):
    method __post_init__ (line 79) | def __post_init__(self):
    method train_dataloader (line 96) | def train_dataloader(self, batch_size: int = None) -> DataLoader:
    method val_dataloader (line 99) | def val_dataloader(self, batch_size: int = None) -> DataLoader:
    method test (line 102) | def test(self, method, test_env: gym.Env):
    method apply (line 119) | def apply(self, method) -> OfflineRLResults:

FILE: sequoia/settings/rl/continual/environment.py
  class GymDataLoader (line 57) | class GymDataLoader(
    method __init__ (line 102) | def __init__(
    method num_workers (line 189) | def num_workers(self) -> Optional[int]:
    method num_workers (line 193) | def num_workers(self, value: Any) -> Optional[int]:
    method batch_size (line 202) | def batch_size(self) -> Optional[int]:
    method batch_size (line 206) | def batch_size(self, value: Any) -> Optional[int]:
    method __next__ (line 214) | def __next__(self) -> ObservationType:
    method _obs_have_done_signal (line 224) | def _obs_have_done_signal(self) -> bool:
    method __iter__ (line 233) | def __iter__(self) -> Iterator:
    method step (line 316) | def step(self, action: Union[ActionType, Any]) -> StepResult:
    method send (line 320) | def send(self, action: Union[ActionType, Any]) -> RewardType:

FILE: sequoia/settings/rl/continual/environment_test.py
  class TestGymDataLoader (line 22) | class TestGymDataLoader:
    method test_spaces (line 31) | def test_spaces(self, env_name: str, batch_size: int):
    method test_max_steps_is_respected (line 63) | def test_max_steps_is_respected(self, env_name: str, batch_size: int):
    method test_multiple_epochs_works (line 86) | def test_multiple_epochs_works(self, batch_size: Optional[int], seed: ...
    method test_reward_isnt_always_one (line 157) | def test_reward_isnt_always_one(self, env_name: str, batch_size: int):
    method test_batched_state (line 177) | def test_batched_state(self, env_name: str, batch_size: int):
    method test_batched_pixels (line 213) | def test_batched_pixels(self, env_name: str, batch_size: int):

FILE: sequoia/settings/rl/continual/make_env.py
  function make_batched_env (line 21) | def make_batched_env(
  function wrap (line 128) | def wrap(env: gym.Env, wrappers: Iterable[Union[Type[Wrapper], WrapperAn...
  function _make_wrapper_fns (line 138) | def _make_wrapper_fns(

FILE: sequoia/settings/rl/continual/make_env_test.py
  function test_make_batched_env (line 19) | def test_make_batched_env(env_name: str, batch_size: int):
  function test_make_batched_env_envs_have_distinct_ids (line 40) | def test_make_batched_env_envs_have_distinct_ids(env_name: str, batch_si...
  function get_unwrapped_id (line 66) | def get_unwrapped_id(env):
  function test_make_env_with_wrapper (line 73) | def test_make_env_with_wrapper(env_name: str, batch_size: int):
  function test_make_env_with_wrapper_and_kwargs (line 99) | def test_make_env_with_wrapper_and_kwargs(env_name: str, batch_size: int):

FILE: sequoia/settings/rl/continual/objects.py
  class Observations (line 11) | class Observations(RLSetting.Observations, ContinualAssumption.Observati...
  class Actions (line 23) | class Actions(RLSetting.Actions, ContinualAssumption.Actions):
  class Rewards (line 30) | class Rewards(RLSetting.Rewards, ContinualAssumption.Rewards):

FILE: sequoia/settings/rl/continual/results.py
  class ContinualRLResults (line 10) | class ContinualRLResults(ContinualResults, Generic[MetricType]):
    method mean_reward_plot (line 24) | def mean_reward_plot(self):

FILE: sequoia/settings/rl/continual/setting.py
  class SB3AtariWrapper (line 25) | class SB3AtariWrapper:
  class ContinualRLSetting (line 91) | class ContinualRLSetting(RLSetting, ContinualAssumption):
    method __post_init__ (line 274) | def __post_init__(self):
    method create_train_task_schedule (line 599) | def create_train_task_schedule(self) -> TaskSchedule:
    method create_val_task_schedule (line 623) | def create_val_task_schedule(self) -> TaskSchedule:
    method create_test_task_schedule (line 627) | def create_test_task_schedule(self) -> TaskSchedule[ContinuousTask]:
    method create_task_schedule (line 647) | def create_task_schedule(
    method observation_space (line 684) | def observation_space(self) -> TypedDictSpace:
    method task_label_space (line 729) | def task_label_space(self) -> gym.Space:
    method action_space (line 746) | def action_space(self) -> gym.Space:
    method reward_space (line 755) | def reward_space(self) -> gym.Space:
    method apply (line 763) | def apply(self, method: Method, config: Config = None) -> "ContinualRL...
    method setup (line 816) | def setup(self, stage: str = None) -> None:
    method prepare_data (line 827) | def prepare_data(self, *args, **kwargs) -> None:
    method train_dataloader (line 833) | def train_dataloader(
    method val_dataloader (line 895) | def val_dataloader(self, batch_size: int = None, num_workers: int = No...
    method test_dataloader (line 950) | def test_dataloader(self, batch_size: int = None, num_workers: int = N...
    method phases (line 1051) | def phases(self) -> int:
    method steps_per_phase (line 1063) | def steps_per_phase(self) -> Optional[int]:
    method _make_env (line 1078) | def _make_env(
    method _make_env_dataloader (line 1101) | def _make_env_dataloader(
    method create_train_wrappers (line 1167) | def create_train_wrappers(self) -> List[Callable[[gym.Env], gym.Env]]:
    method create_valid_wrappers (line 1194) | def create_valid_wrappers(self) -> List[Callable[[gym.Env], gym.Env]]:
    method create_test_wrappers (line 1220) | def create_test_wrappers(self) -> List[Callable[[gym.Env], gym.Env]]:
    method _make_wrappers (line 1242) | def _make_wrappers(
    method _get_objective_scaling_factor (line 1337) | def _get_objective_scaling_factor(self) -> float:
    method _get_simple_name (line 1365) | def _get_simple_name(self, env_name_or_id: str) -> Optional[str]:
  function _load_task_schedule (line 1382) | def _load_task_schedule(file_path: Path) -> Dict[int, Dict]:
  function find_matching_dataset (line 1393) | def find_matching_dataset(

FILE: sequoia/settings/rl/continual/setting_test.py
  function test_passing_unsupported_dataset_raises_error (line 42) | def test_passing_unsupported_dataset_raises_error(dataset: Any):
  function test_acrobot_attributes_change_over_time (line 47) | def test_acrobot_attributes_change_over_time():
  function wrap (line 177) | def wrap(
  function wrap_reversed (line 203) | def wrap_reversed(
  function _equal (line 210) | def _equal(a: Any, b: Any) -> bool:
  function _partials_equal (line 230) | def _partials_equal(a: partial, b: partial) -> bool:
  function _lists_equal (line 243) | def _lists_equal(a: List, b: List) -> bool:
  function _dicts_equal (line 248) | def _dicts_equal(a: Dict, b: Dict) -> bool:
  function all_different_from_next (line 260) | def all_different_from_next(sequence: Sequence) -> bool:
  class TestContinualRLSetting (line 265) | class TestContinualRLSetting(SettingTests):
    method setting_kwargs (line 270) | def setting_kwargs(self, dataset: str, config: Config):
    method test_passing_supported_dataset (line 274) | def test_passing_supported_dataset(self, setting_kwargs: Dict):
    method test_task_schedule_is_reproducible (line 285) | def test_task_schedule_is_reproducible(self, dataset: str, seed: Optio...
    method test_using_deprecated_fields (line 295) | def test_using_deprecated_fields(self):
    method test_tasks_are_different (line 316) | def test_tasks_are_different(self, setting_kwargs: Dict[str, Any], con...
    method test_settings_attributes_are_the_same_for_given_seed (line 327) | def test_settings_attributes_are_the_same_for_given_seed(
    method test_tasks_are_different_when_seed_is_different (line 348) | def test_tasks_are_different_when_seed_is_different(
    method test_env_attributes_change (line 390) | def test_env_attributes_change(self, setting_kwargs: Dict[str, Any], c...
    method validate_env_value_changes (line 448) | def validate_env_value_changes(
    method validate_results (line 496) | def validate_results(
    method test_check_iterate_and_step (line 514) | def test_check_iterate_and_step(
    method test_show_distributions (line 701) | def test_show_distributions(self, config: Config):
  function test_passing_task_schedule_sets_other_attributes_correctly (line 780) | def test_passing_task_schedule_sets_other_attributes_correctly():
  function test_fit_and_on_task_switch_calls (line 824) | def test_fit_and_on_task_switch_calls():
  function test_mujoco_env_name_maps_to_continual_variant (line 874) | def test_mujoco_env_name_maps_to_continual_variant(

FILE: sequoia/settings/rl/continual/tasks.py
  class TaskSchedule (line 42) | class TaskSchedule(Dict[int, TaskType]):
  class EnvironmentNotSupportedError (line 46) | class EnvironmentNotSupportedError(gym.error.UnregisteredEnv):
  function names_match (line 50) | def names_match(name_a: str, name_b: str) -> bool:
  function _is_supported (line 59) | def _is_supported(
  function task_sampling_function (line 123) | def task_sampling_function(
  function make_continuous_task (line 220) | def make_continuous_task(
  function make_task_for_classic_control_env (line 310) | def make_task_for_classic_control_env(
  function make_task_for_modified_gravity_env (line 364) | def make_task_for_modified_gravity_env(

FILE: sequoia/settings/rl/continual/tasks_test.py
  function test_mujoco_tasks (line 29) | def test_mujoco_tasks(env_type: Type[MujocoEnv]):

FILE: sequoia/settings/rl/continual/test_environment.py
  class ContinualRLTestEnvironment (line 12) | class ContinualRLTestEnvironment(TestEnvironment):
    method __init__ (line 13) | def __init__(self, *args, task_schedule: Dict, **kwargs):
    method __len__ (line 18) | def __len__(self):
    method get_results (line 21) | def get_results(self) -> ContinualResults[EpisodeMetrics]:
    method render (line 49) | def render(self, mode="human", **kwargs):
    method _after_reset (line 56) | def _after_reset(self, observation):

FILE: sequoia/settings/rl/discrete/multienv_wrappers.py
  function instantiate_env (line 22) | def instantiate_env(env: Union[str, gym.Env, Callable[[], gym.Env]]) -> ...
  class MultiEnvWrapper (line 31) | class MultiEnvWrapper(IterableWrapper, ABC):
    method __init__ (line 38) | def __init__(self, envs: List[gym.Env], add_task_ids: bool = False):
    method _instantiate_env (line 54) | def _instantiate_env(self, index: int) -> None:
    method set_task (line 57) | def set_task(self, task_id: int) -> None:
    method next_task (line 75) | def next_task(self) -> int:
    method reset (line 78) | def reset(self):
    method step (line 87) | def step(self, action):
    method is_closed (line 92) | def is_closed(self, env_index: int = None):
    method close (line 116) | def close(self, env_index: int = None) -> None:
    method seed (line 136) | def seed(self, seed: Optional[int] = None) -> List[int]:
    method observation (line 164) | def observation(self, observation):
  class ConcatEnvsWrapper (line 170) | class ConcatEnvsWrapper(MultiEnvWrapper):
    method __init__ (line 173) | def __init__(
    method set_task (line 182) | def set_task(self, task_id: int) -> None:
    method reset (line 187) | def reset(self):
    method next_task (line 195) | def next_task(self) -> int:
    method __iter__ (line 202) | def __iter__(self):
    method send (line 205) | def send(self, action):
  function _concatenate_gym_envs (line 213) | def _concatenate_gym_envs(first_env: gym.Env, *other_envs: gym.Env) -> C...
  class RoundRobinWrapper (line 217) | class RoundRobinWrapper(MultiEnvWrapper):
    method __init__ (line 222) | def __init__(self, envs, add_task_ids=False):
    method next_task (line 226) | def next_task(self) -> int:
  class RandomMultiEnvWrapper (line 235) | class RandomMultiEnvWrapper(MultiEnvWrapper):
    method next_task (line 236) | def next_task(self) -> int:
  class CustomMultiEnvWrapper (line 242) | class CustomMultiEnvWrapper(
Condensed preview — 460 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (2,907K chars).
[
  {
    "path": ".dockerignore",
    "chars": 40,
    "preview": "data\nlightning_logs\ncheckpoints\nresults\n"
  },
  {
    "path": ".gitattributes",
    "chars": 33,
    "preview": "sequoia/_version.py export-subst\n"
  },
  {
    "path": ".gitignore",
    "chars": 295,
    "preview": "**/__pycache__/\n.vscode\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\nexamples/results/*\nresults/*\n!results/**/*.csv\ndata"
  },
  {
    "path": ".gitmodules",
    "chars": 412,
    "preview": "[submodule \"sequoia/methods/cn_dpm\"]\n\tpath = sequoia/methods/cn_dpm\n\turl = https://github.com/ryanlindeborg/CN-DPM.git\n["
  },
  {
    "path": ".travis.yml",
    "chars": 152,
    "preview": "language: python\npython:\n  - \"3.7\"\ninstall:\n  - pip install gym[atari]\n  - pip install -r requirements.txt\nscript:\n  - p"
  },
  {
    "path": "LICENSE",
    "chars": 35149,
    "preview": "                    GNU GENERAL PUBLIC LICENSE\n                       Version 3, 29 June 2007\n\n Copyright (C) 2007 Free "
  },
  {
    "path": "MANIFEST.in",
    "chars": 50,
    "preview": "include versioneer.py\ninclude sequoia/_version.py\n"
  },
  {
    "path": "README.md",
    "chars": 12052,
    "preview": "# Sequoia - The Research Tree \n\nA Playground for research at the intersection of Continual, Reinforcement, and Self-Supe"
  },
  {
    "path": "dockers/.gitignore",
    "chars": 34,
    "preview": "# Hiding the 'eai' dockerfile\neai\n"
  },
  {
    "path": "dockers/base/Dockerfile",
    "chars": 4085,
    "preview": "# syntax=docker/dockerfile:1\nFROM pytorch/pytorch:1.8.1-cuda11.1-cudnn8-runtime\nUSER root\nEXPOSE 2222\nEXPOSE 6000\nEXPOSE"
  },
  {
    "path": "dockers/base/build.sh",
    "chars": 698,
    "preview": "#!/bin/bash\nset -o errexit    # Used to exit upon error, avoiding cascading errors\nset -o errtrace    # Show error trace"
  },
  {
    "path": "dockers/branch/Dockerfile",
    "chars": 356,
    "preview": "# syntax=docker/dockerfile:1\nFROM lebrice/sequoia:base\nUSER root\nSHELL [ \"conda\", \"run\", \"-n\", \"base\", \"/bin/bash\", \"-c\""
  },
  {
    "path": "dockers/branch/build.sh",
    "chars": 903,
    "preview": "#!/bin/bash\nset -o errexit    # Used to exit upon error, avoiding cascading errors\nset -o errtrace    # Show error trace"
  },
  {
    "path": "docs/diagrams/src/gym.puml",
    "chars": 1337,
    "preview": "@startuml gym\n\npackage gym {\n    package spaces as gym.spaces {\n        abstract class Space<T> {\n            + contains"
  },
  {
    "path": "docs/diagrams/src/pytorch_lightning.puml",
    "chars": 473,
    "preview": "@startuml pytorch_lightning\npackage pytorch_lightning {\n    abstract class LightningDataModule {\n        {abstract} + pr"
  },
  {
    "path": "docs/diagrams/src/seq_diagram.puml",
    "chars": 4874,
    "preview": "@startuml ContinualRLSetting\nheader Page Header\nfooter Page %page% of %lastpage%\ntitle Overall Evaluation loop - Sequoia"
  },
  {
    "path": "examples/README.md",
    "chars": 3196,
    "preview": "# Examples\n\nHere's a brief description of the examples in this folder:\n\n## Prerequisites:\n- [Intro to dataclasses & simp"
  },
  {
    "path": "examples/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "examples/advanced/RL_and_SL_demo.py",
    "chars": 12043,
    "preview": "\"\"\" Demo where we add the same regularization loss from the other examples, but\nthis time as an `AuxiliaryTask` on top o"
  },
  {
    "path": "examples/advanced/continual_rl_demo.py",
    "chars": 1158,
    "preview": "import sys\n\n# This \"hack\" is required so we can run `python examples/continual_rl_demo.py`\nsys.path.extend([\".\", \"..\"])\n"
  },
  {
    "path": "examples/advanced/ewc_in_rl.py",
    "chars": 14551,
    "preview": "\"\"\" Example of how to add a simplified regularization method to algos from\nstable-baseline-3.\n\"\"\"\nfrom collections impor"
  },
  {
    "path": "examples/advanced/hat_demo.py",
    "chars": 14735,
    "preview": "import sys\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom typing import Dict, NamedTuple, Optiona"
  },
  {
    "path": "examples/advanced/hparam_tuning.py",
    "chars": 3752,
    "preview": "\"\"\"Runs a hyper-parameter tuning sweep, using Orion for HPO and wandb for visualization. \n\n# PREREQUISITES:\n\n\n1.  (Optio"
  },
  {
    "path": "examples/advanced/pnn/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "examples/advanced/pnn/layers.py",
    "chars": 1498,
    "preview": "import torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision import transforms\n\n\"\"\"\nBased on https://github.co"
  },
  {
    "path": "examples/advanced/pnn/model_rl.py",
    "chars": 6681,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision import transforms"
  },
  {
    "path": "examples/advanced/pnn/model_sl.py",
    "chars": 5494,
    "preview": "from typing import Dict, List, Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.fu"
  },
  {
    "path": "examples/advanced/pnn/pnn_method.py",
    "chars": 16314,
    "preview": "import sys\nfrom argparse import Namespace\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Tupl"
  },
  {
    "path": "examples/advanced/procgen_example.py",
    "chars": 12399,
    "preview": "\"\"\" Example of how to create an incremental RL Setting with custom environments for each task.\n\nIn this example, we crea"
  },
  {
    "path": "examples/basic/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "examples/basic/base_method_demo.py",
    "chars": 1800,
    "preview": "\"\"\" Example showing how the BaseMethod can be applied to get results in both\nRL and SL settings.\n\"\"\"\n\nfrom simple_parsin"
  },
  {
    "path": "examples/basic/pl_example.py",
    "chars": 13842,
    "preview": "\"\"\"A simple example for creating a Method using PyTorch-Lightning.\n\nRun this as:\n\n```console\n$> python examples/basic/pl"
  },
  {
    "path": "examples/basic/pl_example_packnet.py",
    "chars": 3980,
    "preview": "from dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom simple_parsing import mutable_field\n\nfr"
  },
  {
    "path": "examples/basic/pl_example_test.py",
    "chars": 2946,
    "preview": "\"\"\" Unit-tests for the PyTorch-Lightning Example.\n\nCan be run like so:\n```console\n$ pytest examples/basic/pl_example_tes"
  },
  {
    "path": "examples/basic/quick_demo.ipynb",
    "chars": 210184,
    "preview": "{\n \"metadata\": {\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_ext"
  },
  {
    "path": "examples/basic/quick_demo.py",
    "chars": 11413,
    "preview": "\"\"\" Demo: Creates a simple new method and applies it to a single CL setting.\n\"\"\"\nimport sys\nfrom argparse import Namespa"
  },
  {
    "path": "examples/basic/quick_demo_ewc.py",
    "chars": 6700,
    "preview": "\"\"\" Example script: Defines a new Method based on the DemoMethod from the\nquick_demo.py script, adding an EWC-like loss "
  },
  {
    "path": "examples/basic/quick_demo_packnet.py",
    "chars": 282,
    "preview": "from sequoia.methods.packnet_method import PackNetMethod\nfrom sequoia.settings.sl import TaskIncrementalSLSetting\n\nif __"
  },
  {
    "path": "examples/basic/quick_demo_test.py",
    "chars": 2047,
    "preview": "\"\"\" TODO: Write tests that check that the examples are working correctly.\n\"\"\"\nimport contextlib\nimport sys\n\nimport pytes"
  },
  {
    "path": "examples/clcomp21/README.md",
    "chars": 1389,
    "preview": "## Example Submissions for CLVision Workshop\n\nExamples in this folder are aimed at solving the supervised learning track"
  },
  {
    "path": "examples/clcomp21/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "examples/clcomp21/a2c_example.py",
    "chars": 13868,
    "preview": "from argparse import Namespace\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Dict, List,"
  },
  {
    "path": "examples/clcomp21/a2c_example_test.py",
    "chars": 2286,
    "preview": "import pytest\n\nfrom sequoia.client.setting_proxy import SettingProxy\nfrom sequoia.conftest import slow\nfrom sequoia.sett"
  },
  {
    "path": "examples/clcomp21/classifier.py",
    "chars": 13263,
    "preview": "\"\"\" Example Method for the SL track: Uses a simple classifier, without any CL mechanism.\n\nAs you'd expect, this Method e"
  },
  {
    "path": "examples/clcomp21/classifier_test.py",
    "chars": 1398,
    "preview": "import pytest\n\nfrom sequoia.client.setting_proxy import SettingProxy\nfrom sequoia.conftest import slow\nfrom sequoia.sett"
  },
  {
    "path": "examples/clcomp21/conftest.py",
    "chars": 2550,
    "preview": "import pytest\n\nfrom sequoia.client.setting_proxy import SettingProxy\nfrom sequoia.settings.rl import IncrementalRLSettin"
  },
  {
    "path": "examples/clcomp21/dummy_method.py",
    "chars": 3853,
    "preview": "from typing import Optional\n\nimport gym\nimport numpy as np\nimport tqdm\nfrom torch import Tensor\n\nfrom sequoia.methods im"
  },
  {
    "path": "examples/clcomp21/dummy_method_test.py",
    "chars": 2103,
    "preview": "import pytest\n\nfrom sequoia.client.setting_proxy import SettingProxy\nfrom sequoia.conftest import slow\nfrom sequoia.sett"
  },
  {
    "path": "examples/clcomp21/multihead_classifier.py",
    "chars": 14974,
    "preview": "\"\"\" Example Method for the SL track: Multi-Head Classifier with simple task inference.\n\nYou can use this model and metho"
  },
  {
    "path": "examples/clcomp21/multihead_classifier_test.py",
    "chars": 2210,
    "preview": "import pytest\n\nfrom sequoia.client.setting_proxy import SettingProxy\nfrom sequoia.conftest import slow\nfrom sequoia.sett"
  },
  {
    "path": "examples/clcomp21/regularization_example.py",
    "chars": 6683,
    "preview": "\"\"\" Example: Defines a new Method based on the ExampleMethod, adding an EWC-like loss to\nhelp prevent the weights from c"
  },
  {
    "path": "examples/clcomp21/regularization_example_test.py",
    "chars": 1474,
    "preview": "import pytest\n\nfrom sequoia.client.setting_proxy import SettingProxy\nfrom sequoia.conftest import slow\nfrom sequoia.sett"
  },
  {
    "path": "examples/clcomp21/sb3_example.py",
    "chars": 2857,
    "preview": "\"\"\" Example where we start from a Method from stable-baselines3 to solve the rl track.\n\"\"\"\nfrom dataclasses import datac"
  },
  {
    "path": "examples/clcomp21/sb3_example_test.py",
    "chars": 2052,
    "preview": "import pytest\n\nfrom sequoia.client.setting_proxy import SettingProxy\nfrom sequoia.conftest import slow\nfrom sequoia.sett"
  },
  {
    "path": "examples/demo_utils.py",
    "chars": 7534,
    "preview": "from collections import defaultdict\nfrom pathlib import Path\nfrom typing import Dict, List, Type\n\nimport pandas as pd\nfr"
  },
  {
    "path": "examples/prerequisites/dataclasses_example.py",
    "chars": 2085,
    "preview": "\"\"\" Example describing dataclasses and how simple-parsing can be used to create\ncommand-line arguments from them.\n\"\"\"\n\nf"
  },
  {
    "path": "mypy.ini",
    "chars": 120,
    "preview": "# Global options:\n\n[mypy]\npython_version = 3.7\nwarn_return_any = True\nwarn_unused_configs = True\nfollow_imports = normal"
  },
  {
    "path": "pytest.ini",
    "chars": 196,
    "preview": "[pytest]\ntimeout = 30\ntestpaths =\n    sequoia\n    examples\naddopts =\n    --doctest-modules\nnorecursedirs =\n    methods/d"
  },
  {
    "path": "requirements.txt",
    "chars": 1107,
    "preview": "# Fork of gym with more flexible utility functions.\ngym @ git+https://www.github.com/openai/gym@8819d561132082f6130d4a23"
  },
  {
    "path": "scripts/eai/cancel_all_queuing.sh",
    "chars": 115,
    "preview": "all_ids=$(eai job ls --state queuing -c \"$1\" --fields id --no-header)\nfor id in $all_ids\ndo\n  eai job kill $id\ndone"
  },
  {
    "path": "scripts/eai/cancel_all_running.sh",
    "chars": 116,
    "preview": "all_ids=$(eai job ls --state running  -c \"$1\" --fields id --no-header)\nfor id in $all_ids\ndo\n  eai job kill $id\ndone"
  },
  {
    "path": "scripts/eai/job.sh",
    "chars": 1815,
    "preview": "#!/bin/bash\nset -o errexit    # Used to exit upon error, avoiding cascading errors\nset -o errtrace    # Show error trace"
  },
  {
    "path": "scripts/eai/rl_sweep.sh",
    "chars": 1907,
    "preview": "#!/bin/bash\nset -o errexit  # Used to exit upon error, avoiding cascading errors\nset -o errtrace # Show error trace\nset "
  },
  {
    "path": "scripts/eai/shell_job.sh",
    "chars": 1607,
    "preview": "#!/bin/bash\nset -o errexit    # Used to exit upon error, avoiding cascading errors\nset -o errtrace    # Show error trace"
  },
  {
    "path": "scripts/eai/sl_sweep.sh",
    "chars": 2065,
    "preview": "#!/bin/bash\nset -o errexit  # Used to exit upon error, avoiding cascading errors\nset -o errtrace # Show error trace\nset "
  },
  {
    "path": "scripts/slurm/launch_many_sweeps.sh",
    "chars": 1421,
    "preview": "#!/bin/bash\nset -o errexit  # Used to exit upon error, avoiding cascading errors\nset -o errtrace # Show error trace\nset "
  },
  {
    "path": "scripts/slurm/run.sh",
    "chars": 268,
    "preview": "#!/bin/bash\n#SBATCH --array=0-3%2\n#SBATCH --cpus-per-task=2\n#SBATCH --gres=gpu:1\n#SBATCH --mem=10GB\n#SBATCH --time=11:59"
  },
  {
    "path": "scripts/slurm/sweep.sh",
    "chars": 725,
    "preview": "#!/bin/bash\n#SBATCH --array=0-10%2\n#SBATCH --cpus-per-task=2\n#SBATCH --gres=gpu:1\n#SBATCH --mem=10GB\n#SBATCH --time=11:5"
  },
  {
    "path": "sequoia/README.md",
    "chars": 695,
    "preview": "# sequoia\n\n## Packages:\n- [settings](settings): definitions for the settings (machine learning problems).\n- [methods](me"
  },
  {
    "path": "sequoia/__init__.py",
    "chars": 219,
    "preview": "\"\"\" Sequoia - The Research Tree \"\"\"\nfrom ._version import get_versions\nfrom .settings import Environment, Method, Settin"
  },
  {
    "path": "sequoia/_version.py",
    "chars": 18535,
    "preview": "# This file helps to compute a version number in source trees obtained from\n# git-archive tarball (such as those provide"
  },
  {
    "path": "sequoia/client/README.md",
    "chars": 236,
    "preview": "# (WIP) Sequoia Client\n\nThis is only currently used for the competition. The idea is that the setting (and its environme"
  },
  {
    "path": "sequoia/client/__init__.py",
    "chars": 80,
    "preview": "from .env_proxy import EnvironmentProxy\nfrom .setting_proxy import SettingProxy\n"
  },
  {
    "path": "sequoia/client/__main__.py",
    "chars": 503,
    "preview": "\"\"\" TODO: launch the 'sequoia gRPC server' at a given address / port. \"\"\"\nimport argparse\n\nfrom .server import server\n\ni"
  },
  {
    "path": "sequoia/client/env.proto",
    "chars": 767,
    "preview": "syntax = \"proto3\";\n// Adapted from https://github.com/AppliedDeepLearning/gymx/blob/master/gymx/env.proto\n\nenum SettingT"
  },
  {
    "path": "sequoia/client/env_proxy.py",
    "chars": 4646,
    "preview": "\"\"\"TODO: Create an 'environment proxy' that relays observations / actions etc from a remote environment via gRPC.\n\nFor n"
  },
  {
    "path": "sequoia/client/env_proxy_test.py",
    "chars": 7141,
    "preview": "import platform\nfrom functools import partial\nfrom typing import ClassVar, Iterable, Tuple, Type, TypeVar\n\nimport gym\nim"
  },
  {
    "path": "sequoia/client/server.py",
    "chars": 83,
    "preview": "def server(grpc_host: str, grpc_port: int):\n    raise NotImplementedError(f\"TODO\")\n"
  },
  {
    "path": "sequoia/client/setting_proxy.py",
    "chars": 20459,
    "preview": "import time\nimport warnings\nfrom functools import partial\nfrom logging import getLogger\nfrom pathlib import Path\nfrom ty"
  },
  {
    "path": "sequoia/client/setting_proxy_test.py",
    "chars": 7854,
    "preview": "\"\"\"TODO: Tests for the SettingProxy.\n\n\"\"\"\nfrom functools import partial\nfrom typing import ClassVar, Type\n\nimport numpy "
  },
  {
    "path": "sequoia/common/__init__.py",
    "chars": 186,
    "preview": "from .batch import Batch\nfrom .config import Config\nfrom .loss import Loss\nfrom .metrics import ClassificationMetrics, M"
  },
  {
    "path": "sequoia/common/batch.py",
    "chars": 25250,
    "preview": "\"\"\" WIP (@lebrice): Playing around with the idea of using a typed object to\nrepresent the different forms of \"batches\" t"
  },
  {
    "path": "sequoia/common/batch_test.py",
    "chars": 16816,
    "preview": "\"\"\" Tests for the `Batch` class.\n\"\"\"\n\n\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, T"
  },
  {
    "path": "sequoia/common/callbacks/__init__.py",
    "chars": 230,
    "preview": "\"\"\"\nTODO: Migrate the addons to Pytorch-Lightning, maybe in the form of callbacks\nor as optional extensions to be added "
  },
  {
    "path": "sequoia/common/callbacks/knn_callback.py",
    "chars": 15694,
    "preview": "\"\"\" Callback that evaluates representations with a KNN after each epoch.\n\nTODO: The code here is split into too many fun"
  },
  {
    "path": "sequoia/common/callbacks/vae_callback.py",
    "chars": 4251,
    "preview": "from dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom pytorch_lightning import Callback, Trai"
  },
  {
    "path": "sequoia/common/config/__init__.py",
    "chars": 65,
    "preview": "from .config import Config\nfrom .wandb_config import WandbConfig\n"
  },
  {
    "path": "sequoia/common/config/config.py",
    "chars": 3074,
    "preview": "\"\"\" Config dataclasses for use with pytorch lightning.\n\n@author Fabrice Normandin (@lebrice)\n\"\"\"\nimport os\nfrom dataclas"
  },
  {
    "path": "sequoia/common/config/wandb_config.py",
    "chars": 8044,
    "preview": "\"\"\"TODO: Re-enable the wandb stuff (disabled for now).\n\"\"\"\nimport os\nimport re\nfrom dataclasses import dataclass\nfrom pa"
  },
  {
    "path": "sequoia/common/gym_wrappers/__init__.py",
    "chars": 663,
    "preview": "\"\"\" Contains some potentially useful gym wrappers. \"\"\"\nfrom .add_done import AddDoneToObservation\nfrom .add_info import "
  },
  {
    "path": "sequoia/common/gym_wrappers/action_limit.py",
    "chars": 2214,
    "preview": "\"\"\" IDEA: same as ObservationLimit, for for the number of total actions (steps).\n\"\"\"\nimport gym\nfrom gym.error import Cl"
  },
  {
    "path": "sequoia/common/gym_wrappers/action_limit_test.py",
    "chars": 3778,
    "preview": "from typing import List\n\nimport gym\nimport pytest\nfrom gym.wrappers import TimeLimit\n\nfrom sequoia.common.gym_wrappers.e"
  },
  {
    "path": "sequoia/common/gym_wrappers/add_done.py",
    "chars": 4504,
    "preview": "\"\"\" Wrapper that adds 'done' as part of the environment's observations.\n\"\"\"\nfrom dataclasses import is_dataclass, replac"
  },
  {
    "path": "sequoia/common/gym_wrappers/add_info.py",
    "chars": 4072,
    "preview": "\"\"\" Wrapper that adds the 'info' as a part of the environment's observations.\n\"\"\"\nfrom dataclasses import is_dataclass, "
  },
  {
    "path": "sequoia/common/gym_wrappers/convert_tensors.py",
    "chars": 8632,
    "preview": "from dataclasses import is_dataclass, replace\nimport dataclasses\nfrom functools import singledispatch, wraps\nfrom typing"
  },
  {
    "path": "sequoia/common/gym_wrappers/convert_tensors_test.py",
    "chars": 2105,
    "preview": "from typing import Union\n\nimport gym\nimport pytest\nimport torch\nfrom gym import spaces\nfrom torch import Tensor\n\nfrom se"
  },
  {
    "path": "sequoia/common/gym_wrappers/env_dataset.py",
    "chars": 13445,
    "preview": "\"\"\" Creates an IterableDataset from a Gym Environment.\n\"\"\"\nimport warnings\nfrom typing import Dict, Generic, Iterable, I"
  },
  {
    "path": "sequoia/common/gym_wrappers/env_dataset_test.py",
    "chars": 11257,
    "preview": "from functools import partial\nfrom typing import ClassVar, Type\n\nimport gym\nimport numpy as np\nimport pytest\nfrom gym im"
  },
  {
    "path": "sequoia/common/gym_wrappers/episode_limit.py",
    "chars": 5581,
    "preview": "# IDEA: Limit the total number of episodes, even in vectorized\n# environments!\nimport warnings\nfrom typing import Sequen"
  },
  {
    "path": "sequoia/common/gym_wrappers/episode_limit_test.py",
    "chars": 7178,
    "preview": "from functools import partial\n\nimport gym\nimport numpy as np\nimport pytest\nfrom gym.vector import SyncVectorEnv\nfrom gym"
  },
  {
    "path": "sequoia/common/gym_wrappers/measure_performance.py",
    "chars": 1398,
    "preview": "\"\"\" Abstract base class for a Wrapper that gets applied onto the environment in order to\nmeasure the online training per"
  },
  {
    "path": "sequoia/common/gym_wrappers/multi_task_environment.py",
    "chars": 22713,
    "preview": "import bisect\nimport dataclasses\nfrom functools import singledispatch\nfrom typing import Any, Callable, Dict, List, Opti"
  },
  {
    "path": "sequoia/common/gym_wrappers/multi_task_environment_test.py",
    "chars": 18372,
    "preview": "from typing import Dict, List, Tuple\n\nimport gym\nimport matplotlib.pyplot as plt\nimport pytest\nfrom gym import spaces\nfr"
  },
  {
    "path": "sequoia/common/gym_wrappers/observation_limit.py",
    "chars": 2329,
    "preview": "\"\"\" IDEA: same as EpisodeLimit, for for the number of total observations.\n\"\"\"\n\nimport gym\nfrom gym.error import ClosedEn"
  },
  {
    "path": "sequoia/common/gym_wrappers/observation_limit_test.py",
    "chars": 4183,
    "preview": "from functools import partial\n\nimport gym\nimport pytest\nfrom gym.vector import SyncVectorEnv\n\nfrom sequoia.conftest impo"
  },
  {
    "path": "sequoia/common/gym_wrappers/pixel_observation.py",
    "chars": 2757,
    "preview": "\"\"\" Fixes some of the annoying things about the PixelObservationWrapper. \"\"\"\nfrom typing import Union\n\nimport gym\nimport"
  },
  {
    "path": "sequoia/common/gym_wrappers/pixel_observation_test.py",
    "chars": 2495,
    "preview": "import gym\nimport numpy as np\nimport pytest\n\nfrom .pixel_observation import PixelObservationWrapper\n\npyglet = pytest.imp"
  },
  {
    "path": "sequoia/common/gym_wrappers/policy_env.py",
    "chars": 9233,
    "preview": "\"\"\"TODO: Idea: create a wrapper that accepts a 'policy' which will decide an\naction to take whenever the `action` argume"
  },
  {
    "path": "sequoia/common/gym_wrappers/policy_env_test.py",
    "chars": 1633,
    "preview": "from typing import List\n\nfrom sequoia.conftest import DummyEnvironment\n\nfrom .policy_env import PolicyEnv, StateTransiti"
  },
  {
    "path": "sequoia/common/gym_wrappers/smooth_environment.py",
    "chars": 7235,
    "preview": "\"\"\"TODO: A Wrapper that creates smooth transitions between tasks.\nCould be based on the MultiTaskEnvironment, but with a"
  },
  {
    "path": "sequoia/common/gym_wrappers/smooth_environment_test.py",
    "chars": 4063,
    "preview": "from typing import Dict\n\nimport gym\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom .smooth_environment import "
  },
  {
    "path": "sequoia/common/gym_wrappers/step_callback_wrapper.py",
    "chars": 3273,
    "preview": "\"\"\"TODO: Make a wrapper that calls a given function/callback when a given step is reached.\n\"\"\"\nfrom abc import ABC, abst"
  },
  {
    "path": "sequoia/common/gym_wrappers/step_callback_wrapper_test.py",
    "chars": 1624,
    "preview": "from typing import Tuple\n\nimport gym\n\nfrom .step_callback_wrapper import PeriodicCallback, StepCallback, StepCallbackWra"
  },
  {
    "path": "sequoia/common/gym_wrappers/transform_wrappers.py",
    "chars": 3504,
    "preview": "from typing import Callable, Union\nimport typing\n\nimport gym\nfrom gym import Space, spaces\nfrom gym.wrappers import Tran"
  },
  {
    "path": "sequoia/common/gym_wrappers/transform_wrappers_test.py",
    "chars": 1994,
    "preview": "import gym\nimport numpy as np\n\nfrom sequoia.common.spaces import Image\nfrom sequoia.common.transforms import Compose, Tr"
  },
  {
    "path": "sequoia/common/gym_wrappers/utils.py",
    "chars": 21930,
    "preview": "import inspect\nfrom abc import ABC\nfrom functools import partial\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n "
  },
  {
    "path": "sequoia/common/gym_wrappers/utils_test.py",
    "chars": 960,
    "preview": "import gym\nimport pytest\nfrom gym.wrappers import ClipAction\nfrom gym.wrappers.pixel_observation import PixelObservation"
  },
  {
    "path": "sequoia/common/hparams/__init__.py",
    "chars": 248,
    "preview": "\"\"\" Utilities for creating hyper-parameter dataclasses and their fields. \"\"\"\nfrom simple_parsing.helpers.hparams import "
  },
  {
    "path": "sequoia/common/layers.py",
    "chars": 7911,
    "preview": "import math\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom gym import s"
  },
  {
    "path": "sequoia/common/loss.py",
    "chars": 16170,
    "preview": "\"\"\" Module that defines a `Loss` class that holds losses and associated metrics.\n\nThis Loss object is used to bundle tog"
  },
  {
    "path": "sequoia/common/loss_test.py",
    "chars": 1568,
    "preview": "\"\"\"\nTODO: Write some tests that also help illustrate how the Loss class works.\n\"\"\"\nfrom .loss import Loss\n\n\ndef test_dem"
  },
  {
    "path": "sequoia/common/metrics/__init__.py",
    "chars": 325,
    "preview": "from .classification import ClassificationMetrics\nfrom .get_metrics import get_metrics\nfrom .metrics import Metrics, Met"
  },
  {
    "path": "sequoia/common/metrics/classification.py",
    "chars": 6535,
    "preview": "\"\"\" Metrics class for classification.\n\nGives the accuracy, the class accuracy, and the confusion matrix for a given set\n"
  },
  {
    "path": "sequoia/common/metrics/classification_test.py",
    "chars": 1519,
    "preview": "import numpy as np\nimport torch\n\nfrom .classification import ClassificationMetrics\nfrom .get_metrics import get_metrics\n"
  },
  {
    "path": "sequoia/common/metrics/get_metrics.py",
    "chars": 1604,
    "preview": "\"\"\" Defines the get_metrics function with gives back appropriate metrics\nfor the given tensors.\n\nTODO: Add more metrics!"
  },
  {
    "path": "sequoia/common/metrics/metrics.py",
    "chars": 5022,
    "preview": "\"\"\" Cute little dataclass that is used to describe a given type of Metrics.\n\nThis is a bit like the Metrics from pytorch"
  },
  {
    "path": "sequoia/common/metrics/metrics_utils.py",
    "chars": 3312,
    "preview": "\"\"\" Utility functions for calculating metrics. \"\"\"\nfrom typing import Union\n\nimport numpy as np\nimport torch\nfrom torch "
  },
  {
    "path": "sequoia/common/metrics/metrics_utils_test.py",
    "chars": 2604,
    "preview": "import numpy as np\nimport torch\n\nfrom .metrics_utils import accuracy, class_accuracy, get_confusion_matrix\n\n\ndef test_ac"
  },
  {
    "path": "sequoia/common/metrics/regression.py",
    "chars": 4045,
    "preview": "\"\"\" Metrics class for regression.\n\nGives the mean squared error between a prediction Tensor `y_pred` and the\ntarget tens"
  },
  {
    "path": "sequoia/common/metrics/rl_metrics.py",
    "chars": 5342,
    "preview": "from dataclasses import dataclass, field\nfrom typing import Any, Dict, Union\n\nfrom .metrics import Metrics\n\n\n@dataclass\n"
  },
  {
    "path": "sequoia/common/replay.py",
    "chars": 8982,
    "preview": "\"\"\" Labeled, Unlabeled and Semi-supervised Replay buffer objects.\n\nTODO: Unused for now, but could be used in a Lightnin"
  },
  {
    "path": "sequoia/common/spaces/__init__.py",
    "chars": 336,
    "preview": "\"\"\" Custom `gym.spaces.Space` subclasses used by Sequoia. \"\"\"\nfrom .image import Image, ImageTensorSpace\nfrom .named_tup"
  },
  {
    "path": "sequoia/common/spaces/image.py",
    "chars": 5427,
    "preview": "\"\"\" IDEA: Create a subclass of spaces.Box for images.\n\"\"\"\nfrom typing import Optional, Tuple, Union\n\nimport numpy as np\n"
  },
  {
    "path": "sequoia/common/spaces/named_tuple.py",
    "chars": 5220,
    "preview": "\"\"\" IDEA: Subclass of `gym.spaces.Tuple` that yields namedtuples,\nas a bit of a hybrid between `gym.spaces.Dict` and `gy"
  },
  {
    "path": "sequoia/common/spaces/named_tuple_test.py",
    "chars": 4209,
    "preview": "import numpy as np\nimport pytest\nfrom gym import spaces\nfrom gym.spaces import Box, Discrete\nfrom gym.vector.utils impor"
  },
  {
    "path": "sequoia/common/spaces/space.py",
    "chars": 447,
    "preview": "\"\"\" Small typing improvements to the `gym.spaces.Space` class. \"\"\"\nfrom typing import Any, Generic, TypeVar, Union\n\nfrom"
  },
  {
    "path": "sequoia/common/spaces/sparse.py",
    "chars": 14090,
    "preview": "\"\"\" 'wrapper' around a gym.Space that adds has a probability of sampling `None`\ninstead of a sample from the 'base' spac"
  },
  {
    "path": "sequoia/common/spaces/sparse_test.py",
    "chars": 7584,
    "preview": "from typing import Iterable\n\nimport gym\nimport numpy as np\nimport pytest\nfrom gym import spaces\n\nfrom .sparse import Spa"
  },
  {
    "path": "sequoia/common/spaces/tensor_spaces.py",
    "chars": 7323,
    "preview": "\"\"\" TODO: Maybe create a typed version of 'add_tensor_support' of gym_wrappers.convert_tensors\n\"\"\"\nfrom typing import Op"
  },
  {
    "path": "sequoia/common/spaces/tensor_spaces_test.py",
    "chars": 593,
    "preview": "import numpy as np\nimport pytest\nfrom gym import spaces\nfrom torch import Tensor\n\nfrom .tensor_spaces import TensorBox, "
  },
  {
    "path": "sequoia/common/spaces/typed_dict.py",
    "chars": 12853,
    "preview": "\"\"\" Subclass of `spaces.Dict` that allows custom dtypes and uses type annotations.\n\"\"\"\nimport dataclasses\nfrom collectio"
  },
  {
    "path": "sequoia/common/spaces/typed_dict_test.py",
    "chars": 9942,
    "preview": "from dataclasses import Field, dataclass, fields\nfrom typing import Dict, Iterable, Mapping, Tuple, TypeVar\n\nimport gym\n"
  },
  {
    "path": "sequoia/common/task.py",
    "chars": 853,
    "preview": "\"\"\" NOTE: Unused at the moment.\n\nThis defines a `Task` object that is just used to represent the information\nabout a 'Ta"
  },
  {
    "path": "sequoia/common/transforms/__init__.py",
    "chars": 334,
    "preview": "from .channels import (\n    ChannelsFirst,\n    ChannelsFirstIfNeeded,\n    ChannelsLast,\n    ChannelsLastIfNeeded,\n    Th"
  },
  {
    "path": "sequoia/common/transforms/channels.py",
    "chars": 11811,
    "preview": "# from torchvision.transforms import Lambda\nfrom collections.abc import Mapping\nfrom dataclasses import dataclass\nfrom f"
  },
  {
    "path": "sequoia/common/transforms/compose.py",
    "chars": 2879,
    "preview": "from typing import Callable, List, TypeVar\n\nfrom gym import spaces\nfrom torchvision.transforms import Compose as Compose"
  },
  {
    "path": "sequoia/common/transforms/resize.py",
    "chars": 5704,
    "preview": "from collections.abc import Mapping\nfrom functools import singledispatch\nfrom typing import Dict, List, Tuple\n\nimport nu"
  },
  {
    "path": "sequoia/common/transforms/split_batch.py",
    "chars": 9302,
    "preview": "import dataclasses\nfrom typing import Any, Callable, Optional, Tuple, Type, TypeVar\n\nimport numpy as np\nfrom torch impor"
  },
  {
    "path": "sequoia/common/transforms/to_tensor.py",
    "chars": 8448,
    "preview": "\"\"\" Slight modification of the ToTensor transform from TorchVision.\n\n@lebrice: I wrote this because I would often get we"
  },
  {
    "path": "sequoia/common/transforms/transform.py",
    "chars": 888,
    "preview": "\"\"\" Defines a 'smarter' Transform class. \"\"\"\nfrom abc import abstractmethod\nfrom typing import Generic, Tuple, TypeVar, "
  },
  {
    "path": "sequoia/common/transforms/transform_enum.py",
    "chars": 4969,
    "preview": "\"\"\" Transforms and such. Trying to make it possible to parse such from the\ncommand-line.\n\nAlso, playing around with the "
  },
  {
    "path": "sequoia/common/transforms/transforms_test.py",
    "chars": 6962,
    "preview": "from dataclasses import dataclass, field\nfrom typing import List, Tuple\n\nimport gym\nimport numpy as np\nimport pytest\nimp"
  },
  {
    "path": "sequoia/common/transforms/utils.py",
    "chars": 528,
    "preview": "from typing import Any\n\nimport numpy as np\nfrom gym import spaces\nfrom PIL import Image\nfrom torch import Tensor\n\nfrom s"
  },
  {
    "path": "sequoia/common.puml",
    "chars": 676,
    "preview": "@startuml common\n\n!include gym.puml\n\n' class List\n\npackage common {\n    abstract class Batch {}\n\n    package transforms "
  },
  {
    "path": "sequoia/conftest.py",
    "chars": 11379,
    "preview": "import json\nimport logging\nimport sys\nfrom pathlib import Path\nfrom typing import Any, Iterable, List, Optional, Type, g"
  },
  {
    "path": "sequoia/experiments/__init__.py",
    "chars": 121,
    "preview": "\"\"\" Package that defines a list of \"Experiments\".\n\"\"\"\nfrom .experiment import Experiment\nfrom .hpo_sweep import HPOSweep"
  },
  {
    "path": "sequoia/experiments/experiment.py",
    "chars": 19572,
    "preview": "\"\"\" Module used for launching an Experiment (applying a Method to one or more\nSettings).\n\"\"\"\nimport os\nimport shlex\nimpo"
  },
  {
    "path": "sequoia/experiments/experiment_test.py",
    "chars": 4944,
    "preview": "import shlex\nimport sys\nfrom pathlib import Path\nfrom typing import Optional, Type\n\nimport pytest\n\nfrom sequoia.common.c"
  },
  {
    "path": "sequoia/experiments/hpo_sweep.py",
    "chars": 4946,
    "preview": "import json\nimport shlex\nimport sys\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Dict, "
  },
  {
    "path": "sequoia/experiments/hpo_sweep_test.py",
    "chars": 3237,
    "preview": "import random\nimport shlex\nimport sys\nfrom pathlib import Path\nfrom typing import Optional, Type\n\nimport pytest\n\nfrom se"
  },
  {
    "path": "sequoia/main.py",
    "chars": 13972,
    "preview": "\"\"\"Sequoia - The Research Tree \n\nUsed to run experiments, which consist in applying a Method to a Setting.\n\"\"\"\nfrom argp"
  },
  {
    "path": "sequoia/methods/README.md",
    "chars": 11780,
    "preview": "# Sequoia - Methods\n\n### Adding a new Method:\n\n#### Prerequisites:\n**- First, please take a look at the [examples](examp"
  },
  {
    "path": "sequoia/methods/__init__.py",
    "chars": 7185,
    "preview": "\"\"\" Methods: solutions to research problems (Settings).\n\nMethods contain the logic related to the training of the algori"
  },
  {
    "path": "sequoia/methods/aux_tasks/__init__.py",
    "chars": 285,
    "preview": "from .auxiliary_task import AuxiliaryTask\nfrom .ewc import EWCTask\nfrom .reconstruction import AEReconstructionTask, VAE"
  },
  {
    "path": "sequoia/methods/aux_tasks/auxiliary_task.py",
    "chars": 6334,
    "preview": "import typing\nfrom abc import abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Callable, ClassVar, Di"
  },
  {
    "path": "sequoia/methods/aux_tasks/ewc.py",
    "chars": 13866,
    "preview": "\"\"\"Elastic Weight Consolidation as an Auxiliary Task.\n\nThis is a simplified version of EWC, that only currently uses the"
  },
  {
    "path": "sequoia/methods/aux_tasks/reconstruction/__init__.py",
    "chars": 339,
    "preview": "\"\"\" Auxiliary tasks based on reconstructing an input given a hidden vector.\n\nTODO: Add some denoising autoencoders maybe"
  },
  {
    "path": "sequoia/methods/aux_tasks/reconstruction/ae.py",
    "chars": 3027,
    "preview": "\"\"\" Defines an Auto-Encoder-based Auxiliary task.\n\"\"\"\nfrom typing import ClassVar, Dict, Optional, Tuple, Union\n\nimport "
  },
  {
    "path": "sequoia/methods/aux_tasks/reconstruction/decoder_for_dataset.py",
    "chars": 1014,
    "preview": "from typing import Dict, Tuple, Type, Union\n\nfrom torch import nn\n\nfrom .decoders import CifarDecoder, ImageNetDecoder, "
  },
  {
    "path": "sequoia/methods/aux_tasks/reconstruction/decoders.py",
    "chars": 2290,
    "preview": "from abc import ABC\nfrom typing import Tuple\n\nfrom torch import nn\n\nfrom sequoia.common.layers import DeConvBlock, Resha"
  },
  {
    "path": "sequoia/methods/aux_tasks/reconstruction/vae.py",
    "chars": 3192,
    "preview": "from dataclasses import dataclass\nfrom typing import ClassVar, Dict\n\nimport torch\nfrom torch import Tensor, nn\n\nfrom seq"
  },
  {
    "path": "sequoia/methods/aux_tasks/transformation_based/__init__.py",
    "chars": 133,
    "preview": "from .bases import ClassifyTransformationTask, RegressTransformationTask, TransformationBasedTask\nfrom .rotation import "
  },
  {
    "path": "sequoia/methods/aux_tasks/transformation_based/bases.py",
    "chars": 9764,
    "preview": "from dataclasses import dataclass\nfrom functools import wraps\nfrom typing import Any, Callable, List, Tuple\n\nimport torc"
  },
  {
    "path": "sequoia/methods/aux_tasks/transformation_based/rotation.py",
    "chars": 2210,
    "preview": "from dataclasses import dataclass\n\nfrom torch import Tensor\n\nfrom .bases import ClassifyTransformationTask\n\n\ndef rotate("
  },
  {
    "path": "sequoia/methods/avalanche_methods/__init__.py",
    "chars": 618,
    "preview": "\"\"\" Adapters for Avalanche Strategies, so they can be used as Methods in Sequoia.\n\nSee the Avalanche repo for more info:"
  },
  {
    "path": "sequoia/methods/avalanche_methods/agem.py",
    "chars": 1611,
    "preview": "\"\"\" Method based on AGEM from [Avalanche](https://github.com/ContinualAI/avalanche).\n\nSee `avalanche.training.plugins.ag"
  },
  {
    "path": "sequoia/methods/avalanche_methods/agem_test.py",
    "chars": 355,
    "preview": "\"\"\" WIP: Tests for the AGEM Method.\n\nFor now this only inherits the tests from the AvalancheMethod class.\n\"\"\"\nfrom typin"
  },
  {
    "path": "sequoia/methods/avalanche_methods/ar1.py",
    "chars": 2998,
    "preview": "\"\"\" Method based on AR1 from [Avalanche](https://github.com/ContinualAI/avalanche).\n\nSee `avalanche.training.strategies."
  },
  {
    "path": "sequoia/methods/avalanche_methods/ar1_test.py",
    "chars": 1662,
    "preview": "\"\"\" WIP: Tests for the AR1 Method.\n\nFor now this only inherits the tests from the AvalancheMethod class.\n\"\"\"\nfrom typing"
  },
  {
    "path": "sequoia/methods/avalanche_methods/base.py",
    "chars": 21755,
    "preview": "\"\"\" Adapter for the `BaseStrategy` from Avalanche, wrapping it up into a Sequoia Method.\n\nSee the Avalanche repo for mor"
  },
  {
    "path": "sequoia/methods/avalanche_methods/base_test.py",
    "chars": 7605,
    "preview": "import inspect\nfrom inspect import Signature, _empty, getsourcefile\nfrom typing import ClassVar, List, Optional, Type\n\ni"
  },
  {
    "path": "sequoia/methods/avalanche_methods/conftest.py",
    "chars": 1805,
    "preview": "from pathlib import Path\n\nimport pytest\nimport torch\nfrom sklearn.datasets import make_classification\nfrom sklearn.model"
  },
  {
    "path": "sequoia/methods/avalanche_methods/cwr_star.py",
    "chars": 1477,
    "preview": "\"\"\" Method based on CWRStar from [Avalanche](https://github.com/ContinualAI/avalanche).\n\nSee `avalanche.training.plugins"
  },
  {
    "path": "sequoia/methods/avalanche_methods/cwr_star_test.py",
    "chars": 371,
    "preview": "\"\"\" WIP: Tests for the CWRStar Method.\n\nFor now this only inherits the tests from the AvalancheMethod class.\n\"\"\"\nfrom ty"
  },
  {
    "path": "sequoia/methods/avalanche_methods/ewc.py",
    "chars": 2971,
    "preview": "\"\"\" Method based on EWC from [Avalanche](https://github.com/ContinualAI/avalanche).\n\nSee `avalanche.training.plugins.ewc"
  },
  {
    "path": "sequoia/methods/avalanche_methods/ewc_test.py",
    "chars": 6275,
    "preview": "\"\"\" WIP: Tests for the EWC Method.\n\nFor now this only inherits the tests from the AvalancheMethod class.\n\"\"\"\nfrom typing"
  },
  {
    "path": "sequoia/methods/avalanche_methods/experience.py",
    "chars": 6337,
    "preview": "\"\"\" 'Wrapper' around a PassiveEnvironment from Sequoia, disguising it as an 'Experience'\nfrom Avalanche.\n\"\"\"\nfrom typing"
  },
  {
    "path": "sequoia/methods/avalanche_methods/gdumb.py",
    "chars": 7111,
    "preview": "\"\"\" Method based on GDumb from [Avalanche](https://github.com/ContinualAI/avalanche).\n\nSee `avalanche.training.plugins.g"
  },
  {
    "path": "sequoia/methods/avalanche_methods/gdumb_test.py",
    "chars": 360,
    "preview": "\"\"\" WIP: Tests for the GDumb Method.\n\nFor now this only inherits the tests from the AvalancheMethod class.\n\"\"\"\nfrom typi"
  },
  {
    "path": "sequoia/methods/avalanche_methods/gem.py",
    "chars": 1627,
    "preview": "\"\"\" Method based on GEM from [Avalanche](https://github.com/ContinualAI/avalanche).\n\nSee `avalanche.training.plugins.gem"
  },
  {
    "path": "sequoia/methods/avalanche_methods/gem_test.py",
    "chars": 350,
    "preview": "\"\"\" WIP: Tests for the GEM Method.\n\nFor now this only inherits the tests from the AvalancheMethod class.\n\"\"\"\nfrom typing"
  },
  {
    "path": "sequoia/methods/avalanche_methods/lwf.py",
    "chars": 3742,
    "preview": "\"\"\" Method based on LwF from [Avalanche](https://github.com/ContinualAI/avalanche).\n\nSee `avalanche.training.plugins.lwf"
  },
  {
    "path": "sequoia/methods/avalanche_methods/lwf_test.py",
    "chars": 350,
    "preview": "\"\"\" WIP: Tests for the LwF Method.\n\nFor now this only inherits the tests from the AvalancheMethod class.\n\"\"\"\nfrom typing"
  },
  {
    "path": "sequoia/methods/avalanche_methods/naive.py",
    "chars": 1197,
    "preview": "\"\"\" 'Naive' method from [Avalanche](https://github.com/ContinualAI/avalanche).\n\nSee `avalanche.training.strategies.Naive"
  },
  {
    "path": "sequoia/methods/avalanche_methods/naive_test.py",
    "chars": 360,
    "preview": "\"\"\" WIP: Tests for the Naive Method.\n\nFor now this only inherits the tests from the AvalancheMethod class.\n\"\"\"\nfrom typi"
  },
  {
    "path": "sequoia/methods/avalanche_methods/patched_models.py",
    "chars": 10424,
    "preview": "\"\"\" Patch for the multi-task models in Avalanche, so that we can evaluate on future\ntasks, by selecting random predictio"
  },
  {
    "path": "sequoia/methods/avalanche_methods/plugins.py",
    "chars": 4120,
    "preview": "\"\"\" WIP: @lebrice: Plugins that I was using while trying to get the BaseStrategy and\nplugins from Avalanche to work dire"
  },
  {
    "path": "sequoia/methods/avalanche_methods/replay.py",
    "chars": 4674,
    "preview": "\"\"\" Method based on Replay from [Avalanche](https://github.com/ContinualAI/avalanche).\n\nSee `avalanche.training.plugins."
  },
  {
    "path": "sequoia/methods/avalanche_methods/replay_test.py",
    "chars": 365,
    "preview": "\"\"\" WIP: Tests for the Replay Method.\n\nFor now this only inherits the tests from the AvalancheMethod class.\n\"\"\"\nfrom typ"
  },
  {
    "path": "sequoia/methods/avalanche_methods/synaptic_intelligence.py",
    "chars": 12840,
    "preview": "\"\"\" Method based on SynapticIntelligence from [Avalanche](https://github.com/ContinualAI/avalanche).\n\nSee `avalanche.tra"
  }
]

// ... and 260 more files (download for full content)

About this extraction

This page contains the full source code of the lebrice/Sequoia GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 460 files (2.6 MB), approximately 715.6k tokens, and a symbol index with 3006 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!