SYMBOL INDEX (1721 symbols across 145 files) FILE: dopamine/colab/utils.py function load_baselines (line 99) | def load_baselines(base_dir, verbose=False): function load_statistics (line 149) | def load_statistics(log_path, iteration_number=None, verbose=True): function get_latest_file (line 180) | def get_latest_file(path): function get_latest_iteration (line 196) | def get_latest_iteration(path): function summarize_data (line 221) | def summarize_data(data, summary_keys): function read_experiment (line 256) | def read_experiment( FILE: dopamine/continuous_domains/run_experiment.py function create_continuous_agent (line 43) | def create_continuous_agent( function create_continuous_runner (line 112) | def create_continuous_runner(base_dir, schedule='continuous_train_and_ev... class ContinuousRunner (line 137) | class ContinuousRunner(base_run_experiment.Runner): method __init__ (line 144) | def __init__( method _use_legacy_logger (line 216) | def _use_legacy_logger(self): method _has_collector_dispatcher (line 222) | def _has_collector_dispatcher(self): method _fine_grained_print_to_console (line 228) | def _fine_grained_print_to_console(self): method _save_tensorboard_summaries (line 233) | def _save_tensorboard_summaries( class ContinuousTrainRunner (line 265) | class ContinuousTrainRunner(ContinuousRunner): method __init__ (line 272) | def __init__( method _run_one_iteration (line 291) | def _run_one_iteration(self, iteration): method _save_tensorboard_summaries (line 330) | def _save_tensorboard_summaries( FILE: dopamine/continuous_domains/train.py function main (line 43) | def main(unused_argv): FILE: dopamine/discrete_domains/atari_lib.py function create_atari_environment (line 71) | def create_atari_environment( function maybe_transform_variable_names (line 152) | def maybe_transform_variable_names(variables, legacy_checkpoint_load=Fal... class NatureDQNNetwork (line 179) | class NatureDQNNetwork(tf.keras.Model): method __init__ (line 182) | def __init__(self, num_actions, name=None): method call (line 226) | def call(self, state): class RainbowNetwork (line 252) | class RainbowNetwork(tf.keras.Model): method __init__ (line 255) | def __init__(self, num_actions, num_atoms, support, name=None): method call (line 316) | def call(self, state): class ImplicitQuantileNetwork (line 343) | class ImplicitQuantileNetwork(tf.keras.Model): method __init__ (line 346) | def __init__(self, num_actions, quantile_embedding_dim, name=None): method call (line 407) | def call(self, state, num_quantiles): class AtariPreprocessing (line 459) | class AtariPreprocessing(object): method __init__ (line 475) | def __init__( method observation_space (line 523) | def observation_space(self): method action_space (line 534) | def action_space(self): method reward_range (line 538) | def reward_range(self): method metadata (line 542) | def metadata(self): method close (line 545) | def close(self): method reset (line 548) | def reset(self): method render (line 561) | def render(self, mode): method step (line 577) | def step(self, action): method _fetch_grayscale_observation (line 631) | def _fetch_grayscale_observation(self, output): method _pool_and_resize (line 645) | def _pool_and_resize(self): class GameOverWrapper (line 671) | class GameOverWrapper(object): method __init__ (line 674) | def __init__(self, environment): method observation_space (line 679) | def observation_space(self): method action_space (line 683) | def action_space(self): method reset (line 686) | def reset(self): method step (line 689) | def step(self, action): FILE: dopamine/discrete_domains/checkpointer.py function get_latest_checkpoint_number (line 59) | def get_latest_checkpoint_number( class Checkpointer (line 95) | class Checkpointer(object): method __init__ (line 98) | def __init__( method _generate_filename (line 139) | def _generate_filename(self, file_prefix, iteration_number): method _save_data_to_file (line 144) | def _save_data_to_file(self, data, filename): method save_checkpoint (line 149) | def save_checkpoint(self, iteration_number, data): method _clean_up_old_checkpoints (line 172) | def _clean_up_old_checkpoints(self, iteration_number): method _load_data_from_file (line 202) | def _load_data_from_file(self, filename): method load_checkpoint (line 208) | def load_checkpoint(self, iteration_number): FILE: dopamine/discrete_domains/gym_lib.py function create_gym_environment (line 60) | def create_gym_environment( class BasicDiscreteDomainNetwork (line 105) | class BasicDiscreteDomainNetwork(tf.keras.layers.Layer): method __init__ (line 122) | def __init__( method call (line 153) | def call(self, state): class CartpoleDQNNetwork (line 168) | class CartpoleDQNNetwork(tf.keras.Model): method __init__ (line 171) | def __init__(self, num_actions, name=None): method call (line 185) | def call(self, state): class FourierBasis (line 191) | class FourierBasis(object): method __init__ (line 205) | def __init__(self, nvars, min_vals=0, max_vals=None, order=3): method scale (line 216) | def scale(self, values): method compute_features (line 223) | def compute_features(self, features): class FourierDQNNetwork (line 230) | class FourierDQNNetwork(tf.keras.Model): method __init__ (line 233) | def __init__( method call (line 264) | def call(self, state): class CartpoleFourierDQNNetwork (line 283) | class CartpoleFourierDQNNetwork(FourierDQNNetwork): method __init__ (line 286) | def __init__(self, num_actions, name=None): class CartpoleRainbowNetwork (line 301) | class CartpoleRainbowNetwork(tf.keras.Model): method __init__ (line 304) | def __init__(self, num_actions, num_atoms, support, name=None): method call (line 323) | def call(self, state): class AcrobotDQNNetwork (line 332) | class AcrobotDQNNetwork(tf.keras.Model): method __init__ (line 335) | def __init__(self, num_actions, name=None): method call (line 349) | def call(self, state): class AcrobotFourierDQNNetwork (line 355) | class AcrobotFourierDQNNetwork(FourierDQNNetwork): method __init__ (line 358) | def __init__(self, num_actions, name=None): class AcrobotRainbowNetwork (line 374) | class AcrobotRainbowNetwork(tf.keras.Model): method __init__ (line 377) | def __init__(self, num_actions, num_atoms, support, name=None): method call (line 396) | def call(self, state): class LunarLanderDQNNetwork (line 405) | class LunarLanderDQNNetwork(tf.keras.Model): method __init__ (line 408) | def __init__(self, num_actions, name=None): method call (line 418) | def call(self, state): class MountainCarDQNNetwork (line 425) | class MountainCarDQNNetwork(tf.keras.Model): method __init__ (line 428) | def __init__(self, num_actions, name=None): method call (line 440) | def call(self, state): class GymPreprocessing (line 447) | class GymPreprocessing(object): method __init__ (line 450) | def __init__(self, environment, use_legacy_gym=False): method observation_space (line 456) | def observation_space(self): method action_space (line 460) | def action_space(self): method reward_range (line 464) | def reward_range(self): method metadata (line 468) | def metadata(self): method reset (line 471) | def reset(self): method step (line 478) | def step(self, action): FILE: dopamine/discrete_domains/iteration_statistics.py class IterationStatistics (line 22) | class IterationStatistics(object): method __init__ (line 36) | def __init__(self): method append (line 39) | def append(self, data_pairs): FILE: dopamine/discrete_domains/legacy_networks.py function nature_dqn_network (line 25) | def nature_dqn_network(num_actions, network_type, state): function rainbow_network (line 41) | def rainbow_network(num_actions, num_atoms, support, network_type, state): function implicit_quantile_network (line 59) | def implicit_quantile_network( function _basic_discrete_domain_network (line 85) | def _basic_discrete_domain_network( function cartpole_dqn_network (line 108) | def cartpole_dqn_network(num_actions, network_type, state): function fourier_dqn_network (line 127) | def fourier_dqn_network( function cartpole_fourier_dqn_network (line 150) | def cartpole_fourier_dqn_network(num_actions, network_type, state): function cartpole_rainbow_network (line 169) | def cartpole_rainbow_network( function acrobot_dqn_network (line 190) | def acrobot_dqn_network(num_actions, network_type, state): function acrobot_fourier_dqn_network (line 209) | def acrobot_fourier_dqn_network(num_actions, network_type, state): function acrobot_rainbow_network (line 228) | def acrobot_rainbow_network( FILE: dopamine/discrete_domains/logger.py class Logger (line 30) | class Logger(object): method __init__ (line 33) | def __init__(self, logging_dir, logs_duration=4): method __setitem__ (line 64) | def __setitem__(self, key, value): method _generate_filename (line 76) | def _generate_filename(self, filename_prefix, iteration_number): method log_to_file (line 80) | def log_to_file(self, filename_prefix, iteration_number): method is_logging_enabled (line 107) | def is_logging_enabled(self): FILE: dopamine/discrete_domains/run_experiment.py function load_gin_configs (line 47) | def load_gin_configs(gin_files, gin_bindings): function create_agent (line 62) | def create_agent( function create_runner (line 136) | def create_runner(base_dir, schedule='continuous_train_and_eval'): class Runner (line 162) | class Runner(object): method __init__ (line 181) | def __init__( method _use_legacy_logger (line 280) | def _use_legacy_logger(self): method _has_collector_dispatcher (line 286) | def _has_collector_dispatcher(self): method _fine_grained_print_to_console (line 292) | def _fine_grained_print_to_console(self): method _create_directories (line 297) | def _create_directories(self): method _initialize_checkpointer_and_maybe_resume (line 307) | def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_pr... method _initialize_episode (line 355) | def _initialize_episode(self): method _run_one_step (line 364) | def _run_one_step(self, action): method _end_episode (line 377) | def _end_episode(self, reward, terminal=True): method _run_one_episode (line 390) | def _run_one_episode(self): method _run_continued_episode (line 430) | def _run_continued_episode(self, start_step_count, max_step_count): method _run_one_phase (line 470) | def _run_one_phase(self, min_steps, statistics, run_mode_str): method _run_train_phase (line 515) | def _run_train_phase(self, statistics): method _run_eval_phase (line 548) | def _run_eval_phase(self, statistics): method _run_one_iteration (line 572) | def _run_one_iteration(self, iteration): method _save_tensorboard_summaries (line 622) | def _save_tensorboard_summaries( method _log_experiment (line 685) | def _log_experiment(self, iteration, statistics): method _checkpoint_experiment (line 699) | def _checkpoint_experiment(self, iteration): method run_experiment (line 714) | def run_experiment(self): class TrainRunner (line 743) | class TrainRunner(Runner): method __init__ (line 751) | def __init__( method _run_one_iteration (line 772) | def _run_one_iteration(self, iteration): method _save_tensorboard_summaries (line 811) | def _save_tensorboard_summaries( FILE: dopamine/discrete_domains/train.py function main (line 50) | def main(unused_argv): FILE: dopamine/jax/agents/dqn/dqn_agent.py function identity_epsilon (line 46) | def identity_epsilon( function create_optimizer (line 53) | def create_optimizer( function train (line 111) | def train( function target_q (line 151) | def target_q(target_network, next_states, rewards, terminals, cumulative... function linearly_decaying_epsilon (line 170) | def linearly_decaying_epsilon(decay_period, step, warmup_steps, epsilon): function select_action (line 195) | def select_action( class JaxDQNAgent (line 255) | class JaxDQNAgent(object): method __init__ (line 258) | def __init__( method _build_networks_and_optimizer (line 400) | def _build_networks_and_optimizer(self): method _build_replay_buffer (line 408) | def _build_replay_buffer(self): method _sample_from_replay_buffer (line 423) | def _sample_from_replay_buffer(self): method _sync_weights (line 436) | def _sync_weights(self): method _reset_state (line 440) | def _reset_state(self): method _record_observation (line 444) | def _record_observation(self, observation): method begin_episode (line 460) | def begin_episode(self, observation): method step (line 492) | def step(self, reward, observation): method end_episode (line 529) | def end_episode(self, reward, terminal=True): method _train_step (line 551) | def _train_step(self): method _store_transition (line 604) | def _store_transition( method bundle_and_checkpoint (line 661) | def bundle_and_checkpoint(self, checkpoint_dir, iteration_number): method unbundle (line 690) | def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary): method set_collector_dispatcher (line 733) | def set_collector_dispatcher(self, collector_dispatcher): FILE: dopamine/jax/agents/full_rainbow/full_rainbow_agent.py function zero_epsilon (line 51) | def zero_epsilon( function select_action (line 58) | def select_action( function get_logits (line 98) | def get_logits(model, states, rng): function get_q_values (line 103) | def get_q_values(model, states, rng): function train (line 118) | def train( function target_output (line 201) | def target_output( class JaxFullRainbowAgent (line 247) | class JaxFullRainbowAgent(dqn_agent.JaxDQNAgent): method __init__ (line 250) | def __init__( method _build_networks_and_optimizer (line 340) | def _build_networks_and_optimizer(self): method _build_replay_buffer (line 350) | def _build_replay_buffer(self): method _training_step_update (line 368) | def _training_step_update(self): method _store_transition (line 443) | def _store_transition( method _train_step (line 481) | def _train_step(self): method begin_episode (line 501) | def begin_episode(self, observation): method step (line 528) | def step(self, reward, observation): FILE: dopamine/jax/agents/implicit_quantile/implicit_quantile_agent.py function target_quantile_values (line 39) | def target_quantile_values( function train (line 106) | def train( function select_action (line 209) | def select_action( class JaxImplicitQuantileAgent (line 279) | class JaxImplicitQuantileAgent(dqn_agent.JaxDQNAgent): method __init__ (line 282) | def __init__( method _build_networks_and_optimizer (line 394) | def _build_networks_and_optimizer(self): method begin_episode (line 403) | def begin_episode(self, observation): method step (line 436) | def step(self, reward, observation): method _train_step (line 474) | def _train_step(self): FILE: dopamine/jax/agents/ppo/ppo_agent.py function train (line 45) | def train( function calculate_advantages_and_returns (line 192) | def calculate_advantages_and_returns( function create_minibatches_and_shuffle (line 222) | def create_minibatches_and_shuffle( function train_minibatch (line 279) | def train_minibatch( function select_action (line 391) | def select_action(network_def, params, state, rng): class PPOAgent (line 416) | class PPOAgent(dqn_agent.JaxDQNAgent): method __init__ (line 419) | def __init__( method _build_networks_and_optimizer (line 568) | def _build_networks_and_optimizer(self): method _build_replay_buffer (line 582) | def _build_replay_buffer(self): method begin_episode (line 597) | def begin_episode(self, observation): method step (line 621) | def step(self, reward, observation): method _train_step (line 650) | def _train_step(self): method bundle_and_checkpoint (line 707) | def bundle_and_checkpoint(self, checkpoint_dir, iteration_number): method unbundle (line 735) | def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary): FILE: dopamine/jax/agents/quantile/quantile_agent.py function target_distribution (line 38) | def target_distribution( function train (line 65) | def train( class JaxQuantileAgent (line 131) | class JaxQuantileAgent(dqn_agent.JaxDQNAgent): method __init__ (line 134) | def __init__( method _build_networks_and_optimizer (line 225) | def _build_networks_and_optimizer(self): method _build_replay_buffer (line 232) | def _build_replay_buffer(self): method _train_step (line 250) | def _train_step(self): FILE: dopamine/jax/agents/rainbow/rainbow_agent.py function train (line 54) | def train( function target_distribution (line 103) | def target_distribution( function select_action (line 146) | def select_action( class JaxRainbowAgent (line 208) | class JaxRainbowAgent(dqn_agent.JaxDQNAgent): method __init__ (line 211) | def __init__( method _build_networks_and_optimizer (line 308) | def _build_networks_and_optimizer(self): method _build_replay_buffer (line 318) | def _build_replay_buffer(self): method begin_episode (line 339) | def begin_episode(self, observation): method step (line 373) | def step(self, reward, observation): method _train_step (line 411) | def _train_step(self): function project_distribution (line 492) | def project_distribution(supports, weights, target_support): FILE: dopamine/jax/agents/sac/sac_agent.py function train (line 72) | def train( function select_action (line 262) | def select_action(network_def, params, state, rng, eval_mode=False): class SACAgent (line 287) | class SACAgent(dqn_agent.JaxDQNAgent): method __init__ (line 290) | def __init__( method _build_networks_and_optimizer (line 438) | def _build_networks_and_optimizer(self): method _build_replay_buffer (line 455) | def _build_replay_buffer(self): method _maybe_sync_weights (line 470) | def _maybe_sync_weights(self): method begin_episode (line 488) | def begin_episode(self, observation): method step (line 523) | def step(self, reward, observation): method _train_step (line 563) | def _train_step(self): method bundle_and_checkpoint (line 631) | def bundle_and_checkpoint(self, checkpoint_dir, iteration_number): method unbundle (line 662) | def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary): FILE: dopamine/jax/checkpointers.py class Checkpointable (line 71) | class Checkpointable(Protocol): method to_state_dict (line 74) | def to_state_dict(self) -> Dict[str, Any]: method from_state_dict (line 77) | def from_state_dict(self, state_dict: Dict[str, Any]) -> None: class CheckpointHandler (line 84) | class CheckpointHandler(checkpoint.CheckpointHandler, Generic[Checkpoint... method __init__ (line 87) | def __init__(self, filename: str = 'checkpoint.msgpack') -> None: method save (line 90) | def save(self, directory: epath.Path, item: CheckpointableT) -> None: method restore (line 106) | def restore( method restore (line 112) | def restore(self, directory: epath.Path, item: None = None) -> Dict[st... method restore (line 115) | def restore( method structure (line 133) | def structure(self, directory: epath.Path) -> None: FILE: dopamine/jax/continuous_networks.py class ActorOutput (line 32) | class ActorOutput(NamedTuple): class CriticOutput (line 40) | class CriticOutput(NamedTuple): class ActorCriticOutput (line 47) | class ActorCriticOutput(NamedTuple): class PPOActorOutput (line 54) | class PPOActorOutput(NamedTuple): class PPOCriticOutput (line 62) | class PPOCriticOutput(NamedTuple): class PPOActorCriticOutput (line 68) | class PPOActorCriticOutput(NamedTuple): class _Tanh (line 75) | class _Tanh(tfb.Tanh): method _inverse (line 77) | def _inverse(self, y): function _transform_distribution (line 85) | def _transform_distribution(dist, mean, magnitude): function _shifted_uniform (line 104) | def _shifted_uniform(minval=0.0, maxval=1.0, dtype=jnp.float32): function create_activation (line 114) | def create_activation( class ActorNetwork (line 126) | class ActorNetwork(nn.Module): method __call__ (line 139) | def __call__( class CriticNetwork (line 187) | class CriticNetwork(nn.Module): method __call__ (line 198) | def __call__(self, state: jnp.ndarray, action: jnp.ndarray) -> jnp.nda... class ActorCriticNetwork (line 216) | class ActorCriticNetwork(nn.Module): method setup (line 228) | def setup(self): method __call__ (line 250) | def __call__( method actor (line 279) | def actor(self, state: jnp.ndarray, key: jnp.ndarray) -> ActorOutput: method critic (line 295) | def critic(self, state: jnp.ndarray, action: jnp.ndarray) -> CriticOut... class PPOActorNetwork (line 314) | class PPOActorNetwork(nn.Module): method __call__ (line 331) | def __call__( class PPOCriticNetwork (line 383) | class PPOCriticNetwork(nn.Module): method __call__ (line 397) | def __call__(self, state: jnp.ndarray) -> jnp.ndarray: class PPOActorCriticNetwork (line 412) | class PPOActorCriticNetwork(nn.Module): method setup (line 428) | def setup(self): method __call__ (line 440) | def __call__( method actor (line 462) | def actor( method critic (line 484) | def critic(self, state: jnp.ndarray) -> PPOCriticOutput: FILE: dopamine/jax/losses.py function huber_loss (line 20) | def huber_loss( function mse_loss (line 41) | def mse_loss(targets: jnp.ndarray, predictions: jnp.ndarray) -> jnp.ndar... function softmax_cross_entropy_loss_with_logits (line 46) | def softmax_cross_entropy_loss_with_logits( FILE: dopamine/jax/networks.py function preprocess_atari_inputs (line 53) | def preprocess_atari_inputs(x): class Stack (line 62) | class Stack(nn.Module): method __call__ (line 70) | def __call__(self, x): class ImpalaEncoder (line 100) | class ImpalaEncoder(nn.Module): method setup (line 107) | def setup(self): method __call__ (line 119) | def __call__(self, x): class ImpalaDQNNetwork (line 127) | class ImpalaDQNNetwork(nn.Module): method setup (line 134) | def setup(self): method __call__ (line 138) | def __call__(self, x): class NatureDQNNetwork (line 152) | class NatureDQNNetwork(nn.Module): method __call__ (line 159) | def __call__(self, x): class ClassicControlDQNNetwork (line 183) | class ClassicControlDQNNetwork(nn.Module): method setup (line 193) | def setup(self): method __call__ (line 207) | def __call__(self, x): class FourierBasis (line 222) | class FourierBasis(object): method __init__ (line 236) | def __init__( method scale (line 257) | def scale(self, values): method compute_features (line 264) | def compute_features(self, features): class JaxFourierDQNNetwork (line 271) | class JaxFourierDQNNetwork(nn.Module): method __call__ (line 280) | def __call__(self, x): class RainbowNetwork (line 298) | class RainbowNetwork(nn.Module): method __call__ (line 306) | def __call__(self, x, support): class ClassicControlRainbowNetwork (line 337) | class ClassicControlRainbowNetwork(nn.Module): method setup (line 348) | def setup(self): method __call__ (line 361) | def __call__(self, x, support): class ImplicitQuantileNetwork (line 380) | class ImplicitQuantileNetwork(nn.Module): method __call__ (line 388) | def __call__(self, x, num_quantiles, rng): class QuantileNetwork (line 433) | class QuantileNetwork(nn.Module): method __call__ (line 441) | def __call__(self, x): class NoisyNetwork (line 473) | class NoisyNetwork(nn.Module): method sample_noise (line 485) | def sample_noise(key, shape): method f (line 489) | def f(x): method __call__ (line 494) | def __call__(self, x, features, bias=True, kernel_init=None): function feature_layer (line 532) | def feature_layer(key, noisy, eval_mode=False): class FullRainbowNetwork (line 545) | class FullRainbowNetwork(nn.Module): method __call__ (line 564) | def __call__(self, x, support, eval_mode=False, key=None): class PPOSharedNetwork (line 610) | class PPOSharedNetwork(nn.Module): method __call__ (line 616) | def __call__(self, x) -> jnp.ndarray: class PPOActorNetwork (line 651) | class PPOActorNetwork(nn.Module): method __call__ (line 657) | def __call__(self, x) -> jnp.ndarray: class PPOCriticNetwork (line 662) | class PPOCriticNetwork(nn.Module): method __call__ (line 666) | def __call__(self, x) -> jnp.ndarray: class PPODiscreteActorCriticNetwork (line 672) | class PPODiscreteActorCriticNetwork(nn.Module): method setup (line 678) | def setup(self): method __call__ (line 684) | def __call__( method actor (line 691) | def actor( method critic (line 707) | def critic(self, state: jnp.ndarray) -> continuous_networks.PPOCriticO... FILE: dopamine/jax/replay_memory/accumulator.py class Accumulator (line 32) | class Accumulator(checkpointers.Checkpointable, Protocol[ReplayElementT]): method accumulate (line 34) | def accumulate( method clear (line 39) | def clear(self) -> None: class TransitionAccumulator (line 43) | class TransitionAccumulator(Accumulator[elements.ReplayElement]): method __init__ (line 53) | def __init__( method _make_replay_element (line 70) | def _make_replay_element(self) -> 'elements.ReplayElement | None': method accumulate (line 152) | def accumulate( method clear (line 188) | def clear(self) -> None: method to_state_dict (line 192) | def to_state_dict(self) -> dict[str, Any]: method from_state_dict (line 202) | def from_state_dict(self, state_dict: dict[str, Any]) -> None: FILE: dopamine/jax/replay_memory/elements.py class TransitionElement (line 28) | class TransitionElement(typing.NamedTuple): class ReplayElementProtocol (line 37) | class ReplayElementProtocol(Protocol): method pack (line 39) | def pack(self) -> 'ReplayElementProtocol': method unpack (line 42) | def unpack(self) -> 'ReplayElementProtocol': method is_compressed (line 46) | def is_compressed(self) -> bool: function compress (line 50) | def compress(buffer: npt.NDArray) -> npt.NDArray: function uncompress (line 77) | def uncompress(compressed: npt.NDArray) -> npt.NDArray: class ReplayElement (line 97) | class ReplayElement(ReplayElementProtocol, struct.PyTreeNode): method pack (line 107) | def pack(self) -> 'ReplayElement': method unpack (line 117) | def unpack(self) -> 'ReplayElement': method is_compressed (line 128) | def is_compressed(self) -> bool: FILE: dopamine/jax/replay_memory/replay_buffer.py class ReplayBuffer (line 38) | class ReplayBuffer(checkpointers.Checkpointable, Generic[ReplayElementT]): method __init__ (line 65) | def __init__( method add (line 90) | def add(self, transition: elements.TransitionElement, **kwargs: Any) -... method sample (line 109) | def sample( method sample (line 118) | def sample( method sample (line 126) | def sample( method update (line 155) | def update(self, keys: npt.NDArray[ReplayItemID], **kwargs: Any) -> None: method update (line 159) | def update(self, keys: ReplayItemID, **kwargs: Any) -> None: method update (line 162) | def update( method clear (line 169) | def clear(self) -> None: method to_state_dict (line 176) | def to_state_dict(self) -> dict[str, Any]: method from_state_dict (line 198) | def from_state_dict(self, state_dict: dict[str, Any]) -> None: method _make_checkpoint_manager (line 228) | def _make_checkpoint_manager( method save (line 245) | def save(self, checkpoint_dir: str, iteration_number: int): method load (line 256) | def load(self, checkpoint_dir: str, iteration_number: int): FILE: dopamine/jax/replay_memory/samplers.py class SampleMetadata (line 33) | class SampleMetadata: class SamplingDistribution (line 38) | class SamplingDistribution(checkpointers.Checkpointable, Protocol): method add (line 41) | def add(self, key: ReplayItemID, **kwargs: Any) -> None: method update (line 45) | def update(self, keys: npt.NDArray[ReplayItemID], **kwargs: Any) -> None: method update (line 49) | def update(self, key: ReplayItemID, **kwargs: Any) -> None: method update (line 52) | def update( method remove (line 57) | def remove(self, key: ReplayItemID) -> None: method sample (line 60) | def sample(self, size: int) -> SampleMetadata: method clear (line 63) | def clear(self) -> None: class UniformSamplingDistribution (line 67) | class UniformSamplingDistribution(SamplingDistribution): method __init__ (line 70) | def __init__( method size (line 99) | def size(self) -> int: method add (line 102) | def add(self, key: ReplayItemID, **kwargs: Any) -> None: method update (line 115) | def update( method update (line 121) | def update(self, key: ReplayItemID, *args: Any, **kwargs: Any) -> None: method update (line 124) | def update( method remove (line 134) | def remove(self, key: ReplayItemID) -> None: method sample (line 150) | def sample(self, size: int) -> SampleMetadata: method clear (line 164) | def clear(self) -> None: method to_state_dict (line 168) | def to_state_dict(self) -> dict[str, Any]: method from_state_dict (line 175) | def from_state_dict(self, state_dict: dict[str, Any]) -> None: class PrioritizedSampleMetadata (line 184) | class PrioritizedSampleMetadata(SampleMetadata): class PrioritizedSamplingDistribution (line 189) | class PrioritizedSamplingDistribution(UniformSamplingDistribution): method __init__ (line 192) | def __init__( method add (line 206) | def add(self, key: ReplayItemID, *, priority: float) -> None: method update (line 216) | def update( method update (line 225) | def update(self, keys: ReplayItemID, *, priorities: float) -> None: method update (line 228) | def update( method remove (line 245) | def remove(self, key: ReplayItemID) -> None: method sample (line 261) | def sample(self, size: int) -> PrioritizedSampleMetadata: method clear (line 280) | def clear(self) -> None: method to_state_dict (line 284) | def to_state_dict(self) -> dict[str, Any]: method from_state_dict (line 290) | def from_state_dict(self, state_dict: dict[str, Any]): class SequentialSamplingDistribution (line 295) | class SequentialSamplingDistribution(UniformSamplingDistribution): method __init__ (line 298) | def __init__( method sample (line 306) | def sample(self, size: int) -> SampleMetadata: FILE: dopamine/jax/replay_memory/sum_tree.py class SumTree (line 25) | class SumTree(checkpointers.Checkpointable): method __init__ (line 28) | def __init__(self, capacity: int) -> None: method set (line 38) | def set(self, index: int, value: float) -> None: method set (line 42) | def set( method set (line 47) | def set( method get (line 81) | def get(self, index: int) -> float: method get (line 85) | def get(self, index: npt.NDArray[np.int_]) -> npt.NDArray[np.float64]: method get (line 88) | def get( method root (line 95) | def root(self) -> float: method query (line 100) | def query(self, target: float) -> int: method query (line 104) | def query(self, target: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]: method query (line 107) | def query( method clear (line 155) | def clear(self) -> None: method to_state_dict (line 158) | def to_state_dict(self) -> dict[str, Any]: method from_state_dict (line 161) | def from_state_dict(self, state_dict: dict[str, Any]): FILE: dopamine/jax/serialization.py class NumpyEncoding (line 35) | class NumpyEncoding(TypedDict, total=False): class LongIntegerEncoding (line 44) | class LongIntegerEncoding(TypedDict): function encode (line 51) | def encode(obj: Any, chain: Optional[Callable[[Any], Any]] = None) -> Any: function _ (line 59) | def _( function _ (line 84) | def _( function decode (line 98) | def decode( function decode (line 105) | def decode( function decode (line 111) | def decode(obj: Any, chain: Optional[Callable[[Any], Any]] = None) -> Any: FILE: dopamine/labs/atari_100k/atari_100k_rainbow_agent.py function select_action (line 33) | def select_action( function _crop_with_indices (line 81) | def _crop_with_indices(img, x, y, cropped_shape): function _per_image_random_crop (line 86) | def _per_image_random_crop(key, img, cropped_shape): function _intensity_aug (line 99) | def _intensity_aug(key, x, scale=0.05): function drq_image_augmentation (line 107) | def drq_image_augmentation(key, obs, img_pad=4): function preprocess_inputs_with_augmentation (line 122) | def preprocess_inputs_with_augmentation(x, data_augmentation=False, rng=... class Atari100kRainbowAgent (line 133) | class Atari100kRainbowAgent(full_rainbow_agent.JaxFullRainbowAgent): method __init__ (line 136) | def __init__( method _training_step_update (line 177) | def _training_step_update(self): method step (line 231) | def step(self, reward=None, observation=None): method _reset_state (line 276) | def _reset_state(self, n_envs=None): method _record_observation (line 283) | def _record_observation(self, observation): method reset_all (line 305) | def reset_all(self, new_obs): method reset_one (line 311) | def reset_one(self, env_id): method delete_one (line 314) | def delete_one(self, env_id): method cache_train_state (line 319) | def cache_train_state(self): method restore_train_state (line 326) | def restore_train_state(self): method log_transition (line 331) | def log_transition(self, observation, action, reward, terminal, episod... method _store_transition (line 344) | def _store_transition( FILE: dopamine/labs/atari_100k/atari_100k_runner.py function create_env_wrapper (line 35) | def create_env_wrapper(create_env_fn): class DataEfficientAtariRunner (line 46) | class DataEfficientAtariRunner(run_experiment.Runner): method __init__ (line 53) | def __init__( method _run_one_phase (line 116) | def _run_one_phase( method _initialize_episode (line 192) | def _initialize_episode(self, envs): method _run_parallel (line 218) | def _run_parallel( method _run_train_phase (line 359) | def _run_train_phase(self, statistics): method _run_eval_phase (line 425) | def _run_eval_phase(self, statistics): method _run_one_iteration (line 470) | def _run_one_iteration(self, iteration): method _maybe_save_single_summary (line 512) | def _maybe_save_single_summary( method _save_tensorboard_summaries (line 530) | def _save_tensorboard_summaries( method run_experiment (line 575) | def run_experiment(self): class LoggedDataEfficientAtariRunner (line 594) | class LoggedDataEfficientAtariRunner(DataEfficientAtariRunner): method __init__ (line 597) | def __init__( method run_experiment (line 606) | def run_experiment(self): function delete_ind_from_array (line 616) | def delete_ind_from_array(array, ind, axis=0): FILE: dopamine/labs/atari_100k/eval_run_experiment.py class MaxEpisodeEvalRunner (line 26) | class MaxEpisodeEvalRunner(run_experiment.Runner): method __init__ (line 29) | def __init__( method _initialize_episode (line 39) | def _initialize_episode(self): method _run_no_ops (line 50) | def _run_no_ops(self): method _run_one_phase_fix_episodes (line 65) | def _run_one_phase_fix_episodes(self, max_episodes, statistics, run_mo... method _run_eval_phase (line 105) | def _run_eval_phase(self, statistics): FILE: dopamine/labs/atari_100k/normalization_utils.py function normalize_score (line 136) | def normalize_score(ret, game): FILE: dopamine/labs/atari_100k/replay_memory/deterministic_sum_tree.py function step (line 26) | def step(i, args): # pylint: disable=unused-argument function parallel_stratified_sample (line 41) | def parallel_stratified_sample(rng, nodes, i, n, depth): class DeterministicSumTree (line 65) | class DeterministicSumTree(sum_tree.SumTree): method __init__ (line 73) | def __init__(self, capacity): method _total_priority (line 99) | def _total_priority(self): method sample (line 107) | def sample(self, rng, query_value=None): method stratified_sample (line 124) | def stratified_sample(self, batch_size, rng): method get (line 138) | def get(self, node_index): method reset_priorities (line 149) | def reset_priorities(self): method set (line 153) | def set(self, node_index, value): FILE: dopamine/labs/atari_100k/replay_memory/subsequence_replay_buffer.py function modulo_range (line 49) | def modulo_range(start, length, modulo): function invalid_range (line 54) | def invalid_range(cursor, replay_capacity, stack_size, update_horizon): class JaxSubsequenceParallelEnvReplayBuffer (line 82) | class JaxSubsequenceParallelEnvReplayBuffer(object): method __init__ (line 101) | def __init__( method _create_storage (line 214) | def _create_storage(self): method get_add_args_signature (line 225) | def get_add_args_signature(self): method get_storage_signature (line 235) | def get_storage_signature(self): method _add_zero_transition (line 255) | def _add_zero_transition(self): method add (line 265) | def add( method _add (line 316) | def _add(self, *args): method _add_transition (line 328) | def _add_transition(self, transition): method _check_args_length (line 348) | def _check_args_length(self, *args): method _check_add_types (line 364) | def _check_add_types(self, *args): method is_empty (line 395) | def is_empty(self): method is_full (line 399) | def is_full(self): method ravel_indices (line 403) | def ravel_indices(self, indices_t, indices_b): method unravel_indices (line 408) | def unravel_indices(self, indices): method get_from_store (line 411) | def get_from_store(self, element_name, indices_t, indices_b): method cursor (line 415) | def cursor(self): method parallel_get_stack (line 419) | def parallel_get_stack(self, element_name, indices_t, indices_b, first... method get_terminal_stack (line 433) | def get_terminal_stack(self, index_t, index_b): method is_valid_transition (line 436) | def is_valid_transition(self, index_t, index_b): method _create_batch_arrays (line 485) | def _create_batch_arrays(self, batch_size): method num_elements (line 505) | def num_elements(self): method sample_index_batch (line 511) | def sample_index_batch(self, batch_size): method restore_leading_dims (line 579) | def restore_leading_dims(self, batch_size, subseq_len, tensor): method sample (line 582) | def sample(self, *args, **kwargs): method sample_transition_batch (line 585) | def sample_transition_batch( method get_transition_elements (line 736) | def get_transition_elements(self, batch_size=None, subseq_len=None): method _generate_filename (line 810) | def _generate_filename(self, checkpoint_dir, name, suffix): method _return_checkpointable_elements (line 813) | def _return_checkpointable_elements(self): method save (line 829) | def save(self, checkpoint_dir, iteration_number): method load (line 873) | def load(self, checkpoint_dir, suffix): method reset_priorities (line 907) | def reset_priorities(self): class PrioritizedJaxSubsequenceParallelEnvReplayBuffer (line 912) | class PrioritizedJaxSubsequenceParallelEnvReplayBuffer( method __init__ (line 917) | def __init__( method get_add_args_signature (line 957) | def get_add_args_signature(self): method _add (line 965) | def _add(self, *args): method sample_index_batch (line 988) | def sample_index_batch(self, batch_size): method sample_transition_batch (line 1028) | def sample_transition_batch( method set_priority (line 1049) | def set_priority(self, indices, priorities): method get_priority (line 1059) | def get_priority(self, indices): method get_transition_elements (line 1068) | def get_transition_elements(self, batch_size=None): method reset_priorities (line 1076) | def reset_priorities(self): FILE: dopamine/labs/atari_100k/spr_agent.py function get_logits (line 57) | def get_logits(model, states, actions, do_rollout, rng): function get_q_values (line 65) | def get_q_values(model, states, actions, do_rollout, rng): function get_spr_targets (line 71) | def get_spr_targets(model, states, key): function train (line 87) | def train( function target_output (line 243) | def target_output( class SPRAgent (line 288) | class SPRAgent(atari_100k_rainbow_agent.Atari100kRainbowAgent): method __init__ (line 291) | def __init__( method _build_networks_and_optimizer (line 334) | def _build_networks_and_optimizer(self): method _build_replay_buffer (line 347) | def _build_replay_buffer(self): method _sample_from_replay_buffer (line 370) | def _sample_from_replay_buffer(self): method _training_step_update (line 378) | def _training_step_update(self): FILE: dopamine/labs/atari_100k/spr_networks.py function _absolute_dims (line 43) | def _absolute_dims(rank, dims): class NoisyNetwork (line 52) | class NoisyNetwork(nn.Module): method sample_noise (line 58) | def sample_noise(key, shape): method f (line 62) | def f(x): method __call__ (line 67) | def __call__(self, x, rng_key, bias=True, kernel_init=None, eval_mode=... class NoStatsBatchNorm (line 124) | class NoStatsBatchNorm(nn.Module): method __call__ (line 159) | def __call__(self, x, use_running_average: Optional[bool] = None): function feature_layer (line 216) | def feature_layer(noisy, features): function renormalize (line 232) | def renormalize(tensor, has_batch=False): class ConvTMCell (line 242) | class ConvTMCell(nn.Module): method setup (line 255) | def setup(self): method __call__ (line 259) | def __call__(self, x, action, eval_mode=False, key=None): class RainbowCNN (line 289) | class RainbowCNN(nn.Module): method __call__ (line 302) | def __call__(self, x): class TransitionModel (line 319) | class TransitionModel(nn.Module): method __call__ (line 333) | def __call__(self, x, action): class SPRNetwork (line 349) | class SPRNetwork(nn.Module): method setup (line 370) | def setup(self): method encode (line 380) | def encode(self, x): method project (line 386) | def project(self, x, key, eval_mode): method spr_predict (line 393) | def spr_predict(self, x, key, eval_mode): method spr_rollout (line 400) | def spr_rollout(self, latent, actions, key): method __call__ (line 408) | def __call__( FILE: dopamine/labs/atari_100k/train.py function create_agent (line 61) | def create_agent( function set_random_seed (line 84) | def set_random_seed(seed): function main (line 92) | def main(unused_argv): FILE: dopamine/labs/cale/networks.py class NatureDQNEncoder (line 30) | class NatureDQNEncoder(nn.Module): method __call__ (line 36) | def __call__(self, x): class SACImpalaEncoder (line 71) | class SACImpalaEncoder(nn.Module): method setup (line 79) | def setup(self): method __call__ (line 93) | def __call__(self, x): class SACCALEConvNetwork (line 115) | class SACCALEConvNetwork(nn.Module): method setup (line 125) | def setup(self): method __call__ (line 147) | def __call__( method actor (line 177) | def actor( method critic (line 196) | def critic( class PPOCALEConvNetwork (line 216) | class PPOCALEConvNetwork(nn.Module): method setup (line 234) | def setup(self): method __call__ (line 255) | def __call__( method actor (line 278) | def actor( method critic (line 301) | def critic(self, state: jnp.ndarray) -> continuous_networks.PPOCriticO... FILE: dopamine/labs/cale/ppo_cale.py function select_action_eps_greedy (line 41) | def select_action_eps_greedy( class PPOCALEAgent (line 75) | class PPOCALEAgent(ppo_agent.PPOAgent): method __init__ (line 78) | def __init__( method _select_action (line 109) | def _select_action(self): method _maybe_log_action_distribution (line 126) | def _maybe_log_action_distribution(self): method begin_episode (line 133) | def begin_episode(self, observation): method step (line 147) | def step(self, reward, observation): method bundle_and_checkpoint (line 165) | def bundle_and_checkpoint(self, checkpoint_dir, iteration_number): FILE: dopamine/labs/cale/sac_cale.py function select_action_eps_greedy (line 42) | def select_action_eps_greedy( class SACCALEAgent (line 82) | class SACCALEAgent(sac_agent.SACAgent): method __init__ (line 85) | def __init__( method _select_action (line 120) | def _select_action(self): method _maybe_log_action_distribution (line 139) | def _maybe_log_action_distribution(self): method begin_episode (line 146) | def begin_episode(self, observation): method step (line 160) | def step(self, reward, observation): method bundle_and_checkpoint (line 178) | def bundle_and_checkpoint(self, checkpoint_dir, iteration_number): FILE: dopamine/labs/cale/utils.py function _polar_to_cartesian (line 68) | def _polar_to_cartesian(r, theta): function _polar_to_discrete_action (line 72) | def _polar_to_discrete_action(r, theta, fire, threshold=0.5): function get_action_number (line 91) | def get_action_number(r, theta, fire, threshold=0.5): FILE: dopamine/labs/environments/brax/brax_lib.py class BraxEnv (line 30) | class BraxEnv(object): method __init__ (line 33) | def __init__(self, env_name: str, seed: Optional[int] = None): method observation_space (line 42) | def observation_space(self) -> onp.ndarray: method action_space (line 46) | def action_space(self) -> int: method reward_range (line 50) | def reward_range(self): method metadata (line 54) | def metadata(self): method reset (line 57) | def reset(self) -> onp.ndarray: method step (line 63) | def step( function create_brax_environment (line 77) | def create_brax_environment(env_name, seed=None) -> BraxEnv: function create_brax_agent (line 83) | def create_brax_agent( function create_brax_runner (line 107) | def create_brax_runner( FILE: dopamine/labs/environments/brax/train.py function main (line 42) | def main(unused_argv): FILE: dopamine/labs/environments/minatar/minatar_env.py class MinAtarEnv (line 34) | class MinAtarEnv(object): method __init__ (line 37) | def __init__(self, game_name): method observation_space (line 43) | def observation_space(self): method action_space (line 47) | def action_space(self): method reward_range (line 51) | def reward_range(self): method metadata (line 55) | def metadata(self): method reset (line 58) | def reset(self): method step (line 63) | def step(self, action): class StickyMinAtarEnv (line 69) | class StickyMinAtarEnv(MinAtarEnv): method __init__ (line 72) | def __init__(self, game_name, action_repeat_probability, seed=None): method reset (line 79) | def reset(self): method step (line 83) | def step(self, action): function create_minatar_env (line 94) | def create_minatar_env( class MinatarDQNNetwork (line 108) | class MinatarDQNNetwork(nn.Module): method __call__ (line 115) | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: class MinatarRainbowNetwork (line 128) | class MinatarRainbowNetwork(nn.Module): method __call__ (line 136) | def __call__(self, x: jnp.ndarray, support: jnp.ndarray) -> jnp.ndarray: class MinatarQuantileNetwork (line 154) | class MinatarQuantileNetwork(nn.Module): method __call__ (line 162) | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: FILE: dopamine/labs/moes/agents/dqn_moe_agent.py function train (line 44) | def train( function target_q (line 212) | def target_q( function select_action (line 248) | def select_action( class DQNMoEAgent (line 284) | class DQNMoEAgent(dqn_agent.JaxDQNAgent): method __init__ (line 287) | def __init__( method _build_networks_and_optimizer (line 308) | def _build_networks_and_optimizer(self): method begin_episode (line 318) | def begin_episode(self, observation): method step (line 345) | def step(self, reward, observation): method _train_step (line 371) | def _train_step(self): method _training_step_update (line 383) | def _training_step_update(self): FILE: dopamine/labs/moes/agents/full_rainbow_moe_agent.py function train (line 47) | def train( class JaxFullRainbowMoEAgent (line 270) | class JaxFullRainbowMoEAgent(full_rainbow_agent.JaxFullRainbowAgent): method __init__ (line 273) | def __init__( method _train_step (line 290) | def _train_step(self): method _training_step_update (line 301) | def _training_step_update(self): FILE: dopamine/labs/moes/agents/losses.py function entropy (line 29) | def entropy(x): function naive_entropy (line 37) | def naive_entropy( function importance_loss (line 85) | def importance_loss( function load_loss (line 135) | def load_loss(loss_parameters: types.MoELossParameters) -> types.MoELoss... function aux_loss (line 214) | def aux_loss( FILE: dopamine/labs/moes/agents/rainbow_100k_moe_agent.py function train (line 46) | def train( class Atari100kRainbowMoEAgent (line 207) | class Atari100kRainbowMoEAgent(base_100k_rainbow.Atari100kRainbowAgent): method __init__ (line 210) | def __init__( method _train_step (line 225) | def _train_step(self): method _training_step_update (line 236) | def _training_step_update(self): FILE: dopamine/labs/moes/agents/types.py class MoELossParameters (line 45) | class MoELossParameters: function loss_params_flatten (line 58) | def loss_params_flatten(v): function loss_params_unflatten (line 72) | def loss_params_unflatten(aux_data, children): class MoELossStatistic (line 84) | class MoELossStatistic: function loss_stat_flatten (line 90) | def loss_stat_flatten(v): function loss_stat_unflatten (line 97) | def loss_stat_unflatten(aux_data, children): class MoELossReturn (line 109) | class MoELossReturn: function loss_return_flatten (line 114) | def loss_return_flatten(v): function loss_return_unflatten (line 121) | def loss_return_unflatten(aux_data, children): FILE: dopamine/labs/moes/architectures/moe.py class MoE (line 30) | class MoE(nn.Module): method setup (line 42) | def setup(self): method __call__ (line 46) | def __call__(self, x: jax.Array, *, key: jax.Array) -> types.MoEModule... FILE: dopamine/labs/moes/architectures/networks.py class ExpertModel (line 34) | class ExpertModel(nn.Module): method setup (line 44) | def setup(self): method __call__ (line 57) | def __call__(self, x): class BigExpertModel (line 74) | class BigExpertModel(nn.Module): method setup (line 87) | def setup(self): method __call__ (line 106) | def __call__(self, x): class RoutingType (line 130) | class RoutingType(enum.Enum): class MoEType (line 139) | class MoEType(enum.Enum): function _maybe_create_moe_module (line 148) | def _maybe_create_moe_module( class ImpalaMoE (line 195) | class ImpalaMoE(nn.Module): method setup (line 215) | def setup(self): method create_moe (line 228) | def create_moe(self, rng_key, noisy): method __call__ (line 242) | def __call__(self, x: jax.Array, *, key: jax.Array) -> types.NetworkRe... class NatureDQNMoE (line 324) | class NatureDQNMoE(nn.Module): method setup (line 344) | def setup(self): method create_moe (line 356) | def create_moe(self, rng_key, noisy): method __call__ (line 370) | def __call__(self, x: jax.Array, *, key: jax.Array) -> types.NetworkRe... class RainbowCNNet (line 461) | class RainbowCNNet(nn.Module): method __call__ (line 468) | def __call__(self, x): class FullRainbowMoENetwork (line 493) | class FullRainbowMoENetwork(nn.Module): method setup (line 517) | def setup(self): method create_moe (line 539) | def create_moe(self, rng_key, noisy): method __call__ (line 554) | def __call__(self, x, support, eval_mode=False, key=None): FILE: dopamine/labs/moes/architectures/routers.py class RandomRouter (line 26) | class RandomRouter(nn.Module): method setup (line 32) | def setup(self): method __call__ (line 36) | def __call__( class TopKRouter (line 64) | class TopKRouter(nn.Module): method setup (line 72) | def setup(self): method __call__ (line 76) | def __call__( FILE: dopamine/labs/moes/architectures/softmoe.py function l2_normalize (line 26) | def l2_normalize(x, axis, eps=1e-6): class SoftMoE (line 32) | class SoftMoE(nn.Module): method __call__ (line 44) | def __call__(self, x: jax.Array, *, key: jax.Array) -> types.MoEModule... FILE: dopamine/labs/moes/architectures/types.py class RouterReturn (line 22) | class RouterReturn: function router_flatten (line 29) | def router_flatten(v): function router_unflatten (line 36) | def router_unflatten(aux_data, children): class MoEModuleReturn (line 48) | class MoEModuleReturn: function module_flatten (line 54) | def module_flatten(v): function module_unflatten (line 61) | def module_unflatten(aux_data, children): class MoENetworkReturn (line 73) | class MoENetworkReturn: function network_flatten (line 81) | def network_flatten(v): function network_unflatten (line 88) | def network_unflatten(aux_data, children): class BaselineNetworkReturn (line 100) | class BaselineNetworkReturn: function baseline_network_flatten (line 107) | def baseline_network_flatten(v): function baseline_network_unflatten (line 114) | def baseline_network_unflatten(aux_data, children): FILE: dopamine/labs/moes/atari_100k_train.py function create_agent (line 56) | def create_agent( function set_random_seed (line 73) | def set_random_seed(seed): function main (line 81) | def main(unused_argv): FILE: dopamine/labs/offline_rl/fixed_replay.py class JaxFixedReplayBuffer (line 26) | class JaxFixedReplayBuffer(fixed_replay_buffer.FixedReplayBuffer): method __init__ (line 29) | def __init__( method _load_buffer (line 98) | def _load_buffer(self, suffix): method replay_capacity (line 138) | def replay_capacity(self): method reload_data (line 141) | def reload_data(self): FILE: dopamine/labs/offline_rl/jax/networks.py function preprocess_atari_inputs (line 34) | def preprocess_atari_inputs(x): function transform_and_concat_return (line 39) | def transform_and_concat_return(x, return_to_condition): class Stack (line 48) | class Stack(nn.Module): method __call__ (line 56) | def __call__(self, x): class CNNEncoder (line 86) | class CNNEncoder(nn.Module): method __call__ (line 95) | def __call__(self, x): class ImpalaEncoder (line 111) | class ImpalaEncoder(nn.Module): method setup (line 118) | def setup(self): method __call__ (line 125) | def __call__(self, x): class ImpalaNetworkWithRepresentations (line 132) | class ImpalaNetworkWithRepresentations(nn.Module): method setup (line 139) | def setup(self): method __call__ (line 143) | def __call__(self, x, stop_grad_representation=False): class JAXDQNNetworkWithRepresentations (line 176) | class JAXDQNNetworkWithRepresentations(nn.Module): method __call__ (line 183) | def __call__(self, x, stop_grad_representation=False): class ParameterizedRainbowNetwork (line 212) | class ParameterizedRainbowNetwork(nn.Module): method setup (line 234) | def setup(self): method __call__ (line 241) | def __call__( FILE: dopamine/labs/offline_rl/jax/offline_classy_cql_agent.py class ClassyLoss (line 41) | class ClassyLoss(enum.StrEnum): class TargetType (line 48) | class TargetType(enum.StrEnum): function get_network_outputs (line 55) | def get_network_outputs(model, states, rng): function target_output (line 67) | def target_output( function train (line 119) | def train( class OfflineClassyCQLAgent (line 209) | class OfflineClassyCQLAgent(full_rainbow_agent.JaxFullRainbowAgent): method __init__ (line 212) | def __init__( method train_step (line 315) | def train_step(self): method _training_step_update (line 320) | def _training_step_update(self): method _build_replay_buffer (line 385) | def _build_replay_buffer(self): method log_gradient_steps_per_epoch (line 413) | def log_gradient_steps_per_epoch(self): method _sample_from_replay_buffer (line 422) | def _sample_from_replay_buffer(self): method reload_data (line 428) | def reload_data(self): method step (line 432) | def step(self, reward, observation): FILE: dopamine/labs/offline_rl/jax/offline_dqn_agent.py class OfflineJaxDQNAgent (line 27) | class OfflineJaxDQNAgent(dqn_agent.JaxDQNAgent): method __init__ (line 30) | def __init__( method _build_replay_buffer (line 65) | def _build_replay_buffer(self): method _sample_from_replay_buffer (line 92) | def _sample_from_replay_buffer(self): method reload_data (line 98) | def reload_data(self): method step (line 102) | def step(self, reward, observation): method train_step (line 130) | def train_step(self): FILE: dopamine/labs/offline_rl/jax/offline_dr3_agent.py function train (line 33) | def train( function compute_dr3_loss (line 124) | def compute_dr3_loss(state_representations, next_state_representations): class OfflineJaxDR3Agent (line 134) | class OfflineJaxDR3Agent(offline_dqn_agent.OfflineJaxDQNAgent): method __init__ (line 137) | def __init__( method train_step (line 173) | def train_step(self): FILE: dopamine/labs/offline_rl/jax/offline_rainbow_agent.py function get_logits_and_q_values (line 37) | def get_logits_and_q_values(model, states, rng): function train (line 54) | def train( class OfflineJaxRainbowAgent (line 135) | class OfflineJaxRainbowAgent(full_rainbow_agent.JaxFullRainbowAgent): method __init__ (line 138) | def __init__( method _training_step_update (line 179) | def _training_step_update(self): method _build_replay_buffer (line 231) | def _build_replay_buffer(self): method log_gradient_steps_per_epoch (line 259) | def log_gradient_steps_per_epoch(self): method _sample_from_replay_buffer (line 268) | def _sample_from_replay_buffer(self): method reload_data (line 274) | def reload_data(self): method train_step (line 278) | def train_step(self): method step (line 281) | def step(self, reward, observation): FILE: dopamine/labs/offline_rl/jax/return_conditioned_bc_agent.py function get_q_values (line 36) | def get_q_values(model, states, returns_to_condition): function train (line 42) | def train( class WrappedNetworkDef (line 83) | class WrappedNetworkDef(object): method __init__ (line 86) | def __init__(self, network_def, min_return, max_return): method set_return_to_condition (line 92) | def set_return_to_condition(self, return_multiplier=1.0): method apply (line 100) | def apply(self, params, x, support, key=None): class JaxReturnConditionedBCAgent (line 111) | class JaxReturnConditionedBCAgent(offline_rainbow_agent.OfflineJaxRainbo... method __init__ (line 114) | def __init__( method _training_step_update (line 137) | def _training_step_update(self): method _build_networks_and_optimizer (line 166) | def _build_networks_and_optimizer(self): method _create_wrapped_network (line 178) | def _create_wrapped_network(self): method set_return_to_condition (line 184) | def set_return_to_condition(self, return_multiplier): method step (line 187) | def step(self, reward, observation): method _get_action (line 200) | def _get_action(self): method begin_episode (line 220) | def begin_episode(self, observation): FILE: dopamine/labs/offline_rl/jax/run_experiment.py class FixedReplayRunner (line 29) | class FixedReplayRunner(run_experiment.Runner): method __init__ (line 32) | def __init__(self, base_dir, create_agent_fn, num_epochs=None): method _run_train_phase (line 48) | def _run_train_phase(self): method _run_one_iteration (line 70) | def _run_one_iteration(self, iteration): method _save_tensorboard_summaries (line 96) | def _save_tensorboard_summaries( FILE: dopamine/labs/offline_rl/jax/train.py function create_offline_agent (line 62) | def create_offline_agent( function create_replay_dir (line 108) | def create_replay_dir(xm_parameters): function main (line 123) | def main(unused_argv): FILE: dopamine/labs/offline_rl/rlu_tfds/scaling_dataset_utils.py function _parallel_lookup (line 30) | def _parallel_lookup( function choose_indices (line 80) | def choose_indices( function _get_dataset_with_idxs (line 133) | def _get_dataset_with_idxs(game, idxs, num_datasets): function create_dataset_with_expertise (line 160) | def create_dataset_with_expertise( FILE: dopamine/labs/offline_rl/rlu_tfds/tfds_atari_utils.py class BatchToTransition (line 21) | class BatchToTransition(object): method __init__ (line 24) | def __init__( method create_transitions (line 39) | def create_transitions(self, batch, rtg_batch=None, episode_return=None): function get_transition_dataset_fn (line 62) | def get_transition_dataset_fn( function load_data_splits (line 101) | def load_data_splits(dataset_name, data_splits): function uniformly_subsampled_atari_data (line 117) | def uniformly_subsampled_atari_data(dataset_name, data_percent): function create_atari_ds_loader (line 132) | def create_atari_ds_loader( function create_ds_iterator (line 159) | def create_ds_iterator(ds, batch_size=32, repeat=True): function build_tfds_replay (line 167) | def build_tfds_replay( FILE: dopamine/labs/offline_rl/rlu_tfds/tfds_replay.py function get_atari_ds_name_from_replay (line 25) | def get_atari_ds_name_from_replay( function game_from_dataset_name (line 38) | def game_from_dataset_name(dataset_name: str) -> str: class JaxFixedReplayBufferTFDS (line 43) | class JaxFixedReplayBufferTFDS(object): method __init__ (line 46) | def __init__( method min_max_returns (line 110) | def min_max_returns(self): method _load_buffer (line 116) | def _load_buffer(self, suffix): method load_single_buffer (line 120) | def load_single_buffer(self, suffix): method _load_replay_buffers (line 123) | def _load_replay_buffers(self, unused_num_buffers): method get_transition_elements (line 126) | def get_transition_elements(self): method sample_transition_batch (line 129) | def sample_transition_batch(self): method load (line 133) | def load(self, *args, **kwargs): # pylint: disable=unused-argument method reload_buffer (line 136) | def reload_buffer(self, num_buffers): method save (line 139) | def save(self, *args, **kwargs): # pylint: disable=unused-argument method add (line 142) | def add(self, *args, **kwargs): # pylint: disable=unused-argument method add_count (line 146) | def add_count(self): method gradient_steps_per_epoch (line 150) | def gradient_steps_per_epoch(self): method replay_capacity (line 154) | def replay_capacity(self): method reload_data (line 157) | def reload_data(self): FILE: dopamine/labs/redo/networks.py class IdentityLayer (line 24) | class IdentityLayer(nn.Module): method __call__ (line 28) | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: class ScalableNatureDQNNetwork (line 33) | class ScalableNatureDQNNetwork(nn.Module): method _record_activations (line 40) | def _record_activations(self, x, layer): method __call__ (line 47) | def __call__(self, x): class Stack (line 88) | class Stack(nn.Module): method _record_activations (line 95) | def _record_activations(self, x, layer): method __call__ (line 102) | def __call__(self, x): class ScalableDQNResNet (line 138) | class ScalableDQNResNet(nn.Module): method _record_activations (line 145) | def _record_activations(self, x, layer): method __call__ (line 152) | def __call__(self, x): class ScalableRainbowNetwork (line 191) | class ScalableRainbowNetwork(nn.Module): method _record_activations (line 200) | def _record_activations(self, x, layer): method __call__ (line 207) | def __call__(self, x, support): class FullRainbowNetwork (line 256) | class FullRainbowNetwork(nn.Module): method _record_activations (line 275) | def _record_activations(self, x, layer): method __call__ (line 282) | def __call__(self, x, support, eval_mode=False, key=None): FILE: dopamine/labs/redo/recycled_atari100k_rainbow_agent.py class RecycledAtari100kRainbowAgent (line 29) | class RecycledAtari100kRainbowAgent( method __init__ (line 34) | def __init__( method _log_stats (line 86) | def _log_stats(self, log_dict, step): method _train_step (line 96) | def _train_step(self): method _training_step_update (line 118) | def _training_step_update(self): method _weight_recycle (line 172) | def _weight_recycle(self, update_step, online_params): method _sample_batch_for_statistics (line 203) | def _sample_batch_for_statistics(self): method get_intermediates (line 214) | def get_intermediates(self, online_params): FILE: dopamine/labs/redo/recycled_dqn_agents.py function loss_fn (line 30) | def loss_fn(params, target, state, action, apply_fn): function get_gradients (line 45) | def get_gradients( function apply_updates_jitted (line 79) | def apply_updates_jitted(online_params, grad, optimizer_state, optimizer): class RecycledDQNAgent (line 88) | class RecycledDQNAgent(dqn_agent.JaxDQNAgent): method __init__ (line 91) | def __init__( method _log_stats (line 145) | def _log_stats(self, log_dict, step): method _train_step (line 155) | def _train_step(self): method _training_step_update (line 177) | def _training_step_update(self): method _sample_batch_for_statistics (line 238) | def _sample_batch_for_statistics(self): method get_intermediates (line 249) | def get_intermediates(self, online_params): FILE: dopamine/labs/redo/recycled_rainbow_agent.py class RecycledRainbowAgent (line 29) | class RecycledRainbowAgent(rainbow_agent.JaxRainbowAgent): method __init__ (line 32) | def __init__( method _log_stats (line 80) | def _log_stats(self, log_dict, step): method _train_step (line 90) | def _train_step(self): method _training_step_update (line 112) | def _training_step_update(self): method _sample_batch_for_statistics (line 183) | def _sample_batch_for_statistics(self): method get_intermediates (line 194) | def get_intermediates(self, online_params): FILE: dopamine/labs/redo/tfagents/sac_train_eval.py function custom_call (line 81) | def custom_call(self, inputs): function get_all_layers (line 87) | def get_all_layers(model, filter_fn=lambda _: True): function is_dense_layer (line 109) | def is_dense_layer(layer): function scale_width (line 113) | def scale_width(num_units: int, width: float): function get_intermedieates (line 118) | def get_intermedieates(*networks): function create_fc_layers (line 134) | def create_fc_layers(layer_units, width=1.0, weight_decay=0): function create_identity_layer (line 145) | def create_identity_layer(): function create_sequential_critic_network (line 149) | def create_sequential_critic_network( class _TanhNormalProjectionNetworkWrapper (line 223) | class _TanhNormalProjectionNetworkWrapper( method __init__ (line 228) | def __init__(self, sample_spec, predefined_outer_rank=1, weight_decay=... method call (line 234) | def call(self, inputs, network_state=(), **kwargs): function create_sequential_actor_network (line 243) | def create_sequential_actor_network( class RecycledSacAgent (line 278) | class RecycledSacAgent(sac_agent.SacAgent): method __init__ (line 281) | def __init__( method _train (line 340) | def _train(self, experience, weights): method log_weights_mean (line 486) | def log_weights_mean(self, model_name, model): method is_dead_neurons_log_iter (line 494) | def is_dead_neurons_log_iter(self): method log_deadneurons_models (line 501) | def log_deadneurons_models(self): method calculate_neuron_score_all_layers (line 518) | def calculate_neuron_score_all_layers( method log_dead_neurons_count (line 563) | def log_dead_neurons_count( method log_histogram (line 606) | def log_histogram(self, name, activation): method is_reset_iter (line 609) | def is_reset_iter(self): method reset_models (line 616) | def reset_models(self): method reset_momentum (line 624) | def reset_momentum(self, optimizer, var, mask): method reset_model_weights (line 631) | def reset_model_weights(self): method reset_weights (line 644) | def reset_weights(self, optimizer, model, target_model=None): method create_new_weights (line 682) | def create_new_weights( method reset_model_neurons (line 717) | def reset_model_neurons(self): method get_dead_neurons (line 743) | def get_dead_neurons(self, neuron_score): method get_mask_dead_neurons_weights (line 752) | def get_mask_dead_neurons_weights(self, act, act_grad, model): method reset_dead_neurons (line 781) | def reset_dead_neurons( method _rescale_weights (line 893) | def _rescale_weights(self, neuron_mask, var, new_weights, axis, scaler): method create_mask_helper (line 905) | def create_mask_helper(self, neuron_mask, prev_layer_size, next_layer_... function train_eval (line 931) | def train_eval( function main (line 1208) | def main(_): FILE: dopamine/labs/redo/train.py function create_agent_recycled (line 60) | def create_agent_recycled( function create_runner_recycled (line 106) | def create_runner_recycled( function main (line 139) | def main(unused_argv): FILE: dopamine/labs/redo/weight_recyclers.py function leastk_mask (line 27) | def leastk_mask(scores, ones_fraction): function reset_momentum (line 51) | def reset_momentum(momentum, mask): function weight_reinit_zero (line 56) | def weight_reinit_zero(param, mask): function weight_reinit_random (line 65) | def weight_reinit_random( function _weight_normalization_per_neuron_norm (line 114) | def _weight_normalization_per_neuron_norm(param, axes): function _get_norm_per_neuron (line 121) | def _get_norm_per_neuron(param, axes): class BaseRecycler (line 126) | class BaseRecycler: method __init__ (line 144) | def __init__( method update_reset_layers (line 165) | def update_reset_layers(self, reset_start_layer_idx): method is_update_iter (line 168) | def is_update_iter(self, step): method update_weights (line 171) | def update_weights(self, intermediates, params, key, opt_state): method maybe_update_weights (line 174) | def maybe_update_weights( method is_reset (line 186) | def is_reset(self, update_step): method is_intermediated_required (line 190) | def is_intermediated_required(self, update_step): method is_logging_step (line 193) | def is_logging_step(self, step): method maybe_log_deadneurons (line 196) | def maybe_log_deadneurons(self, update_step, intermediates): method intersected_dead_neurons_with_last_reset (line 203) | def intersected_dead_neurons_with_last_reset( method log_intersected_dead_neurons (line 212) | def log_intersected_dead_neurons(self, intermediates): method log_dead_neurons_count (line 261) | def log_dead_neurons_count(self, intermediates): method estimate_neuron_score (line 303) | def estimate_neuron_score(self, activation, is_cbp=False): class LayerReset (line 330) | class LayerReset(BaseRecycler): method is_reset (line 339) | def is_reset(self, update_step): method update_weights (line 346) | def update_weights(self, intermediates, params, key, opt_state): class NeuronRecycler (line 393) | class NeuronRecycler(BaseRecycler): method __init__ (line 406) | def __init__( method intersected_dead_neurons_with_last_reset (line 442) | def intersected_dead_neurons_with_last_reset( method is_reset (line 451) | def is_reset(self, update_step): method is_intermediated_required (line 458) | def is_intermediated_required(self, update_step): method update_reset_layers (line 463) | def update_reset_layers(self, reset_start_layer_idx): method update_weights (line 467) | def update_weights(self, intermediates, params, key, opt_state): method recycle_dead_neurons (line 473) | def recycle_dead_neurons(self, intermedieates, params, key, opt_state): method _score2mask (line 567) | def _score2mask(self, activation, param, next_param, key): method create_masks (line 572) | def create_masks(self, param_dict, activations_dict, key): method create_mask_helper (line 652) | def create_mask_helper(self, neuron_mask, current_param, next_param): class NeuronRecyclerScheduled (line 701) | class NeuronRecyclerScheduled(NeuronRecycler): method __init__ (line 704) | def __init__( method _score2mask (line 715) | def _score2mask(self, activation, param, next_param, key): FILE: dopamine/labs/sac_from_pixels/continuous_networks.py class SACEncoderOutputs (line 28) | class SACEncoderOutputs: class SACEncoderNetwork (line 35) | class SACEncoderNetwork(nn.Module): method __call__ (line 47) | def __call__(self, x): class SACConvNetwork (line 90) | class SACConvNetwork(nn.Module): method setup (line 102) | def setup(self): method __call__ (line 111) | def __call__( method actor (line 141) | def actor( method critic (line 160) | def critic( FILE: dopamine/labs/sac_from_pixels/deepmind_control_lib.py function create_deepmind_control_environment (line 36) | def create_deepmind_control_environment( class DeepmindControlPreprocessing (line 67) | class DeepmindControlPreprocessing(gym.Env): method __init__ (line 75) | def __init__(self, environment: control.Environment, action_repeat: in... method observation_space (line 91) | def observation_space(self) -> spaces.Box: method action_space (line 111) | def action_space(self) -> spaces.Box: method reward_range (line 119) | def reward_range(self) -> Tuple[float, float]: method metadata (line 124) | def metadata(self) -> Mapping[Any, Any]: method reset (line 128) | def reset(self) -> np.ndarray: method step (line 137) | def step( method _get_observation (line 168) | def _get_observation(self, timestep: dm_env.TimeStep) -> np.ndarray: class DeepmindControlWithImagesPreprocessing (line 174) | class DeepmindControlWithImagesPreprocessing(DeepmindControlPreprocessing): method __init__ (line 177) | def __init__( method observation_space (line 195) | def observation_space(self) -> spaces.Box: method _get_observation (line 206) | def _get_observation(self, timestep: dm_env.TimeStep) -> np.ndarray: method _render_image (line 209) | def _render_image(self) -> np.ndarray: FILE: dopamine/labs/tandem_dqn/run_experiment.py function create_tandem_agents_and_checkpoints (line 32) | def create_tandem_agents_and_checkpoints( class TandemRunner (line 60) | class TandemRunner(run_experiment.Runner): method __init__ (line 63) | def __init__(self, base_dir, create_agent_fn, suite='atari'): method _initialize_episode (line 76) | def _initialize_episode(self, agent_type='active'): method _run_one_episode (line 88) | def _run_one_episode(self, agent_type='active'): method _run_one_phase (line 132) | def _run_one_phase( method _run_eval_phase (line 178) | def _run_eval_phase(self, statistics, agent_type='active'): method _run_one_iteration (line 204) | def _run_one_iteration(self, iteration): method _save_tensorboard_summaries (line 242) | def _save_tensorboard_summaries( FILE: dopamine/labs/tandem_dqn/tandem_dqn_agent.py function train (line 31) | def train( function target_q (line 83) | def target_q( class TandemDQNAgent (line 111) | class TandemDQNAgent(dqn_agent.JaxDQNAgent): method __init__ (line 114) | def __init__(self, num_actions, double_dqn=True, summary_writer=None): method _build_networks_and_optimizer (line 118) | def _build_networks_and_optimizer(self): method _sync_weights (line 137) | def _sync_weights(self): method _select_action (line 142) | def _select_action(self, params): method begin_episode (line 160) | def begin_episode(self, agent_type, observation): method step (line 175) | def step(self, agent_type, reward, observation): method _train_step (line 193) | def _train_step(self): method bundle_and_checkpoint (line 260) | def bundle_and_checkpoint(self, checkpoint_dir, iteration_number): method unbundle (line 292) | def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary): FILE: dopamine/labs/tandem_dqn/train.py function main (line 47) | def main(unused_argv): FILE: dopamine/metrics/collector.py class Collector (line 38) | class Collector(abc.ABC): method __init__ (line 41) | def __init__( method get_name (line 57) | def get_name(self) -> str: method check_type (line 60) | def check_type(self, data_type: str) -> bool: method write (line 64) | def write( method flush (line 69) | def flush(self) -> None: method close (line 72) | def close(self) -> None: FILE: dopamine/metrics/collector_dispatcher.py function add_collector (line 61) | def add_collector(name: str, constructor: CollectorConstructorType) -> N... class CollectorDispatcher (line 66) | class CollectorDispatcher(object): method __init__ (line 69) | def __init__( method write (line 83) | def write( method flush (line 102) | def flush(self) -> None: method close (line 106) | def close(self) -> None: FILE: dopamine/metrics/console_collector.py class ConsoleCollector (line 28) | class ConsoleCollector(collector.Collector): method __init__ (line 31) | def __init__(self, base_dir: Union[str, None], save_to_file: bool = Tr... method get_name (line 39) | def get_name(self) -> str: method write (line 42) | def write( method close (line 56) | def close(self) -> None: FILE: dopamine/metrics/pickle_collector.py class PickleCollector (line 28) | class PickleCollector(collector.Collector): method __init__ (line 31) | def __init__(self, base_dir: str): method get_name (line 39) | def get_name(self) -> str: method write (line 42) | def write( method flush (line 53) | def flush(self): FILE: dopamine/metrics/statistics_instance.py class StatisticsInstance (line 22) | class StatisticsInstance: FILE: dopamine/metrics/tensorboard_collector.py class TensorboardCollector (line 23) | class TensorboardCollector(collector.Collector): method __init__ (line 26) | def __init__(self, base_dir: str): method get_name (line 34) | def get_name(self) -> str: method write (line 37) | def write( method flush (line 46) | def flush(self): FILE: dopamine/tf/agents/dqn/dqn_agent.py function linearly_decaying_epsilon (line 41) | def linearly_decaying_epsilon(decay_period, step, warmup_steps, epsilon): function identity_epsilon (line 66) | def identity_epsilon( class DQNAgent (line 73) | class DQNAgent(object): method __init__ (line 76) | def __init__( method _create_network (line 243) | def _create_network(self, name): method _build_networks (line 256) | def _build_networks(self): method _build_replay_buffer (line 284) | def _build_replay_buffer(self, use_staging): method _build_target_q_op (line 303) | def _build_target_q_op(self): method _build_train_op (line 324) | def _build_train_op(self): method _build_sync_op (line 348) | def _build_sync_op(self): method begin_episode (line 371) | def begin_episode(self, observation): method step (line 389) | def step(self, reward, observation): method end_episode (line 412) | def end_episode(self, reward): method _select_action (line 424) | def _select_action(self): method _train_step (line 449) | def _train_step(self): method _record_observation (line 477) | def _record_observation(self, observation): method _store_transition (line 493) | def _store_transition(self, last_observation, action, reward, is_termi... method _reset_state (line 511) | def _reset_state(self): method bundle_and_checkpoint (line 515) | def bundle_and_checkpoint(self, checkpoint_dir, iteration_number): method unbundle (line 546) | def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary): FILE: dopamine/tf/agents/implicit_quantile/implicit_quantile_agent.py class ImplicitQuantileAgent (line 34) | class ImplicitQuantileAgent(rainbow_agent.RainbowAgent): method __init__ (line 37) | def __init__( method _create_network (line 99) | def _create_network(self, name): method _build_networks (line 114) | def _build_networks(self): method _build_target_quantile_values_op (line 185) | def _build_target_quantile_values_op(self): method _build_train_op (line 229) | def _build_train_op(self): FILE: dopamine/tf/agents/rainbow/rainbow_agent.py class RainbowAgent (line 50) | class RainbowAgent(dqn_agent.DQNAgent): method __init__ (line 53) | def __init__( method _create_network (line 161) | def _create_network(self, name): method _build_replay_buffer (line 176) | def _build_replay_buffer(self, use_staging): method _build_target_distribution (line 202) | def _build_target_distribution(self): method _build_train_op (line 259) | def _build_train_op(self): method _store_transition (line 314) | def _store_transition( function project_distribution (line 344) | def project_distribution( FILE: dopamine/tf/replay_memory/circular_replay_buffer.py function modulo_range (line 50) | def modulo_range(start, length, modulo): function invalid_range (line 55) | def invalid_range(cursor, replay_capacity, stack_size, update_horizon): class OutOfGraphReplayBuffer (line 85) | class OutOfGraphReplayBuffer(object): method __init__ (line 105) | def __init__( method _episode_end_indices (line 209) | def _episode_end_indices(self): method _create_storage (line 217) | def _create_storage(self): method get_add_args_signature (line 226) | def get_add_args_signature(self): method get_storage_signature (line 237) | def get_storage_signature(self): method _add_zero_transition (line 258) | def _add_zero_transition(self): method add (line 268) | def add( method _add (line 324) | def _add(self, *args): method _add_transition (line 336) | def _add_transition(self, transition): method _check_args_length (line 355) | def _check_args_length(self, *args): method _check_add_types (line 371) | def _check_add_types(self, *args): method is_empty (line 398) | def is_empty(self): method is_full (line 402) | def is_full(self): method cursor (line 406) | def cursor(self): method get_range (line 410) | def get_range(self, array, start_index, end_index): method get_observation_stack (line 443) | def get_observation_stack(self, index): method _get_element_stack (line 446) | def _get_element_stack(self, index, element_name): method get_terminal_stack (line 453) | def get_terminal_stack(self, index): method is_valid_transition (line 458) | def is_valid_transition(self, index): method _create_batch_arrays (line 498) | def _create_batch_arrays(self, batch_size): method sample_index_batch (line 518) | def sample_index_batch(self, batch_size): method sample_transition_batch (line 567) | def sample_transition_batch(self, batch_size=None, indices=None): method get_transition_elements (line 652) | def get_transition_elements(self, batch_size=None): method _generate_filename (line 700) | def _generate_filename(self, checkpoint_dir, name, suffix): method _return_checkpointable_elements (line 703) | def _return_checkpointable_elements(self): method save (line 719) | def save(self, checkpoint_dir, iteration_number): method load (line 771) | def load(self, checkpoint_dir, suffix): class WrappedReplayBuffer (line 822) | class WrappedReplayBuffer(object): method __init__ (line 834) | def __init__( method add (line 920) | def add(self, observation, action, reward, terminal, *args): method create_sampling_ops (line 939) | def create_sampling_ops(self, use_staging): method _set_transition_shape (line 963) | def _set_transition_shape(self, transition, transition_type): method _set_up_staging (line 974) | def _set_up_staging(self, transition): method unpack_transition (line 991) | def unpack_transition(self, transition_tensors, transition_type): method save (line 1013) | def save(self, checkpoint_dir, iteration_number): method load (line 1024) | def load(self, checkpoint_dir, suffix): FILE: dopamine/tf/replay_memory/prioritized_replay_buffer.py class OutOfGraphPrioritizedReplayBuffer (line 36) | class OutOfGraphPrioritizedReplayBuffer( method __init__ (line 44) | def __init__( method get_add_args_signature (line 104) | def get_add_args_signature(self): method _add (line 122) | def _add(self, *args): method sample_index_batch (line 147) | def sample_index_batch(self, batch_size): method sample_transition_batch (line 180) | def sample_transition_batch(self, batch_size=None, indices=None): method set_priority (line 211) | def set_priority(self, indices, priorities): method get_priority (line 228) | def get_priority(self, indices): method get_transition_elements (line 250) | def get_transition_elements(self, batch_size=None): class WrappedPrioritizedReplayBuffer (line 272) | class WrappedPrioritizedReplayBuffer( method __init__ (line 286) | def __init__( method tf_set_priority (line 368) | def tf_set_priority(self, indices, priorities): method tf_get_priority (line 385) | def tf_get_priority(self, indices): FILE: dopamine/tf/replay_memory/sum_tree.py class SumTree (line 30) | class SumTree(object): method __init__ (line 65) | def __init__(self, capacity): method _total_priority (line 92) | def _total_priority(self): method sample (line 100) | def sample(self, query_value=None): method stratified_sample (line 144) | def stratified_sample(self, batch_size): method get (line 170) | def get(self, node_index): method set (line 181) | def set(self, node_index, value): FILE: dopamine/utils/agent_visualizer.py class AgentVisualizer (line 38) | class AgentVisualizer(object): method __init__ (line 41) | def __init__( method visualize (line 87) | def visualize(self): method save_frame (line 96) | def save_frame(self): method generate_video (line 115) | def generate_video(self, video_file='video.mp4'): FILE: dopamine/utils/atari_plotter.py class AtariPlotter (line 29) | class AtariPlotter(plotter.Plotter): method __init__ (line 41) | def __init__(self, parameter_dict=None): method draw (line 55) | def draw(self): FILE: dopamine/utils/bar_plotter.py class BarPlotter (line 41) | class BarPlotter(plotter.Plotter): method __init__ (line 62) | def __init__(self, parameter_dict=None): method draw (line 83) | def draw(self): FILE: dopamine/utils/example_viz.py function main (line 65) | def main(_): FILE: dopamine/utils/example_viz_lib.py class MyDQNAgent (line 51) | class MyDQNAgent(dqn_agent.DQNAgent): method __init__ (line 54) | def __init__(self, sess, num_actions, summary_writer=None): method step (line 61) | def step(self, reward, observation): method _select_action (line 65) | def _select_action(self): method reload_checkpoint (line 74) | def reload_checkpoint(self, checkpoint_path, use_legacy_checkpoint=Fal... method get_q_values (line 96) | def get_q_values(self): method get_rewards (line 99) | def get_rewards(self): class MyRainbowAgent (line 103) | class MyRainbowAgent(rainbow_agent.RainbowAgent): method __init__ (line 106) | def __init__(self, sess, num_actions, summary_writer=None): method step (line 112) | def step(self, reward, observation): method reload_checkpoint (line 116) | def reload_checkpoint(self, checkpoint_path, use_legacy_checkpoint=Fal... method get_probabilities (line 138) | def get_probabilities(self): method get_rewards (line 143) | def get_rewards(self): class MyRunner (line 147) | class MyRunner(run_experiment.Runner): method __init__ (line 150) | def __init__( method _initialize_checkpointer_and_maybe_resume (line 161) | def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_pr... method _run_one_iteration (line 167) | def _run_one_iteration(self, iteration): method visualize (line 173) | def visualize(self, record_path, num_global_steps=500): function create_dqn_agent (line 253) | def create_dqn_agent(sess, environment, summary_writer=None): function create_rainbow_agent (line 261) | def create_rainbow_agent(sess, environment, summary_writer=None): function create_runner (line 269) | def create_runner( function run (line 278) | def run( FILE: dopamine/utils/line_plotter.py class LinePlotter (line 41) | class LinePlotter(plotter.Plotter): method __init__ (line 63) | def __init__(self, parameter_dict=None): method draw (line 83) | def draw(self): FILE: dopamine/utils/plotter.py class Plotter (line 32) | class Plotter(object): method __init__ (line 37) | def __init__(self, parameter_dict=None): method _setup_plot (line 53) | def _setup_plot(self): method draw (line 81) | def draw(self): method x (line 90) | def x(self): method y (line 94) | def y(self): FILE: dopamine/utils/test_utils.py class MockReplayBuffer (line 27) | class MockReplayBuffer(object): method __init__ (line 30) | def __init__(self, is_jax=False): FILE: tests/dopamine/atari_init_test.py class AtariInitTest (line 31) | class AtariInitTest(tf.test.TestCase): method setUp (line 33) | def setUp(self): method test_atari_init (line 47) | def test_atari_init(self): FILE: tests/dopamine/continuous_domains/run_experiment_test.py class RunExperimentTest (line 30) | class RunExperimentTest(parameterized.TestCase): method setUp (line 32) | def setUp(self): method testCreateContinuousAgentReturnsAgent (line 53) | def testCreateContinuousAgentReturnsAgent(self): method testCreateContinuousAgentWithInvalidNameRaisesException (line 58) | def testCreateContinuousAgentWithInvalidNameRaisesException(self): method testCreateContinuousRunnerCreatesCorrectRunner (line 74) | def testCreateContinuousRunnerCreatesCorrectRunner( method testCreateContinuousRunnerFailsWithInvalidName (line 85) | def testCreateContinuousRunnerFailsWithInvalidName(self): FILE: tests/dopamine/discrete_domains/atari_lib_test.py class AtariLibTest (line 30) | class AtariLibTest(tf.test.TestCase): method testCreateAtariEnvironmentWithoutGameName (line 33) | def testCreateAtariEnvironmentWithoutGameName(self): method testCreateAtariEnvironment (line 39) | def testCreateAtariEnvironment(self, mock_gym_make, mock_atari_lib): class MockALE (line 57) | class MockALE(object): method __init__ (line 60) | def __init__(self): method lives (line 63) | def lives(self): method getScreenGrayscale (line 66) | def getScreenGrayscale(self, screen): # pylint: disable=invalid-name class MockEnvALEWrapper (line 70) | class MockEnvALEWrapper(object): method __init__ (line 73) | def __init__(self): class MockEnvironment (line 77) | class MockEnvironment(object): method __init__ (line 80) | def __init__(self, screen_size=10, max_steps=10): method reset (line 87) | def reset(self): method get_observation (line 92) | def get_observation(self): method step (line 96) | def step(self, action): method render (line 105) | def render(self, mode): class AtariPreprocessingTest (line 109) | class AtariPreprocessingTest(tf.test.TestCase): method testResetPassesObservation (line 111) | def testResetPassesObservation(self): method testTerminalPassedThrough (line 118) | def testTerminalPassedThrough(self): method testFrameSkipAccumulatesReward (line 132) | def testFrameSkipAccumulatesReward(self): method testMaxFramePooling (line 143) | def testMaxFramePooling(self): FILE: tests/dopamine/discrete_domains/checkpointer_test.py class CheckpointerTest (line 32) | class CheckpointerTest(tf.test.TestCase): method testCheckpointingInitialization (line 34) | def testCheckpointingInitialization(self): method testLogToFileWithValidDirectoryDefaultPrefix (line 54) | def testLogToFileWithValidDirectoryDefaultPrefix(self): method testLogToFileWithValidDirectoryCustomPrefix (line 64) | def testLogToFileWithValidDirectoryCustomPrefix(self): method testLoadLatestCheckpointWithInvalidDir (line 77) | def testLoadLatestCheckpointWithInvalidDir(self): method testLoadLatestCheckpointWithEmptyDir (line 82) | def testLoadLatestCheckpointWithEmptyDir(self): method testLoadLatestCheckpointWithOverride (line 86) | def testLoadLatestCheckpointWithOverride(self): method testLoadLatestCheckpoint (line 95) | def testLoadLatestCheckpoint(self): method testGarbageCollection (line 106) | def testGarbageCollection(self): method testGarbageCollectionWithCheckpointFrequency (line 126) | def testGarbageCollectionWithCheckpointFrequency(self): method testGarbageCollectionWithCheckpointDuration (line 158) | def testGarbageCollectionWithCheckpointDuration(self): method testGarbageCollectionWithKeepEvery (line 189) | def testGarbageCollectionWithKeepEvery(self): FILE: tests/dopamine/discrete_domains/gym_lib_test.py class MockGymEnvironment (line 22) | class MockGymEnvironment(object): method __init__ (line 25) | def __init__(self, legacy_gym_api): method reset (line 32) | def reset(self): method step (line 37) | def step(self, unused_action): class GymPreprocessingTest (line 43) | class GymPreprocessingTest(parameterized.TestCase): method testAll (line 46) | def testAll(self, use_legacy_gym): FILE: tests/dopamine/discrete_domains/iteration_statistics_test.py class IterationStatisticsTest (line 27) | class IterationStatisticsTest(tf.test.TestCase): method testMissingValue (line 29) | def testMissingValue(self): method testAddOneValue (line 34) | def testAddOneValue(self): method testAddManyValues (line 46) | def testAddManyValues(self): FILE: tests/dopamine/discrete_domains/logger_test.py class LoggerTest (line 34) | class LoggerTest(parameterized.TestCase): method setUp (line 36) | def setUp(self): method testLoggingDisabledWithEmptyDirectory (line 42) | def testLoggingDisabledWithEmptyDirectory(self): method testLoggingDisabledWithInvalidDirectory (line 46) | def testLoggingDisabledWithInvalidDirectory(self): method testLoggingEnabledWithValidDirectory (line 50) | def testLoggingEnabledWithValidDirectory(self): method testSetEntry (line 54) | def testSetEntry(self): method testLogToFileWithInvalidDirectory (line 70) | def testLogToFileWithInvalidDirectory(self): method testLogToFileWithValidDirectory (line 75) | def testLogToFileWithValidDirectory(self): method testGarbageCollectionWithDefaults (line 97) | def testGarbageCollectionWithDefaults(self, logs_duration): FILE: tests/dopamine/discrete_domains/run_experiment_test.py function _create_mock_checkpointer (line 43) | def _create_mock_checkpointer(): class MockEnvironment (line 50) | class MockEnvironment(object): method __init__ (line 53) | def __init__(self, max_steps=10): method reset (line 58) | def reset(self): method step (line 62) | def step(self, action): method render (line 72) | def render(self, mode): class MockLogger (line 76) | class MockLogger(object): method __init__ (line 79) | def __init__(self, test_cls=None, run_asserts=True, data=None): method __setitem__ (line 87) | def __setitem__(self, key, val): method log_to_file (line 94) | def log_to_file(self, filename_prefix, iteration_number): class RunExperimentTest (line 103) | class RunExperimentTest(tf.test.TestCase): method testLoadGinConfigs (line 106) | def testLoadGinConfigs(self, mock_parse_config_files_and_bindings): method testNoAgentName (line 116) | def testNoAgentName(self): method testCreateDQNAgent (line 121) | def testCreateDQNAgent(self, mock_dqn_agent): method testCreateRainbowAgent (line 137) | def testCreateRainbowAgent(self, mock_rainbow_agent): method testCreateImplicitQuantileAgent (line 153) | def testCreateImplicitQuantileAgent(self, mock_implicit_quantile_agent): method testCreateRunnerUnknown (line 168) | def testCreateRunnerUnknown(self): method testCreateRunner (line 175) | def testCreateRunner(self, mock_create_agent, mock_runner_constructor): method testCreateTrainRunner (line 185) | def testCreateTrainRunner(self, mock_create_agent, mock_runner_constru... class RunnerTest (line 194) | class RunnerTest(tf.test.TestCase): method _agent_step (line 196) | def _agent_step(self, reward, observation): method setUp (line 203) | def setUp(self): method testInitializeCheckpointingWithNoCheckpointFile (line 225) | def testInitializeCheckpointingWithNoCheckpointFile(self, mock_get_lat... method testInitializeCheckpointingWhenCheckpointUnbundleFails (line 234) | def testInitializeCheckpointingWhenCheckpointUnbundleFails( method testInitializeCheckpointingWhenCheckpointUnbundleSucceeds (line 261) | def testInitializeCheckpointingWhenCheckpointUnbundleSucceeds( method testRunOneEpisode (line 287) | def testRunOneEpisode(self): method testRunOneEpisodeWithLowMaxSteps (line 303) | def testRunOneEpisodeWithLowMaxSteps(self): method testRunOneContinuedEpisode (line 318) | def testRunOneContinuedEpisode(self): method testRunOneContinuedEpisodeWithLowMaxSteps (line 336) | def testRunOneContinuedEpisodeWithLowMaxSteps(self): method testRunOnePhase (line 353) | def testRunOnePhase(self): method testRunOneIteration (line 380) | def testRunOneIteration(self, mock_collector_dispatcher): method testLogExperiment (line 431) | def testLogExperiment(self, mock_logger_constructor): method testLogExperimentWithoutLegacyLogging (line 453) | def testLogExperimentWithoutLegacyLogging(self, mock_logger_constructor): method testCheckpointExperiment (line 477) | def testCheckpointExperiment( method testRunExperimentWithInconsistentRange (line 508) | def testRunExperimentWithInconsistentRange( method testRunExperiment (line 526) | def testRunExperiment( method testCollectorDispatcherSetup (line 574) | def testCollectorDispatcherSetup(self, mock_collector_dispatcher): FILE: tests/dopamine/jax/agents/dqn/dqn_agent_test.py class DQNAgentTest (line 37) | class DQNAgentTest(absltest.TestCase): method setUp (line 39) | def setUp(self): method _create_test_agent (line 57) | def _create_test_agent(self, allow_partial_reload=False): method testCreateAgentWithDefaults (line 104) | def testCreateAgentWithDefaults(self): method testPreprocessFnParam (line 112) | def testPreprocessFnParam(self): method testBeginEpisode (line 143) | def testBeginEpisode(self): method testStepEval (line 182) | def testStepEval(self): method testStepTrain (line 217) | def testStepTrain(self): method testNonTupleObservationShape (line 272) | def testNonTupleObservationShape(self): method _custom_shapes_test (line 277) | def _custom_shapes_test(self, shape, dtype, stack_size): method testStepTrainCustomObservationShapes (line 332) | def testStepTrainCustomObservationShapes(self): method testStepTrainCustomTypes (line 337) | def testStepTrainCustomTypes(self): method testStepTrainCustomStackSizes (line 342) | def testStepTrainCustomStackSizes(self): method testBundlingWithNonexistentDirectory (line 347) | def testBundlingWithNonexistentDirectory(self): method testUnbundlingWithFailingReplayBuffer (line 351) | def testUnbundlingWithFailingReplayBuffer(self): method testUnbundlingWithNoBundleDictionary (line 359) | def testUnbundlingWithNoBundleDictionary(self): method testPartialUnbundling (line 364) | def testPartialUnbundling(self): method testBundling (line 376) | def testBundling(self): method testLinearlyDecayingEpsilon (line 396) | def testLinearlyDecayingEpsilon(self): FILE: tests/dopamine/jax/agents/full_rainbow/full_rainbow_agent_test.py class FullRainbowAgentTest (line 30) | class FullRainbowAgentTest(absltest.TestCase): method setUp (line 32) | def setUp(self): method _create_test_agent (line 50) | def _create_test_agent(self): method testCreateAgentWithDefaults (line 111) | def testCreateAgentWithDefaults(self): method testShapesAndValues (line 119) | def testShapesAndValues(self): method testBeginEpisode (line 136) | def testBeginEpisode(self): method testStepEval (line 171) | def testStepEval(self): method testStepTrain (line 206) | def testStepTrain(self): method testStoreTransitionWithUniformSampling (line 241) | def testStoreTransitionWithUniformSampling(self): method testStoreTransitionWithPrioritizedSampling (line 256) | def testStoreTransitionWithPrioritizedSampling(self): FILE: tests/dopamine/jax/agents/implicit_quantile/implicit_quantile_agent_test.py class ImplicitQuantileAgentTest (line 34) | class ImplicitQuantileAgentTest(absltest.TestCase): method setUp (line 36) | def setUp(self): method _create_test_agent (line 49) | def _create_test_agent(self): method testCreateAgentWithDefaults (line 91) | def testCreateAgentWithDefaults(self): method testShapes (line 99) | def testShapes(self): method testQValueComputation (line 119) | def testQValueComputation(self): method testReplayQuantileValueShape (line 139) | def testReplayQuantileValueShape(self): method testBeginEpisode (line 163) | def testBeginEpisode(self): FILE: tests/dopamine/jax/agents/ppo/ppo_agent_test.py function create_agent (line 33) | def create_agent( function get_agent_params (line 52) | def get_agent_params( class PPOAgentTest (line 58) | class PPOAgentTest(parameterized.TestCase): method assertAgentParametersEqual (line 60) | def assertAgentParametersEqual( method assertAgentParametersNotEqual (line 77) | def assertAgentParametersNotEqual( method setUp (line 96) | def setUp(self): method test_integer_shaped_actions_match_shapes (line 105) | def test_integer_shaped_actions_match_shapes(self, eval_mode: bool): method test_tuple_shaped_actions_match_shapes (line 124) | def test_tuple_shaped_actions_match_shapes( method test_restore_agent_from_bundle_restores_parameters (line 138) | def test_restore_agent_from_bundle_restores_parameters(self): method test_calculate_advantages_and_returns_shapes (line 160) | def test_calculate_advantages_and_returns_shapes(self): method test_calculate_advantages_and_returns_discounting (line 213) | def test_calculate_advantages_and_returns_discounting( method test_calculate_advantages_and_returns_q_value_calculations (line 263) | def test_calculate_advantages_and_returns_q_value_calculations( method test_create_minibatches_and_shuffle_shapes (line 282) | def test_create_minibatches_and_shuffle_shapes(self): method test_create_minibatches_and_shuffle_incorrect_batch_size (line 320) | def test_create_minibatches_and_shuffle_incorrect_batch_size(self): method test_select_action_shapes (line 347) | def test_select_action_shapes(self, action_shape): FILE: tests/dopamine/jax/agents/quantile/quantile_agent_test.py class JaxQuantileAgentTest (line 34) | class JaxQuantileAgentTest(absltest.TestCase): method setUp (line 36) | def setUp(self): method _create_test_agent (line 52) | def _create_test_agent(self): method testCreateAgentWithDefaults (line 109) | def testCreateAgentWithDefaults(self): method testShapesAndValues (line 117) | def testShapesAndValues(self): method testBeginEpisode (line 139) | def testBeginEpisode(self): method testStepEval (line 174) | def testStepEval(self): method testStepTrain (line 209) | def testStepTrain(self): FILE: tests/dopamine/jax/agents/rainbow/rainbow_agent_test.py class ProjectDistributionTest (line 35) | class ProjectDistributionTest(absltest.TestCase): method _vmapped_projection (line 37) | def _vmapped_projection(self, supports, weights, target_support): method testProjectSingleIdenticalDistribution (line 42) | def testProjectSingleIdenticalDistribution(self): method testProjectSingleDifferentDistribution (line 50) | def testProjectSingleDifferentDistribution(self): method testProjectFromNonMonotonicSupport (line 58) | def testProjectFromNonMonotonicSupport(self): method testExampleFromCodeComments (line 66) | def testExampleFromCodeComments(self): method testProjectBatchOfDifferentDistributions (line 80) | def testProjectBatchOfDifferentDistributions(self): method testUsingPlaceholders (line 101) | def testUsingPlaceholders(self): method testProjectBatchOfDifferentDistributionsWithLargerDelta (line 117) | def testProjectBatchOfDifferentDistributionsWithLargerDelta(self): class RainbowAgentTest (line 134) | class RainbowAgentTest(absltest.TestCase): method setUp (line 136) | def setUp(self): method _create_test_agent (line 153) | def _create_test_agent(self): method testCreateAgentWithDefaults (line 210) | def testCreateAgentWithDefaults(self): method testShapesAndValues (line 218) | def testShapesAndValues(self): method testBeginEpisode (line 235) | def testBeginEpisode(self): method testStepEval (line 270) | def testStepEval(self): method testStepTrain (line 305) | def testStepTrain(self): method testStoreTransitionWithUniformSampling (line 340) | def testStoreTransitionWithUniformSampling(self): method testStoreTransitionWithPrioritizedSamplingy (line 355) | def testStoreTransitionWithPrioritizedSamplingy(self): FILE: tests/dopamine/jax/agents/sac/sac_agent_test.py function get_mock_batch (line 40) | def get_mock_batch( function create_agent (line 77) | def create_agent( function get_agent_params (line 100) | def get_agent_params( class SacAgentTest (line 106) | class SacAgentTest(parameterized.TestCase): method assertAgentParametersEqual (line 108) | def assertAgentParametersEqual( method assertAgentParametersNotEqual (line 125) | def assertAgentParametersNotEqual( method setUp (line 144) | def setUp(self): method testIntegerShapedActionsMatchShapes (line 154) | def testIntegerShapedActionsMatchShapes(self, eval_mode: bool): method testTupleShapedActionsMatchShapes (line 173) | def testTupleShapedActionsMatchShapes( method testAgentParametersUpdateWhenTrained (line 191) | def testAgentParametersUpdateWhenTrained(self, action_shape: Tuple[int... method testAgentParametersNotUpdatedDuringEval (line 215) | def testAgentParametersNotUpdatedDuringEval(self): method testRestoreAgentFromBundleRestoresParameters (line 238) | def testRestoreAgentFromBundleRestoresParameters(self): method testAgentTrainsWithImageObservations (line 260) | def testAgentTrainsWithImageObservations(self): FILE: tests/dopamine/jax/checkpointers_test.py class CheckpointTest (line 27) | class CheckpointTest(parameterized.TestCase): method setUp (line 29) | def setUp(self): method make_checkpoint_manager (line 33) | def make_checkpoint_manager(self) -> orbax.CheckpointManager: method testInvalidCheckpointProtocol (line 39) | def testInvalidCheckpointProtocol(self): method testCheckpointable (line 75) | def testCheckpointable(self, data: Any): FILE: tests/dopamine/jax/continuous_networks_test.py class ActorNetworkTest (line 26) | class ActorNetworkTest(parameterized.TestCase): method setUp (line 28) | def setUp(self): method test_actor_network_outputs_correct_shaped_values (line 34) | def test_actor_network_outputs_correct_shaped_values(self): method test_network_has_specified_number_of_layers (line 46) | def test_network_has_specified_number_of_layers(self): method test_network_activation_initializer (line 61) | def test_network_activation_initializer(self, activation): method test_network_kernel_initializer (line 72) | def test_network_kernel_initializer(self, kernel_initializer): class CriticNetworkTest (line 79) | class CriticNetworkTest(parameterized.TestCase): method setUp (line 81) | def setUp(self): method test_critic_network_outputs_single_value (line 88) | def test_critic_network_outputs_single_value(self): method test_network_has_specified_number_of_layers (line 96) | def test_network_has_specified_number_of_layers(self): method test_network_activation_initializer (line 110) | def test_network_activation_initializer(self, activation): method test_network_kernel_initializer (line 119) | def test_network_kernel_initializer(self, kernel_initializer): class ActorCriticNetworkTest (line 126) | class ActorCriticNetworkTest(absltest.TestCase): method setUp (line 128) | def setUp(self): method assert_actor_shapes_are_correct (line 142) | def assert_actor_shapes_are_correct( method assert_critic_shapes_are_correct (line 152) | def assert_critic_shapes_are_correct( method test_actor_critic_network_call_outputs_correct_shaped_values (line 160) | def test_actor_critic_network_call_outputs_correct_shaped_values(self): method test_actor_critic_network_actor_outputs_correct_shaped_values (line 167) | def test_actor_critic_network_actor_outputs_correct_shaped_values(self): method test_actor_critic_network_critic_outputs_correct_shaped_values (line 175) | def test_actor_critic_network_critic_outputs_correct_shaped_values(self): method test_network_has_specified_number_of_layers (line 186) | def test_network_has_specified_number_of_layers(self): class PPOActorNetworkTest (line 208) | class PPOActorNetworkTest(parameterized.TestCase): method setUp (line 210) | def setUp(self): method test_ppo_actor_network_outputs_correct_shaped_values (line 216) | def test_ppo_actor_network_outputs_correct_shaped_values(self): method test_ppo_network_has_specified_number_of_layers (line 228) | def test_ppo_network_has_specified_number_of_layers(self): method test_ppo_network_activation_initializer (line 244) | def test_ppo_network_activation_initializer(self, activation): method test_scale_diag_set_to_zero (line 250) | def test_scale_diag_set_to_zero(self): class PPOCriticNetworkTest (line 263) | class PPOCriticNetworkTest(parameterized.TestCase): method setUp (line 265) | def setUp(self): method test_ppo_critic_network_outputs_single_value (line 271) | def test_ppo_critic_network_outputs_single_value(self): method test_ppo_network_has_specified_number_of_layers (line 279) | def test_ppo_network_has_specified_number_of_layers(self): method test_ppo_network_activation_initializer (line 293) | def test_ppo_network_activation_initializer(self, activation): class PPOActorCriticNetworkTest (line 298) | class PPOActorCriticNetworkTest(absltest.TestCase): method setUp (line 300) | def setUp(self): method assert_ppo_actor_shapes_are_correct (line 313) | def assert_ppo_actor_shapes_are_correct( method assert_ppo_critic_shapes_are_correct (line 323) | def assert_ppo_critic_shapes_are_correct( method test_ppo_actor_critic_network_call_outputs_correct_shaped_values (line 329) | def test_ppo_actor_critic_network_call_outputs_correct_shaped_values(s... method test_ppo_actor_critic_network_actor_outputs_correct_shaped_values (line 336) | def test_ppo_actor_critic_network_actor_outputs_correct_shaped_values(... method test_ppo_actor_critic_network_critic_outputs_correct_shaped_values (line 344) | def test_ppo_actor_critic_network_critic_outputs_correct_shaped_values... method test_network_has_specified_number_of_layers (line 354) | def test_network_has_specified_number_of_layers(self): class ContinousNetworksHelperTest (line 370) | class ContinousNetworksHelperTest(parameterized.TestCase): method test_create_activation (line 373) | def test_create_activation(self, name, activation): method test_create_activation_raises_error (line 377) | def test_create_activation_raises_error(self): FILE: tests/dopamine/jax/losses_test.py class LossesTest (line 25) | class LossesTest(parameterized.TestCase): method testHuberLoss (line 57) | def testHuberLoss( method testMSELoss (line 84) | def testMSELoss( FILE: tests/dopamine/jax/networks_test.py class NetworksTest (line 25) | class NetworksTest(parameterized.TestCase): method setUp (line 27) | def setUp(self): method testOutputShape (line 51) | def testOutputShape(self, network: nn.Module): method testPPOOutputShape (line 80) | def testPPOOutputShape(self): FILE: tests/dopamine/jax/replay_memory/accumulator_test.py class TransitionAccumulatorTest (line 26) | class TransitionAccumulatorTest(parameterized.TestCase): method _verify_accumulator_transition (line 28) | def _verify_accumulator_transition( method _add (line 48) | def _add( method setUp (line 55) | def setUp(self): method testInitializer (line 65) | def testInitializer(self): method testReset (line 71) | def testReset(self): method testOneElementTrajectoriesAreInvalid (line 84) | def testOneElementTrajectoriesAreInvalid(self): method testAddSingleTrajectory (line 99) | def testAddSingleTrajectory(self, n: int, valid: bool): method testAccumulate (line 136) | def testAccumulate(self, n): method testAccumulateWithInvalidFirstTrajectory (line 151) | def testAccumulateWithInvalidFirstTrajectory(self): method testClear (line 196) | def testClear(self): FILE: tests/dopamine/jax/replay_memory/elements_test.py class ElementsTest (line 26) | class ElementsTest(parameterized.TestCase): method test_pack_unpack (line 28) | def test_pack_unpack(self) -> None: FILE: tests/dopamine/jax/replay_memory/replay_buffer_regression_test.py class ReplayBufferRegressionTest (line 37) | class ReplayBufferRegressionTest(parameterized.TestCase): method setUp (line 39) | def setUp(self): method testNSteprewards (line 54) | def testNSteprewards(self): method testGetStack (line 74) | def testGetStack(self): method testSampleTransitionBatch (line 107) | def testSampleTransitionBatch(self): method testSamplingWithTerminalInTrajectory (line 183) | def testSamplingWithTerminalInTrajectory(self): FILE: tests/dopamine/jax/replay_memory/replay_buffer_test.py class ReplayBufferTest (line 36) | class ReplayBufferTest(parameterized.TestCase): method setUp (line 38) | def setUp(self): method testWithInvalidCheckpointDuration (line 53) | def testWithInvalidCheckpointDuration(self, cd): method testCreateReplayBuffer (line 63) | def testCreateReplayBuffer(self): method testAddWithEmptyAccumulatorReturn (line 73) | def testAddWithEmptyAccumulatorReturn(self, mock_accumulator): method testAddWithoutCompress (line 87) | def testAddWithoutCompress(self, mock_accumulator): method testAddWithValidAccumulatorReturn (line 102) | def testAddWithValidAccumulatorReturn(self, mock_accumulator): method testAddUpToCapacity (line 128) | def testAddUpToCapacity(self, mock_accumulator): method testSampleWithNoElements (line 163) | def testSampleWithNoElements(self): method testSample (line 196) | def testSample(self, mock_accumulator, use_default_bs, compress): method testSave (line 256) | def testSave(self, compress): method testGarbageCollection (line 297) | def testGarbageCollection(self, cd): method testLoad (line 317) | def testLoad(self): method testSaveAndLoad (line 345) | def testSaveAndLoad(self): method testKeyMappingsForSampling (line 418) | def testKeyMappingsForSampling(self): method testClearBuffer (line 497) | def testClearBuffer(self): FILE: tests/dopamine/jax/replay_memory/samplers_test.py class UniformSamplingTest (line 23) | class UniformSamplingTest(parameterized.TestCase): method setUp (line 25) | def setUp(self): method test_update_does_not_raise_and_logs (line 29) | def test_update_does_not_raise_and_logs(self): method test_additional_kwargs_to_add_logs (line 35) | def test_additional_kwargs_to_add_logs(self): method test_sample_when_empty (line 41) | def test_sample_when_empty(self): method test_removal_of_invalid_key_raises (line 45) | def test_removal_of_invalid_key_raises(self): method test_invalid_sample_size_raises (line 50) | def test_invalid_sample_size_raises(self, size: int): method test_add_and_remove (line 54) | def test_add_and_remove(self): method test_sample (line 58) | def test_sample(self): method test_serializes (line 65) | def test_serializes(self): method test_clear_sampler (line 77) | def test_clear_sampler(self): class PrioritizedSamplingTest (line 86) | class PrioritizedSamplingTest(parameterized.TestCase): method setUp (line 88) | def setUp(self): method test_priorities_can_be_updated (line 94) | def test_priorities_can_be_updated(self): method test_priorities_can_be_removed (line 115) | def test_priorities_can_be_removed(self): method test_zero_priorities_is_uniform_sampling (line 129) | def test_zero_priorities_is_uniform_sampling(self): method test_positive_priorities_computes_probabilities (line 136) | def test_positive_priorities_computes_probabilities(self): method test_removal_wont_sample_removed_index (line 144) | def test_removal_wont_sample_removed_index(self, index_to_remove: int): method test_clear_sampler (line 152) | def test_clear_sampler(self): class SequentialSamplingTest (line 165) | class SequentialSamplingTest(parameterized.TestCase): method setUp (line 167) | def setUp(self): method test_sample_when_empty (line 171) | def test_sample_when_empty(self): method test_sample_when_not_empty (line 175) | def test_sample_when_not_empty(self): method test_order_is_sequential_after_add (line 185) | def test_order_is_sequential_after_add(self): method test_order_is_sequential_after_remove (line 197) | def test_order_is_sequential_after_remove(self): method test_order_with_unsorted_add (line 209) | def test_order_with_unsorted_add(self): method test_clear_sampler (line 220) | def test_clear_sampler(self): method test_unsorted_add_with_sort_samples_false (line 229) | def test_unsorted_add_with_sort_samples_false(self): FILE: tests/dopamine/jax/replay_memory/sum_tree_test.py class SumTreeTest (line 23) | class SumTreeTest(parameterized.TestCase): method setUp (line 25) | def setUp(self): method test_negative_capacity_raises (line 29) | def test_negative_capacity_raises(self): method test_negative_value_raises (line 33) | def test_negative_value_raises(self): method test_set_small_capacity (line 37) | def test_set_small_capacity(self): method test_set_and_get_value (line 42) | def test_set_and_get_value(self): method test_set_and_get_values_vectorized (line 52) | def test_set_and_get_values_vectorized(self): method test_set_with_duplicates (line 61) | def test_set_with_duplicates(self): method test_capacity_greater_than_requested (line 70) | def test_capacity_greater_than_requested(self): method test_query_empty_tree (line 73) | def test_query_empty_tree(self): method test_query_value (line 77) | def test_query_value(self): method test_query_values_vectorized (line 81) | def test_query_values_vectorized(self): method test_update_sum_values (line 102) | def test_update_sum_values(self): method test_query_values_vectorized_large_tree (line 121) | def test_query_values_vectorized_large_tree(self): method test_serialization (line 143) | def test_serialization(self): method test_clear (line 152) | def test_clear(self): method test_max_recorded_priority (line 158) | def test_max_recorded_priority(self): FILE: tests/dopamine/jax/serialization_test.py class SerializationTest (line 25) | class SerializationTest(parameterized.TestCase): method testEncodeNumpy (line 45) | def testEncodeNumpy(self, array: Union[np.ndarray, np.bool_, np.number]): method testEncodeLongIntegers (line 58) | def testEncodeLongIntegers(self, integer: int): FILE: tests/dopamine/labs/atari_100k/train_test.py function QuickAgentFlags (line 30) | def QuickAgentFlags(): function SetAgentConfig (line 44) | def SetAgentConfig(agent_name='der'): class RunnerIntegrationTest (line 52) | class RunnerIntegrationTest(parameterized.TestCase): method setUp (line 57) | def setUp(self): method VerifyFilesCreated (line 64) | def VerifyFilesCreated(self): method testIntegration (line 79) | def testIntegration(self, agent_name): FILE: tests/dopamine/labs/moes/architectures/networks_test.py class NetworksTest (line 26) | class NetworksTest(parameterized.TestCase): method setUp (line 28) | def setUp(self): method testExpertModel (line 41) | def testExpertModel(self, maintain_token_size): method testImpalaMoEWithInvalidMoEType (line 57) | def testImpalaMoEWithInvalidMoEType(self): method testImpalaMoEWithInvalidRoutingType (line 64) | def testImpalaMoEWithInvalidRoutingType(self): method testImpalaMoEBaseline (line 71) | def testImpalaMoEBaseline(self): method _create_network_and_apply (line 82) | def _create_network_and_apply(self, network_class, moe_type, state, method _create_impala_network_and_apply (line 97) | def _create_impala_network_and_apply(self, moe_type, routing_type): method _test_network_outputs (line 106) | def _test_network_outputs( method testImpalaMoE (line 153) | def testImpalaMoE(self, routing_type, expected_moe_output_shape): method testImpalaSoftMoE (line 184) | def testImpalaSoftMoE(self, routing_type, expected_moe_output_shape): method testNatureDQNMoEWithInvalidMoEType (line 197) | def testNatureDQNMoEWithInvalidMoEType(self): method testNatureDQNMoEWithInvalidRoutingType (line 204) | def testNatureDQNMoEWithInvalidRoutingType(self): method testNatureDQNMoEBaseline (line 211) | def testNatureDQNMoEBaseline(self): method _create_nature_network_and_apply (line 222) | def _create_nature_network_and_apply(self, moe_type, routing_type): method testNatureDQNMoE (line 253) | def testNatureDQNMoE(self, routing_type, expected_moe_output_shape): method testNatureDQNSoftMoE (line 284) | def testNatureDQNSoftMoE(self, routing_type, expected_moe_output_shape): method _create_full_rainbow_network_and_apply (line 294) | def _create_full_rainbow_network_and_apply(self, moe_type, routing_type): method testFullRainbowMoE (line 329) | def testFullRainbowMoE(self, routing_type, expected_moe_output_shape): method testFullRainbowSoftMoE (line 359) | def testFullRainbowSoftMoE(self, routing_type, expected_moe_output_sha... FILE: tests/dopamine/labs/offline_rl/jax/offline_agent_test.py class OfflineAgentTest (line 35) | class OfflineAgentTest(parameterized.TestCase): method setUp (line 37) | def setUp(self): method _create_agent_fn (line 71) | def _create_agent_fn(self, agent_name): method _test_train_step_updates_weights (line 86) | def _test_train_step_updates_weights(self, agent_name): method test_train_step_updates_weights (line 112) | def test_train_step_updates_weights(self, agent_name): method test_target_type (line 116) | def test_target_type(self, target_type): method test_hl_loss_type (line 123) | def test_hl_loss_type(self, hl_loss_type): method test_impala (line 127) | def test_impala(self): FILE: tests/dopamine/labs/sac_from_pixels/continuous_networks_test.py class SACEncoderNetworkTest (line 23) | class SACEncoderNetworkTest(absltest.TestCase): method setUp (line 25) | def setUp(self): method test_network_outputs_correct_shapes (line 31) | def test_network_outputs_correct_shapes(self): method test_network_correctly_handles_stacked_rgb_frames (line 43) | def test_network_correctly_handles_stacked_rgb_frames(self): method test_actor_cant_update_conv_weights (line 57) | def test_actor_cant_update_conv_weights(self): method test_critic_can_update_conv_weights (line 77) | def test_critic_can_update_conv_weights(self): FILE: tests/dopamine/labs/sac_from_pixels/deepmind_control_lib_test.py function get_mock_render (line 44) | def get_mock_render( # pytype: disable=annotation-type-mismatch # nump... class MockDeepmindControlSuiteEnvironment (line 55) | class MockDeepmindControlSuiteEnvironment(control.Environment): method __init__ (line 57) | def __init__( method action_spec (line 67) | def action_spec(self) -> specs.BoundedArray: method observation_spec (line 75) | def observation_spec(self) -> Mapping[str, specs.BoundedArray]: method reset (line 88) | def reset(self) -> dm_env.TimeStep: method step (line 99) | def step(self, action: np.ndarray) -> dm_env.TimeStep: class DeepmindControlLibTest (line 122) | class DeepmindControlLibTest(parameterized.TestCase): method setUp (line 124) | def setUp(self): method create_env (line 134) | def create_env(self) -> control.Environment: method test_preprocessing (line 142) | def test_preprocessing(self): method test_image_wrapper_returns_image_observation_on_reset (line 174) | def test_image_wrapper_returns_image_observation_on_reset(self): method test_image_wrapper_returns_image_observation_on_step (line 185) | def test_image_wrapper_returns_image_observation_on_step(self): method test_float_image_observation_is_scaled_correctly_to_uint (line 201) | def test_float_image_observation_is_scaled_correctly_to_uint( method test_action_repeats_successfully_apply_repeated_action (line 216) | def test_action_repeats_successfully_apply_repeated_action(self): method test_action_repeats_obey_end_of_episode (line 238) | def test_action_repeats_obey_end_of_episode(self): FILE: tests/dopamine/metrics/collector_dispatcher_test.py class CollectorDispatcherTest (line 25) | class CollectorDispatcherTest(parameterized.TestCase): method setUp (line 27) | def setUp(self): method test_with_no_collectors (line 31) | def test_with_no_collectors(self): method test_with_default_collectors (line 43) | def test_with_default_collectors(self): method test_with_simple_collector (line 57) | def test_with_simple_collector(self, allowlist): FILE: tests/dopamine/metrics/collector_test.py class SimpleCollector (line 25) | class SimpleCollector(collector.Collector): method get_name (line 27) | def get_name(self) -> str: method write (line 30) | def write(self, unused_statistics) -> None: method flush (line 33) | def flush(self, unused_statistics) -> None: method close (line 36) | def close(self) -> None: class CollectorTest (line 40) | class CollectorTest(absltest.TestCase): method setUp (line 42) | def setUp(self): method test_instantiate_abstract_class (line 46) | def test_instantiate_abstract_class(self): method test_valid_subclass (line 51) | def test_valid_subclass(self): method test_valid_subclass_with_no_basedir (line 58) | def test_valid_subclass_with_no_basedir(self): FILE: tests/dopamine/metrics/console_collector_test.py class ConsoleCollectorTest (line 27) | class ConsoleCollectorTest(absltest.TestCase): method setUp (line 29) | def setUp(self): method test_valid_creation (line 34) | def test_valid_creation(self): method test_valid_creation_no_base_dir (line 47) | def test_valid_creation_no_base_dir(self): method test_valid_creation_no_save_to_file (line 54) | def test_valid_creation_no_save_to_file(self): method test_step_with_fine_grained_logging (line 64) | def test_step_with_fine_grained_logging(self): method test_no_write_with_unsupported_type (line 79) | def test_no_write_with_unsupported_type(self): method test_full_run (line 95) | def test_full_run(self): FILE: tests/dopamine/metrics/pickle_collector_test.py class PickleCollectorTest (line 27) | class PickleCollectorTest(absltest.TestCase): method setUp (line 29) | def setUp(self): method test_with_none_base_dir (line 33) | def test_with_none_base_dir(self): method test_valid_creation (line 37) | def test_valid_creation(self): method test_write (line 44) | def test_write(self): method test_no_write_with_unsupported_type (line 56) | def test_no_write_with_unsupported_type(self): method test_flush (line 67) | def test_flush(self): method test_full_run (line 77) | def test_full_run(self): FILE: tests/dopamine/metrics/tensorboard_collector_test.py class TensorboardCollectorTest (line 26) | class TensorboardCollectorTest(absltest.TestCase): method test_with_invalid_base_dir_raises_value_error (line 28) | def test_with_invalid_base_dir_raises_value_error(self): method test_valid_creation_with_all_required_parameters (line 32) | def test_valid_creation_with_all_required_parameters(self): method test_write (line 45) | def test_write(self): method test_no_write_with_unsupported_type (line 65) | def test_no_write_with_unsupported_type(self): method test_full_run (line 79) | def test_full_run(self): FILE: tests/dopamine/tests/gin_config_test.py class GinConfigTest (line 33) | class GinConfigTest(tf.test.TestCase): method setUp (line 38) | def setUp(self): method testDefaultGinDqn (line 49) | def testDefaultGinDqn(self): method testOverrideRunnerParams (line 69) | def testOverrideRunnerParams(self): method testDefaultGinRmspropDqn (line 89) | def testDefaultGinRmspropDqn(self): method testOverrideGinDqn (line 109) | def testOverrideGinDqn(self): method testDefaultGinRainbow (line 130) | def testDefaultGinRainbow(self): method testOverrideGinRainbow (line 152) | def testOverrideGinRainbow(self): method testDefaultDQNConfig (line 174) | def testDefaultDQNConfig(self): method testDefaultC51Config (line 193) | def testDefaultC51Config(self): method testDefaultRainbowConfig (line 217) | def testDefaultRainbowConfig(self): method testDefaultGinImplicitQuantile (line 241) | def testDefaultGinImplicitQuantile(self): method testOverrideGinImplicitQuantile (line 265) | def testOverrideGinImplicitQuantile(self): FILE: tests/dopamine/tests/integration_test.py class AtariIntegrationTest (line 31) | class AtariIntegrationTest(tf.test.TestCase): method setUp (line 36) | def setUp(self): method quick_dqn_flags (line 47) | def quick_dqn_flags(self): method quick_rainbow_flags (line 59) | def quick_rainbow_flags(self): method verify_files_created (line 74) | def verify_files_created(self): method testIntegrationDqn (line 88) | def testIntegrationDqn(self): method testIntegrationRainbow (line 97) | def testIntegrationRainbow(self): FILE: tests/dopamine/tests/train_runner_integration_test.py class TrainRunnerIntegrationTest (line 29) | class TrainRunnerIntegrationTest(tf.test.TestCase): method setUp (line 34) | def setUp(self): method quick_dqn_flags (line 43) | def quick_dqn_flags(self): method verify_files_created (line 57) | def verify_files_created(self): method testIntegrationDqn (line 71) | def testIntegrationDqn(self): FILE: tests/dopamine/tf/agents/dqn/dqn_agent_test.py class DQNAgentTest (line 37) | class DQNAgentTest(tf.test.TestCase): method setUp (line 39) | def setUp(self): method _create_test_agent (line 58) | def _create_test_agent(self, sess, allow_partial_reload=False): method testCreateAgentWithDefaults (line 105) | def testCreateAgentWithDefaults(self): method testBeginEpisode (line 115) | def testBeginEpisode(self): method testStepEval (line 151) | def testStepEval(self): method testStepTrain (line 188) | def testStepTrain(self): method testNonTupleObservationShape (line 241) | def testNonTupleObservationShape(self): method _test_custom_shapes (line 247) | def _test_custom_shapes(self, shape, dtype, stack_size): method testStepTrainCustomObservationShapes (line 300) | def testStepTrainCustomObservationShapes(self): method testStepTrainCustomTypes (line 305) | def testStepTrainCustomTypes(self): method testStepTrainCustomStackSizes (line 310) | def testStepTrainCustomStackSizes(self): method testLinearlyDecayingEpsilon (line 315) | def testLinearlyDecayingEpsilon(self): method testBundlingWithNonexistentDirectory (line 334) | def testBundlingWithNonexistentDirectory(self): method testUnbundlingWithFailingReplayBuffer (line 339) | def testUnbundlingWithFailingReplayBuffer(self): method testUnbundlingWithNoBundleDictionary (line 348) | def testUnbundlingWithNoBundleDictionary(self): method testPartialUnbundling (line 354) | def testPartialUnbundling(self): method testBundling (line 367) | def testBundling(self): method testSyncOpWithNameScopes (line 386) | def testSyncOpWithNameScopes(self): FILE: tests/dopamine/tf/agents/implicit_quantile/implicit_quantile_agent_test.py class ImplicitQuantileAgentTest (line 30) | class ImplicitQuantileAgentTest(tf.test.TestCase): method setUp (line 32) | def setUp(self): method _create_test_agent (line 42) | def _create_test_agent(self, sess): method testCreateAgentWithDefaults (line 90) | def testCreateAgentWithDefaults(self): method testShapes (line 100) | def testShapes(self): method test_q_value_computation (line 145) | def test_q_value_computation(self): method test_replay_quantile_value_computation (line 170) | def test_replay_quantile_value_computation(self): FILE: tests/dopamine/tf/agents/rainbow/rainbow_agent_test.py class ProjectDistributionTest (line 31) | class ProjectDistributionTest(tf.test.TestCase): method testInconsistentSupportsAndWeightsParameters (line 33) | def testInconsistentSupportsAndWeightsParameters(self): method testInconsistentSupportsAndWeightsWithPlaceholders (line 42) | def testInconsistentSupportsAndWeightsWithPlaceholders(self): method testInconsistentSupportsAndTargetSupportParameters (line 66) | def testInconsistentSupportsAndTargetSupportParameters(self): method testInconsistentSupportsAndTargetSupportWithPlaceholders (line 75) | def testInconsistentSupportsAndTargetSupportWithPlaceholders(self): method testZeroDimensionalTargetSupport (line 99) | def testZeroDimensionalTargetSupport(self): method testZeroDimensionalTargetSupportWithPlaceholders (line 108) | def testZeroDimensionalTargetSupportWithPlaceholders(self): method testMultiDimensionalTargetSupport (line 130) | def testMultiDimensionalTargetSupport(self): method testMultiDimensionalTargetSupportWithPlaceholders (line 139) | def testMultiDimensionalTargetSupportWithPlaceholders(self): method testProjectWithNonMonotonicTargetSupport (line 161) | def testProjectWithNonMonotonicTargetSupport(self): method testProjectNewSupportHasInconsistentDeltask (line 177) | def testProjectNewSupportHasInconsistentDeltask(self): method testProjectSingleIdenticalDistribution (line 193) | def testProjectSingleIdenticalDistribution(self): method testProjectSingleDifferentDistribution (line 206) | def testProjectSingleDifferentDistribution(self): method testProjectFromNonMonotonicSupport (line 219) | def testProjectFromNonMonotonicSupport(self): method testExampleFromCodeComments (line 232) | def testExampleFromCodeComments(self): method testProjectBatchOfDifferentDistributions (line 250) | def testProjectBatchOfDifferentDistributions(self): method testUsingPlaceholders (line 276) | def testUsingPlaceholders(self): method testProjectBatchOfDifferentDistributionsWithLargerDelta (line 307) | def testProjectBatchOfDifferentDistributionsWithLargerDelta(self): class RainbowAgentTest (line 328) | class RainbowAgentTest(tf.test.TestCase): method setUp (line 330) | def setUp(self): method _create_test_agent (line 344) | def _create_test_agent(self, sess): method testCreateAgentWithDefaults (line 405) | def testCreateAgentWithDefaults(self): method testShapesAndValues (line 415) | def testShapesAndValues(self): method testBeginEpisode (line 448) | def testBeginEpisode(self): method testStepEval (line 484) | def testStepEval(self): method testStepTrain (line 521) | def testStepTrain(self): method testStoreTransitionWithUniformSampling (line 561) | def testStoreTransitionWithUniformSampling(self): method testStoreTransitionWithPrioritizedSamplingy (line 577) | def testStoreTransitionWithPrioritizedSamplingy(self): FILE: tests/dopamine/tf/replay_memory/circular_replay_buffer_test.py class CheckpointableClass (line 41) | class CheckpointableClass(object): method __init__ (line 43) | def __init__(self): class OutOfGraphReplayBufferTest (line 47) | class OutOfGraphReplayBufferTest(tf.test.TestCase): method setUp (line 49) | def setUp(self): method testWithNontupleObservationShape (line 62) | def testWithNontupleObservationShape(self): method testConstructor (line 71) | def testConstructor(self): method testAdd (line 98) | def testAdd(self): method testExtraAdd (line 111) | def testExtraAdd(self): method testCheckAddTypes (line 131) | def testCheckAddTypes(self): method testLowCapacity (line 149) | def testLowCapacity(self): method testGetRangeInvalidIndexOrder (line 181) | def testGetRangeInvalidIndexOrder(self): method testGetRangeNoWraparound (line 204) | def testGetRangeNoWraparound(self): method testGetRangeWithWraparound (line 232) | def testGetRangeWithWraparound(self): method testNSteprewardum (line 263) | def testNSteprewardum(self): method testGetStack (line 281) | def testGetStack(self): method testSampleTransitionBatch (line 310) | def testSampleTransitionBatch(self): method testSampleTransitionBatchExtra (line 373) | def testSampleTransitionBatchExtra(self): method testSamplingWithterminalInTrajectory (line 445) | def testSamplingWithterminalInTrajectory(self): method testInvalidRange (line 487) | def testInvalidRange(self): method testIsTransitionValid (line 515) | def testIsTransitionValid(self): method testSave (line 538) | def testSave(self): method testEpisodeEndIndicesAreCorrectlySaved (line 571) | def testEpisodeEndIndicesAreCorrectlySaved(self): method testEpisodeEndIndicesAreCorrectlyLoaded (line 596) | def testEpisodeEndIndicesAreCorrectlyLoaded(self): method testSaveWithKeepEvery (line 626) | def testSaveWithKeepEvery(self): method testSaveNonNDArrayAttributes (line 658) | def testSaveNonNDArrayAttributes(self): method testLoadFromNonexistentDirectory (line 694) | def testLoadFromNonexistentDirectory(self): method testPartialLoadFails (line 718) | def testPartialLoadFails(self): method testLoad (line 764) | def testLoad(self): class WrappedReplayBufferTest (line 806) | class WrappedReplayBufferTest(tf.test.TestCase): method setUp (line 808) | def setUp(self): method testConstructorCapacityNotLargeEnough (line 821) | def testConstructorCapacityNotLargeEnough(self): method testConstructorWithZeroUpdateHorizon (line 834) | def testConstructorWithZeroUpdateHorizon(self): method testConstructorWithOutOfBoundsDiscountFactor (line 844) | def testConstructorWithOutOfBoundsDiscountFactor(self): method testConstructorWithExtraStorageTypes (line 855) | def testConstructorWithExtraStorageTypes(self): method _verify_sampled_trajectories (line 865) | def _verify_sampled_trajectories(self, batch): method testConstructorWithNoStaging (line 899) | def testConstructorWithNoStaging(self): method testConstructorWithStaging (line 913) | def testConstructorWithStaging(self): method testWrapperSave (line 927) | def testWrapperSave(self): method testWrapperLoad (line 947) | def testWrapperLoad(self): method testDefaultObsDataType (line 998) | def testDefaultObsDataType(self): method testCustomObsDataType (line 1008) | def testCustomObsDataType(self): FILE: tests/dopamine/tf/replay_memory/prioritized_replay_buffer_test.py class OutOfGraphPrioritizedReplayBufferTest (line 35) | class OutOfGraphPrioritizedReplayBufferTest(tf.test.TestCase): method create_default_memory (line 37) | def create_default_memory(self, extra_storage_types=None): method add_blank (line 47) | def add_blank(self, memory, action=0, reward=0.0, terminal=0, priority... method testAddWithAndWithoutPriority (line 66) | def testAddWithAndWithoutPriority(self): method testAddWithAdditionalArgsAndPriority (line 80) | def testAddWithAdditionalArgsAndPriority(self): method testDummyScreensAddedToNewMemory (line 98) | def testDummyScreensAddedToNewMemory(self): method testGetPriorityWithInvalidIndices (line 104) | def testGetPriorityWithInvalidIndices(self): method testSetAndGetPriority (line 114) | def testSetAndGetPriority(self): method testNewElementHasHighPriority (line 128) | def testNewElementHasHighPriority(self): method testLowPriorityElementNotFrequentlySampled (line 135) | def testLowPriorityElementNotFrequentlySampled(self): method testSampleIndexBatchTooManyFailedRetries (line 150) | def testSampleIndexBatchTooManyFailedRetries(self): method testSampleIndexBatch (line 166) | def testSampleIndexBatch(self): class WrappedPrioritizedReplayBufferTest (line 187) | class WrappedPrioritizedReplayBufferTest(tf.test.TestCase): method create_default_memory (line 190) | def create_default_memory(self): method add_blank (line 200) | def add_blank(self, replay): method testSetAndGetPriority (line 203) | def testSetAndGetPriority(self): method testSampleBatch (line 221) | def testSampleBatch(self): method testConstructorWithExtraStorageTypes (line 232) | def testConstructorWithExtraStorageTypes(self): FILE: tests/dopamine/tf/replay_memory/sum_tree_test.py class SumTreeTest (line 28) | class SumTreeTest(tf.test.TestCase): method setUp (line 30) | def setUp(self): method testNegativeCapacity (line 34) | def testNegativeCapacity(self): method testSetNegativeValue (line 40) | def testSetNegativeValue(self): method testSmallCapacityConstructor (line 46) | def testSmallCapacityConstructor(self): method testSetValueSmallCapacity (line 52) | def testSetValueSmallCapacity(self): method testSetValue (line 57) | def testSetValue(self): method testCapacityGreaterThanRequested (line 68) | def testCapacityGreaterThanRequested(self): method testSampleFromEmptyTree (line 71) | def testSampleFromEmptyTree(self): method testSampleWithInvalidQueryValue (line 77) | def testSampleWithInvalidQueryValue(self): method testSampleSingleton (line 84) | def testSampleSingleton(self): method testSamplePairWithUnevenProbabilities (line 90) | def testSamplePairWithUnevenProbabilities(self): method testSamplePairWithUnevenProbabilitiesWithQueryValue (line 98) | def testSamplePairWithUnevenProbabilitiesWithQueryValue(self): method testSamplingWithSeedDoesNotAffectFutureCalls (line 105) | def testSamplingWithSeedDoesNotAffectFutureCalls(self): method testStratifiedSamplingFromEmptyTree (line 138) | def testStratifiedSamplingFromEmptyTree(self): method testStratifiedSampling (line 144) | def testStratifiedSampling(self): method testMaxRecordedProbability (line 153) | def testMaxRecordedProbability(self): FILE: tests/dopamine/utils/agent_visualizer_test.py class AgentVisualizerTest (line 36) | class AgentVisualizerTest(tf.test.TestCase): method setUp (line 38) | def setUp(self): method test_agent_visualizer_save_frame (line 44) | def test_agent_visualizer_save_frame(self):