SYMBOL INDEX (900 symbols across 70 files) FILE: examples/envs/chained_cue_navigation.py function build_agent (line 29) | def build_agent(env: CueChainingEnv, policy_len: int = 4) -> Agent: function rollout_locations (line 54) | def rollout_locations( function interpolate_locations (line 77) | def interpolate_locations( function load_assets (line 92) | def load_assets(asset_dir: Path) -> dict[str, np.ndarray]: function make_animation (line 106) | def make_animation( function main (line 317) | def main(): FILE: examples/model_fitting/tmaze_recoverability.py class RecoverabilityConfig (line 40) | class RecoverabilityConfig: function _assert_parameterization (line 59) | def _assert_parameterization(parameterization: str) -> None: function _build_task (line 65) | def _build_task(cfg: RecoverabilityConfig) -> TMaze: function _latent_grid (line 75) | def _latent_grid(cfg: RecoverabilityConfig, num_params: int) -> jnp.ndar... function _build_three_param_transform (line 89) | def _build_three_param_transform(task: TMaze): function _build_reward_only_transform (line 139) | def _build_reward_only_transform(task: TMaze): function _build_model_and_truth (line 192) | def _build_model_and_truth(cfg: RecoverabilityConfig) -> tuple[NumpyroMo... function _simulate_measurements (line 215) | def _simulate_measurements(model: NumpyroModel, z_true: jnp.ndarray, see... function _fit_svi (line 228) | def _fit_svi( function _pearson (line 246) | def _pearson(x: np.ndarray, y: np.ndarray) -> float: function _bimodality_score (line 252) | def _bimodality_score(probabilities: np.ndarray) -> float: function run_recoverability (line 261) | def run_recoverability(cfg: RecoverabilityConfig) -> dict[str, Any]: function _save_scatter (line 289) | def _save_scatter(results: dict[str, Any], path: Path) -> None: function _to_jsonable (line 308) | def _to_jsonable(results: dict[str, Any]) -> str: function parse_args (line 312) | def parse_args() -> argparse.Namespace: function main (line 327) | def main() -> None: FILE: pymdp/agent.py class Agent (line 19) | class Agent(Module): method __init__ (line 128) | def __init__( method unique_multiactions (line 397) | def unique_multiactions(self) -> Array: method _get_num_states_from_B (line 401) | def _get_num_states_from_B(self, B: list[Array], B_dependencies: list[... method infer_parameters (line 416) | def infer_parameters( method process_obs (line 622) | def process_obs(self, observations: list[Array] | list[int]) -> list[A... method make_categorical (line 659) | def make_categorical(self, observations: list[Array] | list[int]) -> l... method infer_states (line 675) | def infer_states( method update_empirical_prior (line 825) | def update_empirical_prior(self, action: Array, qs: list[Array]) -> li... method infer_policies (line 853) | def infer_policies(self, qs: list[Array]) -> tuple[Array, Array]: method multiaction_probabilities (line 906) | def multiaction_probabilities(self, q_pi: Array) -> Array: method sample_action (line 937) | def sample_action(self, q_pi: Array, rng_key: Array | None = None) -> ... method decode_multi_actions (line 967) | def decode_multi_actions(self, action: Array) -> Array: method encode_multi_actions (line 993) | def encode_multi_actions(self, action_multi: Array) -> Array: method get_model_dimensions (line 1022) | def get_model_dimensions(self) -> dict[str, Any]: method _construct_dependencies (line 1055) | def _construct_dependencies( method _flatten_B_action_dims (line 1086) | def _flatten_B_action_dims( method _construct_flattend_policies (line 1123) | def _construct_flattend_policies(self, policies: Array, action_maps: l... method _get_default_params (line 1139) | def _get_default_params(self) -> dict[str, Any] | None: method _validate (line 1157) | def _validate(self) -> None: FILE: pymdp/algos.py function add (line 18) | def add(x: Array, y: Array) -> Array: function _matmul_high_precision (line 22) | def _matmul_high_precision(lhs: Array, rhs: Array) -> Array: function marginal_log_likelihood (line 26) | def marginal_log_likelihood(qs: list[Array], log_likelihood: Array, i: i... function all_marginal_log_likelihood (line 30) | def all_marginal_log_likelihood( function mll_factors (line 46) | def mll_factors(qs: list[Array], ll_m: Array, factor_list_m: list[int]) ... function run_vanilla_fpi (line 53) | def run_vanilla_fpi( function run_factorized_fpi (line 88) | def run_factorized_fpi( function mirror_gradient_descent_step (line 128) | def mirror_gradient_descent_step( function update_marginals (line 141) | def update_marginals( function variational_filtering_step (line 212) | def variational_filtering_step( function update_variational_filtering (line 243) | def update_variational_filtering( function get_vmp_messages (line 283) | def get_vmp_messages( function run_vmp (line 363) | def run_vmp( function get_mmp_messages (line 403) | def get_mmp_messages( function run_mmp (line 483) | def run_mmp( function run_online_filtering (line 544) | def run_online_filtering( function run_factorized_fpi_hybrid (line 560) | def run_factorized_fpi_hybrid( function get_qs_padded (line 584) | def get_qs_padded(qs: list[Array], max_state_dim: int) -> list[Array]: function compute_qL_marginals (line 590) | def compute_qL_marginals( function qL_flatten (line 635) | def qL_flatten(qL_marginals_padded: list[list[Array]]) -> list[list[Arra... function compute_qL_all (line 646) | def compute_qL_all( function run_factorized_fpi_end2end_padded (line 658) | def run_factorized_fpi_end2end_padded( class FilterMessage (line 689) | class FilterMessage(NamedTuple): function _normalize_preserve_zeros (line 696) | def _normalize_preserve_zeros(u: Array, function _log_predictive_normalizer (line 712) | def _log_predictive_normalizer(predicted_probs: Array, function _condition_on (line 721) | def _condition_on(A: Array, function _hmm_filter_scan_row_oriented (line 749) | def _hmm_filter_scan_row_oriented( function hmm_filter_scan_rowstoch (line 817) | def hmm_filter_scan_rowstoch( function _hmm_smoother_scan_row_oriented (line 845) | def _hmm_smoother_scan_row_oriented( function hmm_smoother_scan_rowstoch (line 904) | def hmm_smoother_scan_rowstoch( function hmm_filter_scan_colstoch (line 930) | def hmm_filter_scan_colstoch( function hmm_smoother_scan_colstoch (line 961) | def hmm_smoother_scan_colstoch( function hmm_smoother_from_filtered_colstoch (line 1020) | def hmm_smoother_from_filtered_colstoch( function run_exact_single_factor_hmm_scan (line 1103) | def run_exact_single_factor_hmm_scan( function sum_prod (line 1183) | def sum_prod(prior: list[Array]) -> Array: FILE: pymdp/control.py class Policies (line 21) | class Policies(eqx.Module): method __init__ (line 30) | def __init__(self, policy_arr: Array) -> None: method __getitem__ (line 35) | def __getitem__(self, idx: int) -> Array: method __len__ (line 38) | def __len__(self) -> int: function get_marginals (line 41) | def get_marginals(q_pi: Array, policies: Array, num_controls: Sequence[i... function sample_action (line 68) | def sample_action( function sample_policy (line 114) | def sample_policy( function construct_policies (line 151) | def construct_policies( function update_posterior_policies (line 197) | def update_posterior_policies( function compute_expected_state (line 290) | def compute_expected_state( function compute_expected_state_and_Bs (line 330) | def compute_expected_state_and_Bs( function compute_expected_obs (line 359) | def compute_expected_obs( function compute_info_gain (line 388) | def compute_info_gain( function compute_expected_utility (line 422) | def compute_expected_utility(qo: list[Array], C: list[Array], t: int = 0... function calc_negative_pA_info_gain (line 450) | def calc_negative_pA_info_gain( function calc_negative_pB_info_gain (line 493) | def calc_negative_pB_info_gain( function compute_neg_efe_policy (line 543) | def compute_neg_efe_policy( function compute_neg_efe_policy_inductive (line 631) | def compute_neg_efe_policy_inductive( function update_posterior_policies_inductive (line 734) | def update_posterior_policies_inductive( function generate_I_matrix (line 843) | def generate_I_matrix(H: list[Array], B: list[Array], threshold: float, ... function calc_inductive_value_t (line 899) | def calc_inductive_value_t( FILE: pymdp/distribution.py class Distribution (line 6) | class Distribution: method __init__ (line 8) | def __init__(self, event: dict, batch: dict = {}, data: np.ndarray | N... method data (line 32) | def data(self) -> np.ndarray: method data (line 36) | def data(self, value: np.ndarray) -> None: method get (line 42) | def get(self, batch: dict | None = None, event: dict | None = None) ->... method set (line 49) | def set( method _get_slices (line 58) | def _get_slices(self, keys: dict | None, indices: dict, full_indices: ... method _get_index (line 74) | def _get_index(self, key: Any, index_map: dict) -> int: method _get_index_from_axis (line 80) | def _get_index_from_axis(self, axis: int, element: Any) -> int | slice: method __getitem__ (line 91) | def __getitem__(self, indices: Any) -> np.ndarray: method __setitem__ (line 99) | def __setitem__(self, indices: Any, value: Any) -> None: method normalize (line 107) | def normalize(self) -> None: method __repr__ (line 110) | def __repr__(self) -> str: class DistributionIndexer (line 114) | class DistributionIndexer(dict): method __init__ (line 120) | def __init__(self, distributions: list[Distribution]) -> None: method __getitem__ (line 127) | def __getitem__(self, key: str | int) -> Distribution: method __iter__ (line 137) | def __iter__(self) -> Iterator[Distribution]: class Model (line 141) | class Model(dict): method __init__ (line 143) | def __init__( method __getattr__ (line 162) | def __getattr__(self, key: str) -> Any: function compile_model (line 172) | def compile_model(config: dict[str, Any]) -> Model: function get_dependencies (line 315) | def get_dependencies( FILE: pymdp/envs/cue_chaining.py class CueChainingEnv (line 21) | class CueChainingEnv(PymdpEnv): method __init__ (line 43) | def __init__( method coords_to_index (line 160) | def coords_to_index(self, coord: tuple[int, int]) -> int: method index_to_coords (line 166) | def index_to_coords(self, index: int) -> tuple[int, int]: method _validate_coord (line 172) | def _validate_coord(self, coord: tuple[int, int]) -> None: method _generate_A (line 178) | def _generate_A(self) -> tuple[list[jnp.ndarray], list[list[int]]]: method _generate_B (line 214) | def _generate_B(self) -> tuple[list[jnp.ndarray], list[list[int]]]: method _generate_D (line 254) | def _generate_D( FILE: pymdp/envs/env.py function _float_to_int_index (line 20) | def _float_to_int_index(x: Array) -> Array: function select_probs (line 36) | def select_probs( function cat_sample (line 67) | def cat_sample(key: Array, p: Array) -> Array: function make (line 91) | def make( class Env (line 150) | class Env(ABC): method reset (line 154) | def reset( method step (line 179) | def step( method generate_env_params (line 206) | def generate_env_params( class PymdpEnv (line 226) | class PymdpEnv(Env): method __init__ (line 236) | def __init__( method generate_env_params (line 305) | def generate_env_params( method reset (line 333) | def reset( method step (line 352) | def step( method _sample_obs (line 380) | def _sample_obs( FILE: pymdp/envs/generalized_tmaze.py function get_maze_matrix (line 15) | def get_maze_matrix(small: bool = False) -> np.ndarray: function parse_maze (line 71) | def parse_maze(maze: np.ndarray, rng_key: PRNGKeyArray) -> dict[str, Any]: function generate_A (line 171) | def generate_A(maze_info: dict[str, Any]) -> tuple[list[jnp.ndarray], li... function generate_B (line 253) | def generate_B(maze_info: dict[str, Any]) -> tuple[list[jnp.ndarray], li... function generate_D (line 323) | def generate_D(maze_info: dict[str, Any]) -> list[jnp.ndarray]: class GeneralizedTMazeEnv (line 355) | class GeneralizedTMazeEnv(PymdpEnv): method __init__ (line 361) | def __init__(self, env_info: dict[str, Any], categorical_obs: bool = F... method render (line 387) | def render(self, states: list[jnp.ndarray], mode: str = "human") -> jn... FILE: pymdp/envs/graph_worlds.py function generate_connected_clusters (line 8) | def generate_connected_clusters( class GraphEnv (line 25) | class GraphEnv(PymdpEnv): method __init__ (line 31) | def __init__( method generate_A (line 81) | def generate_A(self, graph: nx.Graph) -> tuple[list[jnp.ndarray], list... method generate_B (line 113) | def generate_B(self, graph: nx.Graph) -> tuple[list[jnp.ndarray], list... method generate_D (line 141) | def generate_D( method generate_env_params (line 157) | def generate_env_params( FILE: pymdp/envs/grid_world.py class GridWorld (line 12) | class GridWorld(PymdpEnv): method __init__ (line 73) | def __init__( method coords_to_index (line 109) | def coords_to_index(shape: Tuple[int, int], coord: Tuple[int, int]) ->... method index_to_coords (line 114) | def index_to_coords(shape: Tuple[int, int], idx: int) -> Tuple[int, int]: function _flatten_walls (line 123) | def _flatten_walls(shape: Tuple[int, int], walls: Optional[Iterable[Tupl... function _generate_A (line 135) | def _generate_A(n_states: int) -> tuple[jnp.ndarray, list[list[int]]]: function _neighbors (line 147) | def _neighbors(shape: Tuple[int, int], s: int) -> Tuple[int, int, int, i... function _generate_B (line 162) | def _generate_B( function _generate_D (line 214) | def _generate_D( FILE: pymdp/envs/rollout.py function _append_to_window (line 27) | def _append_to_window(window: Array, value: Array) -> Array: function _resolve_history_len (line 38) | def _resolve_history_len(agent: Agent, num_timesteps: int, use_windowing... function default_policy_search (line 58) | def default_policy_search(agent: Agent, qs: list[Array], rng_key: Array)... function _resolve_empirical_prior (line 65) | def _resolve_empirical_prior( function update_parameters_online (line 85) | def update_parameters_online( function _compute_sequence_empirical_prior_next (line 170) | def _compute_sequence_empirical_prior_next( function _run_sequence_fixed_window_step (line 195) | def _run_sequence_fixed_window_step( function _run_smoothing_fixed_window_step (line 233) | def _run_smoothing_fixed_window_step( function _run_non_window_step (line 259) | def _run_non_window_step( function _update_window_buffers (line 280) | def _update_window_buffers( function _init_observation_history (line 294) | def _init_observation_history(obs: Array, history_len: int, categorical_... function _init_windowed_carry (line 307) | def _init_windowed_carry( function _init_non_windowed_carry (line 355) | def _init_non_windowed_carry( function infer_and_plan (line 370) | def infer_and_plan( function rollout (line 468) | def rollout( FILE: pymdp/envs/tmaze.py class BaseTMaze (line 24) | class BaseTMaze(PymdpEnv): method __init__ (line 44) | def __init__( method _set_reward_outcome (line 141) | def _set_reward_outcome(self, A_reward: jnp.ndarray, loc: int, reward_... method generate_A (line 157) | def generate_A(self) -> tuple[list[jnp.ndarray], list[list[int]]]: method _generate_A_separate (line 164) | def _generate_A_separate(self) -> tuple[list[jnp.ndarray], list[list[i... method _generate_A_embedded (line 204) | def _generate_A_embedded(self) -> tuple[list[jnp.ndarray], list[list[i... method _valid_connections (line 240) | def _valid_connections(self) -> list[tuple[int, int]]: method generate_B (line 267) | def generate_B(self) -> tuple[list[jnp.ndarray], list[list[int]]]: method generate_D (line 302) | def generate_D(self) -> list[jnp.ndarray]: method render (line 328) | def render( class TMaze (line 494) | class TMaze(BaseTMaze): method __init__ (line 499) | def __init__( class SimplifiedTMaze (line 522) | class SimplifiedTMaze(BaseTMaze): method __init__ (line 527) | def __init__( FILE: pymdp/inference.py class VFEInfo (line 37) | class VFEInfo(TypedDict): function _select_current_obs (line 43) | def _select_current_obs(obs: list[Array] | Array, distr_obs: bool) -> li... function _truncate_for_horizon (line 56) | def _truncate_for_horizon( function _ensure_action_history_shape (line 76) | def _ensure_action_history_shape(past_actions: Array | None, num_factors... function _build_sequence_validity_masks (line 101) | def _build_sequence_validity_masks( function _condition_transitions_on_actions (line 124) | def _condition_transitions_on_actions( function _run_one_step_inference (line 171) | def _run_one_step_inference( function _run_sequence_inference (line 194) | def _run_sequence_inference( function _update_qs_history (line 274) | def _update_qs_history( function _assemble_vfe_kwargs (line 298) | def _assemble_vfe_kwargs( function update_posterior_states (line 338) | def update_posterior_states( function _joint_dist_factor (line 487) | def _joint_dist_factor( function joint_dist_factor (line 524) | def joint_dist_factor( function smoothing_ovf (line 559) | def smoothing_ovf( function smoothing_exact (line 595) | def smoothing_exact( FILE: pymdp/learning.py function update_obs_likelihood_dirichlet_m (line 11) | def update_obs_likelihood_dirichlet_m( function update_obs_likelihood_dirichlet (line 57) | def update_obs_likelihood_dirichlet( function update_state_transition_dirichlet_f (line 119) | def update_state_transition_dirichlet_f( function update_state_transition_dirichlet (line 159) | def update_state_transition_dirichlet( FILE: pymdp/legacy/agent.py class Agent (line 16) | class Agent(object): method __init__ (line 33) | def __init__( method _construct_C_prior (line 318) | def _construct_C_prior(self): method _construct_D_prior (line 324) | def _construct_D_prior(self): method _construct_policies (line 330) | def _construct_policies(self): method _construct_num_controls (line 338) | def _construct_num_controls(self): method _construct_E_prior (line 345) | def _construct_E_prior(self): method reset (line 349) | def reset(self, init_qs=None): method step_time (line 396) | def step_time(self): method set_latest_beliefs (line 420) | def set_latest_beliefs(self,last_belief=None): method get_future_qs (line 455) | def get_future_qs(self): method infer_states (line 478) | def infer_states(self, observation, distr_obs=False): method _infer_states_test (line 552) | def _infer_states_test(self, observation, distr_obs=False): method infer_policies (line 608) | def infer_policies(self): method sample_action (line 695) | def sample_action(self): method _sample_action_test (line 723) | def _sample_action_test(self): method update_A (line 749) | def update_A(self, obs): method _update_A_old (line 780) | def _update_A_old(self, obs): method update_B (line 810) | def update_B(self, qs_prev): method _update_B_old (line 841) | def _update_B_old(self, qs_prev): method update_D (line 871) | def update_D(self, qs_t0 = None): method _get_default_params (line 923) | def _get_default_params(self): FILE: pymdp/legacy/algos/fpi.py function run_vanilla_fpi (line 10) | def run_vanilla_fpi(A, obs, num_obs, num_states, prior=None, num_iter=10... function run_vanilla_fpi_factorized (line 159) | def run_vanilla_fpi_factorized(A, obs, num_obs, num_states, mb_dict, pri... function _run_vanilla_fpi_faster (line 324) | def _run_vanilla_fpi_faster(A, obs, n_observations, n_states, prior=None... FILE: pymdp/legacy/algos/mmp.py function run_mmp (line 9) | def run_mmp( function run_mmp_factorized (line 133) | def run_mmp_factorized( function _run_mmp_testing (line 297) | def _run_mmp_testing( FILE: pymdp/legacy/algos/mmp_old.py function run_mmp_old (line 11) | def run_mmp_old( FILE: pymdp/legacy/control.py function update_posterior_policies_full (line 13) | def update_posterior_policies_full( function update_posterior_policies_full_factorized (line 135) | def update_posterior_policies_full_factorized( function update_posterior_policies (line 266) | def update_posterior_policies( function update_posterior_policies_factorized (line 364) | def update_posterior_policies_factorized( function get_expected_states (line 470) | def get_expected_states(qs, B, policy): function get_expected_states_interactions (line 505) | def get_expected_states_interactions(qs, B, B_factor_list, policy): function get_expected_obs (line 543) | def get_expected_obs(qs_pi, A): function get_expected_obs_factorized (line 580) | def get_expected_obs_factorized(qs_pi, A, A_factor_list): function calc_expected_utility (line 619) | def calc_expected_utility(qo_pi, C): function calc_states_info_gain (line 664) | def calc_states_info_gain(A, qs_pi): function calc_states_info_gain_factorized (line 693) | def calc_states_info_gain_factorized(A, qs_pi, A_factor_list): function calc_pA_info_gain (line 727) | def calc_pA_info_gain(pA, qo_pi, qs_pi): function calc_pA_info_gain_factorized (line 764) | def calc_pA_info_gain_factorized(pA, qo_pi, qs_pi, A_factor_list): function calc_pB_info_gain (line 805) | def calc_pB_info_gain(pB, qs_pi, qs_prev, policy): function calc_pB_info_gain_interactions (line 856) | def calc_pB_info_gain_interactions(pB, qs_pi, qs_prev, B_factor_list, po... function calc_inductive_cost (line 910) | def calc_inductive_cost(qs, qs_pi, I, epsilon=1e-3): function construct_policies (line 953) | def construct_policies(num_states, num_controls = None, policy_len=1, co... function get_num_controls_from_policies (line 995) | def get_num_controls_from_policies(policies): function sample_action (line 1017) | def sample_action(q_pi, policies, num_controls, action_selection="determ... function _sample_action_test (line 1068) | def _sample_action_test(q_pi, policies, num_controls, action_selection="... function sample_policy (line 1126) | def sample_policy(q_pi, policies, num_controls, action_selection="determ... function _sample_policy_test (line 1168) | def _sample_policy_test(q_pi, policies, num_controls, action_selection="... function select_highest (line 1215) | def select_highest(options_array): function _select_highest_test (line 1236) | def _select_highest_test(options_array, seed=None): function backwards_induction (line 1259) | def backwards_induction(H, B, B_factor_list, threshold, depth): function calc_ambiguity_factorized (line 1316) | def calc_ambiguity_factorized(qs_pi, A, A_factor_list): function sophisticated_inference_search (line 1353) | def sophisticated_inference_search(qs, policies, A, B, C, A_factor_list,... FILE: pymdp/legacy/default_models.py function generate_epistemic_MAB_model (line 5) | def generate_epistemic_MAB_model(): function generate_grid_world_transitions (line 72) | def generate_grid_world_transitions(action_labels, num_rows = 3, num_col... FILE: pymdp/legacy/envs/env.py class Env (line 11) | class Env(object): method reset (line 30) | def reset(self, state=None): method step (line 36) | def step(self, action): method render (line 52) | def render(self): method sample_action (line 58) | def sample_action(self): method get_likelihood_dist (line 61) | def get_likelihood_dist(self): method get_transition_dist (line 66) | def get_transition_dist(self): method get_uniform_posterior (line 71) | def get_uniform_posterior(self): method get_rand_likelihood_dist (line 76) | def get_rand_likelihood_dist(self): method get_rand_transition_dist (line 81) | def get_rand_transition_dist(self): method __str__ (line 86) | def __str__(self): FILE: pymdp/legacy/envs/grid_worlds.py class GridWorldEnv (line 18) | class GridWorldEnv(Env): method __init__ (line 29) | def __init__(self, shape=[2, 2], init_state=None): method reset (line 52) | def reset(self, init_state=None): method set_state (line 71) | def set_state(self, state): method step (line 89) | def step(self, action): method render (line 108) | def render(self, title=None): method set_init_state (line 130) | def set_init_state(self, init_state=None): method _build (line 141) | def _build(self): method get_init_state_dist (line 167) | def get_init_state_dist(self, init_state=None): method get_transition_dist (line 174) | def get_transition_dist(self): method get_likelihood_dist (line 182) | def get_likelihood_dist(self): method sample_action (line 186) | def sample_action(self): method position (line 190) | def position(self): class DGridWorldEnv (line 195) | class DGridWorldEnv(object): method __init__ (line 204) | def __init__(self, shape=[2, 2], init_state=None): method reset (line 215) | def reset(self, init_state=None): method set_state (line 220) | def set_state(self, state): method step (line 224) | def step(self, action): method render (line 230) | def render(self, title=None): method set_init_state (line 244) | def set_init_state(self, init_state=None): method _build (line 255) | def _build(self): method get_init_state_dist (line 277) | def get_init_state_dist(self, init_state=None): method get_transition_dist (line 284) | def get_transition_dist(self): method get_likelihood_dist (line 292) | def get_likelihood_dist(self): method sample_action (line 296) | def sample_action(self): method position (line 300) | def position(self): FILE: pymdp/legacy/envs/tmaze.py class TMazeEnv (line 25) | class TMazeEnv(Env): method __init__ (line 27) | def __init__(self, reward_probs=None): method reset (line 56) | def reset(self, state=None): method step (line 71) | def step(self, actions): method render (line 79) | def render(self): method sample_action (line 82) | def sample_action(self): method get_likelihood_dist (line 85) | def get_likelihood_dist(self): method get_transition_dist (line 88) | def get_transition_dist(self): method get_rand_likelihood_dist (line 92) | def get_rand_likelihood_dist(self): method get_rand_transition_dist (line 95) | def get_rand_transition_dist(self): method _get_observation (line 98) | def _get_observation(self): method _construct_transition_dist (line 105) | def _construct_transition_dist(self): method _construct_likelihood_dist (line 119) | def _construct_likelihood_dist(self): method _construct_state (line 175) | def _construct_state(self, state_tuple): method state (line 184) | def state(self): method reward_condition (line 188) | def reward_condition(self): class TMazeEnvNullOutcome (line 192) | class TMazeEnvNullOutcome(Env): method __init__ (line 196) | def __init__(self, reward_probs=None): method reset (line 225) | def reset(self, state=None): method step (line 240) | def step(self, actions): method sample_action (line 249) | def sample_action(self): method get_likelihood_dist (line 252) | def get_likelihood_dist(self): method get_transition_dist (line 255) | def get_transition_dist(self): method _get_observation (line 258) | def _get_observation(self): method _construct_transition_dist (line 265) | def _construct_transition_dist(self): method _construct_likelihood_dist (line 279) | def _construct_likelihood_dist(self): method _construct_state (line 331) | def _construct_state(self, state_tuple): method state (line 341) | def state(self): method reward_condition (line 345) | def reward_condition(self): FILE: pymdp/legacy/envs/visual_foraging.py class SceneConstruction (line 28) | class SceneConstruction(Env): method __init__ (line 30) | def __init__(self, starting_loc = 'start', scene_name = 'UP_RIGHT', co... method step (line 45) | def step(self,action_label): method reset (line 72) | def reset(self): method _create_visual_array (line 79) | def _create_visual_array(self): class RandomDotMotion (line 92) | class RandomDotMotion(Env): method __init__ (line 97) | def __init__(self, precision = 1.0, dot_direction = None, sampling_sta... method reset (line 121) | def reset(self, dot_direction = None, sampling_state = None): method step (line 132) | def step(self, action): method _generate_dot_dist (line 138) | def _generate_dot_dist(self): method _get_observation (line 150) | def _get_observation(self): method _set_sampling_state (line 158) | def _set_sampling_state(self, action): method dot_direction (line 163) | def dot_direction(self): method num_directions (line 167) | def num_directions(self): method precision (line 171) | def precision(self): method coherence (line 175) | def coherence(self): function create_2x2_array (line 179) | def create_2x2_array(scene_name, config): function initialize_scene_construction_GM (line 193) | def initialize_scene_construction_GM(T = 6, reward = 2.0, punishment = -... function initialize_RDM_GM (line 284) | def initialize_RDM_GM(T=16, A_precis = 1.0, break_reward = 0.001): FILE: pymdp/legacy/inference.py function update_posterior_states_full (line 18) | def update_posterior_states_full( function update_posterior_states_full_factorized (line 89) | def update_posterior_states_full_factorized( function _update_posterior_states_full_test (line 169) | def _update_posterior_states_full_test( function average_states_over_policies (line 247) | def average_states_over_policies(qs_pi, q_pi): function update_posterior_states (line 282) | def update_posterior_states(A, obs, prior=None, **kwargs): function update_posterior_states_factorized (line 324) | def update_posterior_states_factorized(A, obs, num_obs, num_states, mb_d... FILE: pymdp/legacy/learning.py function update_obs_likelihood_dirichlet (line 9) | def update_obs_likelihood_dirichlet(pA, A, obs, qs, lr=1.0, modalities="... function update_obs_likelihood_dirichlet_factorized (line 60) | def update_obs_likelihood_dirichlet_factorized(pA, A, obs, qs, A_factor_... function update_state_likelihood_dirichlet (line 113) | def update_state_likelihood_dirichlet( function update_state_likelihood_dirichlet_interactions (line 161) | def update_state_likelihood_dirichlet_interactions( function update_state_prior_dirichlet (line 212) | def update_state_prior_dirichlet( function _prune_prior (line 251) | def _prune_prior(prior, levels_to_remove, dirichlet = False): function _prune_A (line 312) | def _prune_A(A, obs_levels_to_prune, state_levels_to_prune, dirichlet = ... function _prune_B (line 383) | def _prune_B(B, state_levels_to_prune, action_levels_to_prune, dirichlet... FILE: pymdp/legacy/maths.py function spm_dot (line 19) | def spm_dot(X, x, dims_to_omit=None): function spm_dot_classic (line 59) | def spm_dot_classic(X, x, dims_to_omit=None): function factor_dot_flex (line 109) | def factor_dot_flex(M, xs, dims, keep_dims=None): function spm_dot_old (line 131) | def spm_dot_old(X, x, dims_to_omit=None, obs_mode=False): function spm_cross (line 197) | def spm_cross(x, y=None, *args): function dot_likelihood (line 239) | def dot_likelihood(A,obs): function get_joint_likelihood (line 255) | def get_joint_likelihood(A, obs, num_states): function get_joint_likelihood_seq (line 267) | def get_joint_likelihood_seq(A, obs, num_states): function get_joint_likelihood_seq_by_modality (line 273) | def get_joint_likelihood_seq_by_modality(A, obs, num_states): function spm_norm (line 291) | def spm_norm(A): function spm_log_single (line 300) | def spm_log_single(arr): function spm_log_obj_array (line 306) | def spm_log_obj_array(obj_arr): function spm_wnorm (line 317) | def spm_wnorm(A): function spm_betaln (line 329) | def spm_betaln(z): function dirichlet_log_evidence (line 335) | def dirichlet_log_evidence(q_dir, p_dir, r_dir): function softmax (line 363) | def softmax(dist): function softmax_obj_arr (line 373) | def softmax_obj_arr(arr): function compute_accuracy (line 382) | def compute_accuracy(log_likelihood, qs): function calc_free_energy (line 396) | def calc_free_energy(qs, prior, n_factors, likelihood=None): function spm_calc_qo_entropy (line 412) | def spm_calc_qo_entropy(A, x): function spm_calc_neg_ambig (line 464) | def spm_calc_neg_ambig(A, x): function spm_MDP_G (line 517) | def spm_MDP_G(A, x): function kl_div (line 575) | def kl_div(P,Q): function entropy (line 592) | def entropy(A): FILE: pymdp/legacy/utils.py class Dimensions (line 18) | class Dimensions(object): method __init__ (line 22) | def __init__( function sample (line 39) | def sample(probabilities): function sample_obj_array (line 44) | def sample_obj_array(arr): function obj_array (line 53) | def obj_array(num_arr): function obj_array_zeros (line 59) | def obj_array_zeros(shape_list): function initialize_empty_A (line 69) | def initialize_empty_A(num_obs, num_states): function initialize_empty_B (line 78) | def initialize_empty_B(num_states, num_controls): function obj_array_uniform (line 87) | def obj_array_uniform(shape_list): function obj_array_ones (line 98) | def obj_array_ones(shape_list, scale = 1.0): function onehot (line 105) | def onehot(value, num_values): function random_A_matrix (line 110) | def random_A_matrix(num_obs, num_states, A_factor_list=None): function random_B_matrix (line 130) | def random_B_matrix(num_states, num_controls, B_factor_list=None, B_fact... function get_combination_index (line 175) | def get_combination_index(x, dims): function index_to_combination (line 199) | def index_to_combination(index, dims): function random_single_categorical (line 221) | def random_single_categorical(shape_list): function construct_controllable_B (line 235) | def construct_controllable_B(num_states, num_controls): function dirichlet_like (line 252) | def dirichlet_like(template_categorical, scale = 1.0): function get_model_dimensions (line 273) | def get_model_dimensions(A=None, B=None, factorized=False): function get_model_dimensions_from_labels (line 303) | def get_model_dimensions_from_labels(model_labels): function norm_dist (line 324) | def norm_dist(dist): function norm_dist_obj_arr (line 328) | def norm_dist_obj_arr(obj_arr): function is_normalized (line 337) | def is_normalized(dist): function is_obj_array (line 355) | def is_obj_array(arr): function to_obj_array (line 358) | def to_obj_array(arr): function obj_array_from_list (line 365) | def obj_array_from_list(list_input): function process_observation_seq (line 374) | def process_observation_seq(obs_seq, n_modalities, n_observations): function process_observation (line 388) | def process_observation(obs, num_modalities, num_observations): function convert_observation_array (line 419) | def convert_observation_array(obs, num_obs): function insert_multiple (line 467) | def insert_multiple(s, indices, items): function reduce_a_matrix (line 472) | def reduce_a_matrix(A): function construct_full_a (line 517) | def construct_full_a(A_reduced, original_factor_idx, num_states): function build_xn_vn_array (line 589) | def build_xn_vn_array(xn): function plot_beliefs (line 619) | def plot_beliefs(belief_dist, title=""): function plot_likelihood (line 632) | def plot_likelihood(A, title=""): FILE: pymdp/likelihoods.py function evolve_trials (line 7) | def evolve_trials(agent: Any, data: Any) -> Any: function aif_likelihood (line 29) | def aif_likelihood(Nb: int, Nt: int, Na: int, data: Any, agent: Any) -> ... FILE: pymdp/maths.py function stable_xlogx (line 17) | def stable_xlogx(x: ArrayLike) -> ArrayLike: function stable_entropy (line 32) | def stable_entropy(x: ArrayLike) -> ArrayLike: function stable_cross_entropy (line 47) | def stable_cross_entropy(x: ArrayLike, y: ArrayLike) -> ArrayLike: function log_stable (line 64) | def log_stable(x: ArrayLike) -> ArrayLike: function factor_dot (line 83) | def factor_dot( function factor_dot (line 110) | def factor_dot( function spm_dot_sparse (line 136) | def spm_dot_sparse( function factor_dot_flex (line 180) | def factor_dot_flex( function compute_log_likelihood_single_modality (line 214) | def compute_log_likelihood_single_modality( function compute_log_likelihood (line 241) | def compute_log_likelihood( function compute_log_likelihood_per_modality (line 266) | def compute_log_likelihood_per_modality( function _to_dense_if_sparse (line 290) | def _to_dense_if_sparse(x: ArrayLike) -> ArrayLike: function _expected_log_prob (line 296) | def _expected_log_prob(log_prob: ArrayLike, marginals: list[ArrayLike]) ... function _expected_log_prob_tensor (line 302) | def _expected_log_prob_tensor(log_prob: ArrayLike, belief: ArrayLike) ->... function _pad_sequence_with_initial_zeros (line 312) | def _pad_sequence_with_initial_zeros(x: ArrayLike) -> ArrayLike: function _ensure_vfe_action_history_shape (line 318) | def _ensure_vfe_action_history_shape( function _sum_dirichlet_kl (line 348) | def _sum_dirichlet_kl( function compute_accuracy (line 369) | def compute_accuracy( function dirichlet_kl_divergence (line 408) | def dirichlet_kl_divergence( function calc_vfe (line 448) | def calc_vfe( function multidimensional_outer (line 763) | def multidimensional_outer(arrs: list[ArrayLike]) -> ArrayLike: function _exact_wnorm (line 783) | def _exact_wnorm(A: ArrayLike) -> ArrayLike: function spm_wnorm (line 811) | def spm_wnorm(A: ArrayLike, exact_param_info_gain: bool = True) -> Array... function dirichlet_expected_value (line 843) | def dirichlet_expected_value(dir_arr: ArrayLike, event_dim: int = 0) -> ... function compute_log_likelihoods_padded (line 871) | def compute_log_likelihoods_padded(obs_padded: ArrayLike, A_padded: Arra... function deconstruct_lls (line 890) | def deconstruct_lls( function compute_log_likelihoods_flat_block_diag_einsum (line 924) | def compute_log_likelihoods_flat_block_diag_einsum( function compute_log_likelihoods_flat_block_diag (line 948) | def compute_log_likelihoods_flat_block_diag(A_big: ArrayLike, obs_big: A... function deconstruct_log_likelihoods_block_diag (line 969) | def deconstruct_log_likelihoods_block_diag( function compute_log_likelihoods_block_diag (line 993) | def compute_log_likelihoods_block_diag( function log_stable_sparse (line 1032) | def log_stable_sparse(x: ArrayLike) -> ArrayLike: function compute_log_likelihood_per_modality_end2end2_padded (line 1050) | def compute_log_likelihood_per_modality_end2end2_padded( FILE: pymdp/planning/mcts.py function _require_mctx (line 21) | def _require_mctx() -> None: function mcts_policy_search (line 30) | def mcts_policy_search( function compute_neg_efe (line 84) | def compute_neg_efe( function get_prob_single_modality (line 144) | def get_prob_single_modality(o_m: jnp.ndarray, po_m: jnp.ndarray, distr_... function make_aif_recurrent_fn (line 149) | def make_aif_recurrent_fn() -> Callable[[Any, jnp.ndarray, jnp.ndarray, ... function rollout (line 201) | def rollout( FILE: pymdp/planning/si.py function predict_fn (line 23) | def predict_fn(agent: Any, qs: list[jnp.ndarray]) -> tuple[list[jnp.ndar... function infer_fn (line 63) | def infer_fn(agent: Any, obs: list[jnp.ndarray], qs: list[jnp.ndarray]) ... function si_policy_search (line 87) | def si_policy_search( class Tree (line 208) | class Tree(eqx.Module): method __init__ (line 241) | def __init__( method __getitem__ (line 273) | def __getitem__(self, index: int) -> dict[str, Any]: method root (line 309) | def root(self) -> dict[str, Any]: function root_idx (line 319) | def root_idx(tree: Tree) -> jnp.ndarray: function _do_nothing (line 342) | def _do_nothing(tree: Tree, idx: jnp.ndarray) -> Tree: function _update_node (line 346) | def _update_node( function _remove_orphans (line 481) | def _remove_orphans(tree: Tree) -> Tree: function _calculate_probabilities (line 548) | def _calculate_probabilities(return_size: int, topk_probs: list[jnp.ndar... function _generate_observations (line 580) | def _generate_observations( function optimized_tree_search (line 609) | def optimized_tree_search( FILE: pymdp/planning/visualize.py function action_to_string (line 22) | def action_to_string(action: Any, model: Any = None) -> str: function observation_to_string (line 50) | def observation_to_string(observation: Any, model: Any = None) -> str: function formatting_jax (line 65) | def formatting_jax(value: Any, format_str: str = ".2f") -> str: function plot_plan_tree (line 86) | def plot_plan_tree( function visualize_plan_tree (line 259) | def visualize_plan_tree( function visualize_beliefs (line 343) | def visualize_beliefs(info: dict[str, Any], agent_idx: int = 0, model: A... function visualize_env (line 377) | def visualize_env( FILE: pymdp/utils.py function norm_dist (line 27) | def norm_dist(dist: Array) -> Array: function list_array_norm_dist (line 42) | def list_array_norm_dist(dist_list: list[Array]) -> list[Array]: function resolve_a_dependencies (line 58) | def resolve_a_dependencies( function resolve_b_dependencies (line 69) | def resolve_b_dependencies( function resolve_b_action_dependencies (line 79) | def resolve_b_action_dependencies( function validate_normalization (line 88) | def validate_normalization(tensor: Array, axis: int = 1, tensor_name: st... function random_factorized_categorical (line 147) | def random_factorized_categorical(key: Array, dims_per_var: Sequence[int... function random_A_array (line 169) | def random_A_array( function random_B_array (line 209) | def random_B_array( function create_controllable_B (line 262) | def create_controllable_B( function list_array_uniform (line 296) | def list_array_uniform(shape_list: Sequence[Sequence[int]]) -> list[Array]: function list_array_zeros (line 316) | def list_array_zeros(shape_list: Sequence[Sequence[int]]) -> list[Array]: function list_array_scaled (line 335) | def list_array_scaled( function get_combination_index (line 359) | def get_combination_index(x: jax.Array | np.ndarray, dims: Sequence[int]... function index_to_combination (line 386) | def index_to_combination(index: jax.Array | np.ndarray, dims: Sequence[i... function make_A_full (line 410) | def make_A_full( function fig2img (line 450) | def fig2img(fig: Any) -> np.ndarray: function A_dep_factors_dist (line 479) | def A_dep_factors_dist(num_states: Sequence[int], A_dep_len: int) -> Array: function A_dep_len_dist (line 504) | def A_dep_len_dist(choices: Array, curr_sf_dim: int, max_sf_dim: int) ->... function A_dep_len_dist_unconditional (line 526) | def A_dep_len_dist_unconditional(choices: Array) -> Array: function generate_agent_spec (line 544) | def generate_agent_spec( function generate_agent_specs_from_parameter_sets (line 730) | def generate_agent_specs_from_parameter_sets( function apply_padding_batched (line 821) | def apply_padding_batched(xs: list[Array]) -> Array: function get_sample_obs (line 854) | def get_sample_obs(num_obs: Sequence[int], batch_size: int = 1) -> list[... function init_A_and_D_from_spec (line 872) | def init_A_and_D_from_spec( function build_block_diag_A (line 943) | def build_block_diag_A( function preprocess_A_for_block_diag (line 976) | def preprocess_A_for_block_diag( function prepare_obs_for_block_diag (line 994) | def prepare_obs_for_block_diag(obs: list[Array], num_obs: Sequence[int])... function concatenate_observations_block_diag (line 1014) | def concatenate_observations_block_diag(obs_list: list[Array]) -> Array: function apply_A_end2end_padding_batched (line 1029) | def apply_A_end2end_padding_batched(A: list[Array]) -> Array: function apply_obs_end2end_padding_batched (line 1059) | def apply_obs_end2end_padding_batched(obs: list[Array], max_obs_dim: int... FILE: scripts/notebook_precommit.py function load_manifest (line 22) | def load_manifest(path: Path) -> set[str]: function strip_top_level_metadata (line 30) | def strip_top_level_metadata(notebook: nbformat.NotebookNode) -> bool: function canonicalize_execution_counts (line 39) | def canonicalize_execution_counts(notebook: nbformat.NotebookNode) -> bool: function to_repo_relative (line 77) | def to_repo_relative(path_str: str) -> str | None: function classify_notebooks (line 88) | def classify_notebooks(path_args: list[str]) -> tuple[list[Path], list[P... function report_unclassified_notebooks (line 119) | def report_unclassified_notebooks(paths: list[str]) -> int: function sanitize_ci_notebooks (line 133) | def sanitize_ci_notebooks(paths: list[Path]) -> list[str]: function sanitize_nightly_notebooks (line 148) | def sanitize_nightly_notebooks(paths: list[Path]) -> None: function validate_manifest_notebooks (line 170) | def validate_manifest_notebooks(paths: list[Path]) -> list[str]: function run_sanitize (line 201) | def run_sanitize(path_args: list[str]) -> int: function run_validate_counts (line 222) | def run_validate_counts(path_args: list[str]) -> int: function parse_args (line 241) | def parse_args() -> argparse.Namespace: function main (line 254) | def main() -> int: FILE: scripts/run_notebook_manifest.py function load_manifest (line 14) | def load_manifest(path: Path) -> list[str]: function parse_args (line 30) | def parse_args() -> argparse.Namespace: function has_explicit_numprocesses (line 48) | def has_explicit_numprocesses(pytest_args: list[str]) -> bool: function main (line 55) | def main() -> int: FILE: test/conftest.py function pytest_configure (line 4) | def pytest_configure(config): FILE: test/test_SPM_validation.py class TestSPM (line 13) | class TestSPM(unittest.TestCase): method test_active_inference_SPM_1a (line 15) | def test_active_inference_SPM_1a(self): method test_BMR_SPM_a (line 72) | def test_BMR_SPM_a(self): method test_BMR_SPM_b (line 98) | def test_BMR_SPM_b(self): FILE: test/test_agent.py class TestAgent (line 19) | class TestAgent(unittest.TestCase): method test_agent_init_without_control_fac_idx (line 21) | def test_agent_init_without_control_fac_idx(self): method test_reset_agent_VANILLA (line 39) | def test_reset_agent_VANILLA(self): method test_reset_agent_MMP_wBMA (line 57) | def test_reset_agent_MMP_wBMA(self): method test_reset_agent_MMP_wPSP (line 75) | def test_reset_agent_MMP_wPSP(self): method test_agent_infer_states (line 91) | def test_agent_infer_states(self): method test_mmp_active_inference (line 172) | def test_mmp_active_inference(self): method test_agent_with_A_learning_vanilla (line 202) | def test_agent_with_A_learning_vanilla(self): method test_agent_with_A_learning_vanilla_factorized (line 248) | def test_agent_with_A_learning_vanilla_factorized(self): method test_agent_with_B_learning_vanilla (line 293) | def test_agent_with_B_learning_vanilla(self): method test_agent_with_D_learning_vanilla (line 343) | def test_agent_with_D_learning_vanilla(self): method test_agent_with_D_learning_MMP (line 431) | def test_agent_with_D_learning_MMP(self): method test_agent_with_input_alpha (line 502) | def test_agent_with_input_alpha(self): method test_agent_with_sampling_mode (line 531) | def test_agent_with_sampling_mode(self): method test_agent_with_stochastic_action_unidimensional_control (line 562) | def test_agent_with_stochastic_action_unidimensional_control(self): method test_agent_distributional_obs (line 587) | def test_agent_distributional_obs(self): method test_agent_with_factorized_inference (line 669) | def test_agent_with_factorized_inference(self): method test_agent_with_interactions_in_B (line 704) | def test_agent_with_interactions_in_B(self): method test_actinfloop_factorized (line 735) | def test_actinfloop_factorized(self): FILE: test/test_agent_jax.py class TestAgentJax (line 23) | class TestAgentJax(unittest.TestCase): method test_no_desired_batch_no_batched_input_construction (line 25) | def test_no_desired_batch_no_batched_input_construction(self): method test_desired_batch_no_batched_input_construction (line 105) | def test_desired_batch_no_batched_input_construction(self): method test_desired_batch_and_batched_input_construction (line 188) | def test_desired_batch_and_batched_input_construction(self): method test_vmappable_agent_methods (line 334) | def test_vmappable_agent_methods(self): method test_agent_complex_action (line 372) | def test_agent_complex_action(self): method test_infer_policies_neg_efe_sign_convention (line 452) | def test_infer_policies_neg_efe_sign_convention(self): method test_agent_validate_normalization_ok (line 480) | def test_agent_validate_normalization_ok(self): method test_agent_validate_normalization_raises_on_bad_A (line 506) | def test_agent_validate_normalization_raises_on_bad_A(self): method test_agent_validate_normalization_raises_on_bad_B (line 536) | def test_agent_validate_normalization_raises_on_bad_B(self): method test_agent_with_A_learning_requires_pA (line 566) | def test_agent_with_A_learning_requires_pA(self): method test_agent_construction_jittable (line 581) | def test_agent_construction_jittable(self): method test_b_learning_updates_inductive_matrix (line 642) | def test_b_learning_updates_inductive_matrix(self): method test_valid_gradients_one_step_ahead (line 696) | def test_valid_gradients_one_step_ahead(self): method test_smoothing_ovf_updates_A_when_learn_B_false (line 837) | def test_smoothing_ovf_updates_A_when_learn_B_false(self): FILE: test/test_categorical_observations.py class TestCategoricalObservationsCore (line 13) | class TestCategoricalObservationsCore(unittest.TestCase): method test_uncertain_observation_inference (line 16) | def test_uncertain_observation_inference(self): method test_categorical_multimodality (line 44) | def test_categorical_multimodality(self): method test_multi_factor_categorical (line 66) | def test_multi_factor_categorical(self): class TestCategoricalObservationsEdgeCases (line 85) | class TestCategoricalObservationsEdgeCases(unittest.TestCase): method test_near_zero_probabilities (line 88) | def test_near_zero_probabilities(self): method test_very_peaked_distribution (line 107) | def test_very_peaked_distribution(self): method test_uniform_categorical_observation (line 126) | def test_uniform_categorical_observation(self): class TestCategoricalObservationsInferenceAlgorithms (line 145) | class TestCategoricalObservationsInferenceAlgorithms(unittest.TestCase): method setUp (line 148) | def setUp(self): method test_fpi_with_categorical_obs (line 168) | def test_fpi_with_categorical_obs(self): method test_fpi_factorized_with_categorical_obs (line 186) | def test_fpi_factorized_with_categorical_obs(self): method test_update_posterior_states_with_categorical (line 202) | def test_update_posterior_states_with_categorical(self): class TestCategoricalObservationsAgent (line 229) | class TestCategoricalObservationsAgent(unittest.TestCase): method test_agent_categorical_flag_false_discrete_obs (line 232) | def test_agent_categorical_flag_false_discrete_obs(self): method test_agent_categorical_flag_true_categorical_obs (line 255) | def test_agent_categorical_flag_true_categorical_obs(self): method test_agent_categorical_override (line 278) | def test_agent_categorical_override(self): method test_agent_preprocess_fn_default_and_warning (line 301) | def test_agent_preprocess_fn_default_and_warning(self): method test_agent_full_loop_categorical (line 340) | def test_agent_full_loop_categorical(self): class TestCategoricalObservationsControl (line 372) | class TestCategoricalObservationsControl(unittest.TestCase): method test_policy_inference_with_categorical_obs (line 375) | def test_policy_inference_with_categorical_obs(self): method test_info_gain_with_categorical_obs (line 398) | def test_info_gain_with_categorical_obs(self): method test_parameter_info_gain_with_categorical_obs (line 424) | def test_parameter_info_gain_with_categorical_obs(self): class TestCategoricalObservationsLearning (line 451) | class TestCategoricalObservationsLearning(unittest.TestCase): method test_learning_A_matrix_with_categorical (line 454) | def test_learning_A_matrix_with_categorical(self): method test_learning_with_uncertain_observations (line 494) | def test_learning_with_uncertain_observations(self): class TestCategoricalObservationsBatched (line 552) | class TestCategoricalObservationsBatched(unittest.TestCase): method test_batched_categorical_observations (line 555) | def test_batched_categorical_observations(self): FILE: test/test_control.py class TestControl (line 15) | class TestControl(unittest.TestCase): method test_get_expected_states (line 17) | def test_get_expected_states(self): method test_get_expected_states_interactions_single_factor (line 101) | def test_get_expected_states_interactions_single_factor(self): method test_get_expected_states_interactions_multi_factor (line 120) | def test_get_expected_states_interactions_multi_factor(self): method test_get_expected_states_interactions_multi_factor_independent (line 143) | def test_get_expected_states_interactions_multi_factor_independent(self): method test_get_expected_obs_factorized (line 166) | def test_get_expected_obs_factorized(self): method test_get_expected_states_and_obs (line 213) | def test_get_expected_states_and_obs(self): method test_expected_utility (line 320) | def test_expected_utility(self): method test_state_info_gain (line 390) | def test_state_info_gain(self): method test_state_info_gain_factorized (line 466) | def test_state_info_gain_factorized(self): method test_pA_info_gain (line 566) | def test_pA_info_gain(self): method test_pB_info_gain (line 618) | def test_pB_info_gain(self): method test_update_posterior_policies_utility (line 661) | def test_update_posterior_policies_utility(self): method test_temporal_C_matrix (line 809) | def test_temporal_C_matrix(self): method test_update_posterior_policies_states_infogain (line 964) | def test_update_posterior_policies_states_infogain(self): method test_update_posterior_policies_pA_infogain (line 1094) | def test_update_posterior_policies_pA_infogain(self): method test_update_posterior_policies_pB_infogain (line 1230) | def test_update_posterior_policies_pB_infogain(self): method test_update_posterior_policies_factorized (line 1363) | def test_update_posterior_policies_factorized(self): method test_sample_action (line 1400) | def test_sample_action(self): method test_sample_policy (line 1564) | def test_sample_policy(self): method test_update_posterior_policies_withE_vector (line 1586) | def test_update_posterior_policies_withE_vector(self): method test_stochastic_action_unidimensional_control (line 1627) | def test_stochastic_action_unidimensional_control(self): method test_deterministic_action_sampling_equal_value (line 1644) | def test_deterministic_action_sampling_equal_value(self): method test_deterministic_policy_selection_equal_value (line 1663) | def test_deterministic_policy_selection_equal_value(self): FILE: test/test_control_jax.py function generate_model_params (line 28) | def generate_model_params(): class TestControlJax (line 53) | class TestControlJax(unittest.TestCase): method test_get_expected_obs_factorized (line 55) | def test_get_expected_obs_factorized(self): method test_info_gain_factorized (line 76) | def test_info_gain_factorized(self): method test_update_posterior_policies_accepts_partial_param_posteriors (line 146) | def test_update_posterior_policies_accepts_partial_param_posteriors(se... method test_update_posterior_policies_requires_param_posterior_when_enabled (line 197) | def test_update_posterior_policies_requires_param_posterior_when_enabl... FILE: test/test_cue_chaining_env.py class TestCueChainingEnv (line 11) | class TestCueChainingEnv(unittest.TestCase): method test_shapes_and_dependencies (line 12) | def test_shapes_and_dependencies(self): method test_cue_and_reward_likelihood_semantics (line 34) | def test_cue_and_reward_likelihood_semantics(self): method test_env_params_broadcast (line 66) | def test_env_params_broadcast(self): method test_rollout_smoke (line 75) | def test_rollout_smoke(self): FILE: test/test_demos.py class TestDemos (line 15) | class TestDemos(unittest.TestCase): method test_agent_demo (line 17) | def test_agent_demo(self): method test_tmaze_demo (line 63) | def test_tmaze_demo(self): method test_tmaze_learning_demo (line 117) | def test_tmaze_learning_demo(self): method test_gridworld_genmodel_construction (line 187) | def test_gridworld_genmodel_construction(self): method test_gridworld_activeinference (line 253) | def test_gridworld_activeinference(self): FILE: test/test_distribution.py class TestDists (line 4) | class TestDists(unittest.TestCase): method test_distribution_slice (line 6) | def test_distribution_slice(self): method test_distribution_get_set (line 29) | def test_distribution_get_set(self): method test_agent_compile (line 74) | def test_agent_compile(self): method test_tensor_shape_change_protection (line 115) | def test_tensor_shape_change_protection(self): FILE: test/test_env.py function _stack_params (line 16) | def _stack_params(params_list): function _make_deterministic_params (line 20) | def _make_deterministic_params(toggle_action_one=True): function _make_stochastic_params (line 37) | def _make_stochastic_params(): class TestPymdpEnv (line 52) | class TestPymdpEnv(unittest.TestCase): method setUp (line 53) | def setUp(self): method test_reset_respects_state_override (line 57) | def test_reset_respects_state_override(self): method test_step_action_none_keeps_state (line 82) | def test_step_action_none_keeps_state(self): method test_env_params_override_defaults (line 107) | def test_env_params_override_defaults(self): method test_generate_env_params_batches (line 137) | def test_generate_env_params_batches(self): method test_vmap_over_keys_matches_manual (line 157) | def test_vmap_over_keys_matches_manual(self): method test_vmap_over_env_params_matches_manual (line 190) | def test_vmap_over_env_params_matches_manual(self): method test_vmap_over_state_action_and_keys (line 223) | def test_vmap_over_state_action_and_keys(self): method test_make_env_params_respects_input_batch (line 253) | def test_make_env_params_respects_input_batch(self): FILE: test/test_fpi.py class TestFPI (line 15) | class TestFPI(unittest.TestCase): method test_factorized_fpi_one_factor_one_modality (line 17) | def test_factorized_fpi_one_factor_one_modality(self): method test_factorized_fpi_one_factor_multi_modality (line 43) | def test_factorized_fpi_one_factor_multi_modality(self): method test_factorized_fpi_multi_factor_one_modality (line 68) | def test_factorized_fpi_multi_factor_one_modality(self): method test_factorized_fpi_multi_factor_multi_modality (line 93) | def test_factorized_fpi_multi_factor_multi_modality(self): method test_factorized_fpi_multi_factor_multi_modality_with_condind (line 126) | def test_factorized_fpi_multi_factor_multi_modality_with_condind(self): method test_factorized_fpi_multi_factor_single_modality_with_condind (line 162) | def test_factorized_fpi_multi_factor_single_modality_with_condind(self): FILE: test/test_grid_world_parity.py class TestGridWorldParity (line 17) | class TestGridWorldParity(unittest.TestCase): method test_jax_matches_legacy_transition_and_observation (line 20) | def test_jax_matches_legacy_transition_and_observation(self): method test_step_outputs_match (line 54) | def test_step_outputs_match(self): method test_batched_jax_env_matches_series_of_numpy_envs (line 106) | def test_batched_jax_env_matches_series_of_numpy_envs(self): FILE: test/test_hmm_associative_scan.py function _normalize_rows (line 23) | def _normalize_rows(x, axis=-1): function _random_simplex (line 28) | def _random_simplex(key, shape): function _normalize_cols (line 33) | def _normalize_cols(x, axis=0): function _random_col_stochastic (line 38) | def _random_col_stochastic(key, shape): function _sparse_absorbing_rowstoch (line 44) | def _sparse_absorbing_rowstoch(K): function _near_zero_rowstoch (line 56) | def _near_zero_rowstoch(K, tiny=1e-20): function _reference_filter (line 66) | def _reference_filter(initial_probs, transition_mats, log_likelihoods): function _reference_smoother (line 111) | def _reference_smoother(initial_probs, transition_mats, log_likelihoods): function _reference_filter_col (line 157) | def _reference_filter_col(initial_probs, B_mats, log_likelihoods): function _reference_smoother_col (line 200) | def _reference_smoother_col(initial_probs, B_mats, log_likelihoods): class TestAssociativeScanHMM (line 243) | class TestAssociativeScanHMM(unittest.TestCase): method _make_case (line 244) | def _make_case(self, key, T, K, time_varying=False): method _make_case_col (line 254) | def _make_case_col(self, key, T, K, time_varying=False): method _make_batch_row (line 264) | def _make_batch_row(self, key, batch_size, T, K, time_varying=False): method _make_batch_col (line 271) | def _make_batch_col(self, key, batch_size, T, K, time_varying=False): method _assert_close (line 278) | def _assert_close(self, a, b, atol=1e-5, rtol=1e-5): method _assert_all_finite_tree (line 281) | def _assert_all_finite_tree(self, tree): method test_filter_scan_matches_reference_stationary (line 285) | def test_filter_scan_matches_reference_stationary(self): method test_smoother_scan_matches_reference_time_varying (line 302) | def test_smoother_scan_matches_reference_time_varying(self): method test_smoother_scan_T1_edge (line 322) | def test_smoother_scan_T1_edge(self): method test_filter_scan_colstoch_matches_reference_stationary (line 342) | def test_filter_scan_colstoch_matches_reference_stationary(self): method test_smoother_scan_colstoch_matches_reference_time_varying (line 359) | def test_smoother_scan_colstoch_matches_reference_time_varying(self): method test_row_col_equivalence (line 379) | def test_row_col_equivalence(self): method test_smoother_scan_colstoch_stability_shifted_ll (line 402) | def test_smoother_scan_colstoch_stability_shifted_ll(self): method test_vmap_filter_scan_row (line 423) | def test_vmap_filter_scan_row(self): method test_vmap_smoother_scan_row (line 435) | def test_vmap_smoother_scan_row(self): method test_vmap_filter_scan_col (line 447) | def test_vmap_filter_scan_col(self): method test_vmap_smoother_scan_col (line 459) | def test_vmap_smoother_scan_col(self): method test_gradients_finite_rowstoch_smoother (line 472) | def test_gradients_finite_rowstoch_smoother(self): method test_gradients_finite_colstoch_smoother (line 506) | def test_gradients_finite_colstoch_smoother(self): method test_strict_zero_absorbing_transitions_finite_outputs_and_grads (line 542) | def test_strict_zero_absorbing_transitions_finite_outputs_and_grads(se... method test_near_zero_transitions_finite_outputs_and_grads (line 592) | def test_near_zero_transitions_finite_outputs_and_grads(self): method test_directional_derivative_rowstoch_filter_mll (line 642) | def test_directional_derivative_rowstoch_filter_mll(self): method test_update_posterior_states_exact_inference_matches_wrapper (line 679) | def test_update_posterior_states_exact_inference_matches_wrapper(self): method test_exact_inference_respects_inference_horizon (line 715) | def test_exact_inference_respects_inference_horizon(self): method test_exact_inference_loop_with_policy_inference_and_learning (line 757) | def test_exact_inference_loop_with_policy_inference_and_learning(self): method test_exact_inference_accepts_missing_past_actions (line 823) | def test_exact_inference_accepts_missing_past_actions(self): method test_exact_inference_rejects_multi_factor_models (line 855) | def test_exact_inference_rejects_multi_factor_models(self): FILE: test/test_inductive_inference_jax.py function _chain_transition (line 15) | def _chain_transition(num_states: int) -> jnp.ndarray: function _advance_or_stay_transition (line 24) | def _advance_or_stay_transition() -> jnp.ndarray: function _manual_chain_I (line 41) | def _manual_chain_I() -> list[jnp.ndarray]: class TestInductiveInferenceJax (line 53) | class TestInductiveInferenceJax(unittest.TestCase): method test_generate_I_matrix_rejects_nonpositive_depth (line 55) | def test_generate_I_matrix_rejects_nonpositive_depth(self): method test_generate_I_matrix_matches_chain_reachability (line 62) | def test_generate_I_matrix_matches_chain_reachability(self): method test_generate_I_matrix_respects_threshold_pruning (line 79) | def test_generate_I_matrix_respects_threshold_pruning(self): method test_generate_I_matrix_respects_depth_truncation (line 104) | def test_generate_I_matrix_respects_depth_truncation(self): method test_calc_inductive_value_on_path_zero_and_off_path_logeps (line 114) | def test_calc_inductive_value_on_path_zero_and_off_path_logeps(self): method test_calc_inductive_value_scales_with_off_path_mass (line 125) | def test_calc_inductive_value_scales_with_off_path_mass(self): method test_calc_inductive_value_is_zero_when_goal_unreachable (line 135) | def test_calc_inductive_value_is_zero_when_goal_unreachable(self): method test_calc_inductive_value_depends_on_map_current_state (line 153) | def test_calc_inductive_value_depends_on_map_current_state(self): method test_one_step_policy_ranking_prefers_goal_directed_action (line 168) | def test_one_step_policy_ranking_prefers_goal_directed_action(self): method test_compute_neg_efe_policy_inductive_matches_non_inductive_when_disabled (line 203) | def test_compute_neg_efe_policy_inductive_matches_non_inductive_when_d... method test_multistep_inductive_scoring_stays_anchored_to_qs_init (line 244) | def test_multistep_inductive_scoring_stays_anchored_to_qs_init(self): FILE: test/test_infer_states_optimized.py class TestInferStatesComparison (line 49) | class TestInferStatesComparison(unittest.TestCase): method setUpClass (line 60) | def setUpClass(cls): method should_skip_spec (line 100) | def should_skip_spec(cls, spec): method get_specs_subset (line 126) | def get_specs_subset(cls, max_specs=None, filter_fn=None): method _compare_results (line 146) | def _compare_results(self, r1, r2, m1, m2, spec): method _test_single_spec_with_batch (line 191) | def _test_single_spec_with_batch(self, spec, batch_size=4, A_sparsity_... method test_first_spec_with_batch (line 282) | def test_first_spec_with_batch(self): method test_small_subset_with_batch (line 286) | def test_small_subset_with_batch(self): method test_different_batch_sizes (line 308) | def test_different_batch_sizes(self): method test_low_complexity_specs_with_batch (line 316) | def test_low_complexity_specs_with_batch(self): method test_sparsity_with_batch (line 344) | def test_sparsity_with_batch(self): method test_all_agents_with_batch (line 360) | def test_all_agents_with_batch(self): FILE: test/test_inference.py class TestInference (line 15) | class TestInference(unittest.TestCase): method test_update_posterior_states (line 17) | def test_update_posterior_states(self): method test_update_posterior_states_factorized_single_factor (line 110) | def test_update_posterior_states_factorized_single_factor(self): method test_update_posterior_states_factorized (line 141) | def test_update_posterior_states_factorized(self): method test_update_posterior_states_factorized_noVFE_compute (line 175) | def test_update_posterior_states_factorized_noVFE_compute(self): FILE: test/test_inference_jax.py class TestInferenceJax (line 19) | class TestInferenceJax(unittest.TestCase): method test_fixed_point_iteration_singlestate_singleobs (line 21) | def test_fixed_point_iteration_singlestate_singleobs(self): method test_fixed_point_iteration_singlestate_multiobs (line 62) | def test_fixed_point_iteration_singlestate_multiobs(self): method test_fixed_point_iteration_multistate_singleobs (line 104) | def test_fixed_point_iteration_multistate_singleobs(self): method test_fixed_point_iteration_multistate_multiobs (line 146) | def test_fixed_point_iteration_multistate_multiobs(self): method test_fixed_point_iteration_index_observations (line 191) | def test_fixed_point_iteration_index_observations(self): FILE: test/test_jax_sparse_backend.py function make_model_configs (line 25) | def make_model_configs(source_seed=0, num_models=4) -> Dict: class TestJaxSparseOperations (line 77) | class TestJaxSparseOperations(unittest.TestCase): method test_sparse_smoothing (line 79) | def test_sparse_smoothing(self): method test_sparse_smoothing_with_invalid_actions (line 148) | def test_sparse_smoothing_with_invalid_actions(self): FILE: test/test_learning.py class TestLearning (line 8) | class TestLearning(unittest.TestCase): method test_update_pA_single_factor_all (line 10) | def test_update_pA_single_factor_all(self): method test_update_pA_single_factor_one_modality (line 47) | def test_update_pA_single_factor_one_modality(self): method test_update_pA_single_factor_some_modalities (line 78) | def test_update_pA_single_factor_some_modalities(self): method test_update_pA_multi_factor_all (line 107) | def test_update_pA_multi_factor_all(self): method test_update_pA_multi_factor_one_modality (line 140) | def test_update_pA_multi_factor_one_modality(self): method test_update_pA_multi_factor_some_modalities (line 167) | def test_update_pA_multi_factor_some_modalities(self): method test_update_pA_diff_observation_formats (line 194) | def test_update_pA_diff_observation_formats(self): method test_update_pA_factorized (line 253) | def test_update_pA_factorized(self): method test_update_pB_single_factor_no_actions (line 300) | def test_update_pB_single_factor_no_actions(self): method test_update_pB_single_factor_with_actions (line 328) | def test_update_pB_single_factor_with_actions(self): method test_update_pB_multi_factor_no_actions_all_factors (line 356) | def test_update_pB_multi_factor_no_actions_all_factors(self): method test_update_pB_multi_factor_no_actions_one_factor (line 388) | def test_update_pB_multi_factor_no_actions_one_factor(self): method test_update_pB_multi_factor_no_actions_some_factors (line 424) | def test_update_pB_multi_factor_no_actions_some_factors(self): method test_update_pB_multi_factor_with_actions_all_factors (line 460) | def test_update_pB_multi_factor_with_actions_all_factors(self): method test_update_pB_multi_factor_with_actions_one_factor (line 493) | def test_update_pB_multi_factor_with_actions_one_factor(self): method test_update_pB_multi_factor_with_actions_some_factors (line 529) | def test_update_pB_multi_factor_with_actions_some_factors(self): method test_update_pB_multi_factor_some_controllable_some_factors (line 565) | def test_update_pB_multi_factor_some_controllable_some_factors(self): method test_update_pB_interactions (line 601) | def test_update_pB_interactions(self): method test_update_pD (line 665) | def test_update_pD(self): method test_prune_prior (line 730) | def test_prune_prior(self): method test_prune_likelihoods (line 775) | def test_prune_likelihoods(self): FILE: test/test_learning_jax.py function _to_numpy_list_of_arrs (line 27) | def _to_numpy_list_of_arrs(jax_tree): class TestLearningJax (line 31) | class TestLearningJax(unittest.TestCase): method test_update_observation_likelihood_fullyconnected (line 33) | def test_update_observation_likelihood_fullyconnected(self): method test_update_observation_likelihood_factorized (line 88) | def test_update_observation_likelihood_factorized(self): method test_update_state_likelihood_single_factor_no_actions (line 142) | def test_update_state_likelihood_single_factor_no_actions(self): method test_update_state_likelihood_single_factor_with_actions (line 182) | def test_update_state_likelihood_single_factor_with_actions(self): method test_update_state_likelihood_multi_factor_all_factors_no_actions (line 222) | def test_update_state_likelihood_multi_factor_all_factors_no_actions(s... method test_update_state_likelihood_multi_factor_all_factors_with_actions (line 261) | def test_update_state_likelihood_multi_factor_all_factors_with_actions... method test_update_state_likelihood_multi_factor_some_factors_no_action (line 298) | def test_update_state_likelihood_multi_factor_some_factors_no_action(s... method test_update_state_likelihood_with_interactions (line 342) | def test_update_state_likelihood_with_interactions(self): method test_update_state_likelihood_single_factor_sequence_joints (line 389) | def test_update_state_likelihood_single_factor_sequence_joints(self): FILE: test/test_message_passing_jax.py function make_model_configs (line 27) | def make_model_configs(source_seed=0, num_models=3) -> Dict: class TestMessagePassing (line 59) | class TestMessagePassing(unittest.TestCase): method test_fixed_point_iteration (line 61) | def test_fixed_point_iteration(self): method test_fixed_point_iteration_factorized_fullyconnected (line 90) | def test_fixed_point_iteration_factorized_fullyconnected(self): method test_fixed_point_iteration_factorized_sparsegraph (line 120) | def test_fixed_point_iteration_factorized_sparsegraph(self): method test_marginal_message_passing (line 153) | def test_marginal_message_passing(self): method test_variational_message_passing_with_transition_dependencies (line 204) | def test_variational_message_passing_with_transition_dependencies(self): FILE: test/test_mmp.py class MMP (line 22) | class MMP(unittest.TestCase): method test_mmp_a (line 24) | def test_mmp_a(self): method test_mmp_b (line 64) | def test_mmp_b(self): method test_mmp_c (line 97) | def test_mmp_c(self): method test_mmp_d (line 131) | def test_mmp_d(self): FILE: test/test_param_info_gain_jax.py function test_exact_wnorm_finite (line 11) | def test_exact_wnorm_finite(scale): function test_exact_wnorm_mathematical_correctness (line 30) | def test_exact_wnorm_mathematical_correctness(): function test_calc_negative_pA_info_gain_precision (line 60) | def test_calc_negative_pA_info_gain_precision(seed): FILE: test/test_pybefit_model_fitting.py function _build_tmaze_agent_transform (line 19) | def _build_tmaze_agent_transform(task): function test_pybefit_tmaze_predictive_smoke (line 72) | def test_pybefit_tmaze_predictive_smoke(): FILE: test/test_rollout_function.py class TestRolloutFunction (line 28) | class TestRolloutFunction(unittest.TestCase): method setUp (line 29) | def setUp(self): method build_agent_env (line 37) | def build_agent_env( method manual_windowed_rollout_reference (line 105) | def manual_windowed_rollout_reference(self, agent, env, num_steps, rng... method test_rollout_collects_time_series (line 254) | def test_rollout_collects_time_series(self): method test_rollout_env_state_matches_manual_steps (line 279) | def test_rollout_env_state_matches_manual_steps(self): method test_online_learning_updates_A_during_scan (line 347) | def test_online_learning_updates_A_during_scan(self): method test_offline_learning_defers_A_update (line 364) | def test_offline_learning_defers_A_update(self): method test_online_learning_updates_B (line 380) | def test_online_learning_updates_B(self): method test_online_learning_updates_B_for_sequence_inference (line 399) | def test_online_learning_updates_B_for_sequence_inference(self): method test_online_learning_updates_for_smoothing_inference_with_horizon (line 426) | def test_online_learning_updates_for_smoothing_inference_with_horizon(... method test_online_learning_updates_A_only_for_sequence_inference (line 477) | def test_online_learning_updates_A_only_for_sequence_inference(self): method test_sequence_rollout_updates_empirical_prior_with_finite_horizon (line 506) | def test_sequence_rollout_updates_empirical_prior_with_finite_horizon(... method test_sequence_rollout_keeps_empirical_prior_fixed_during_warmup (line 527) | def test_sequence_rollout_keeps_empirical_prior_fixed_during_warmup(se... method test_rollout_caps_history_without_inference_horizon (line 551) | def test_rollout_caps_history_without_inference_horizon(self): method test_rollout_supports_multiple_inference_algorithms (line 570) | def test_rollout_supports_multiple_inference_algorithms(self): method test_rollout_categorical_obs_matches_discrete_semantics (line 593) | def test_rollout_categorical_obs_matches_discrete_semantics(self): method test_tmaze_rollout_supports_categorical_observations (line 642) | def test_tmaze_rollout_supports_categorical_observations(self): method test_sequence_rollout_supports_interacting_transition_dependencies (line 672) | def test_sequence_rollout_supports_interacting_transition_dependencies... method test_rollout_modes_for_ovf_and_exact (line 728) | def test_rollout_modes_for_ovf_and_exact(self): method _assert_rollout_matches_manual_reference (line 786) | def _assert_rollout_matches_manual_reference(self, algo, seed, include... method test_sequence_rollout_matches_manual_window_branch_reference (line 819) | def test_sequence_rollout_matches_manual_window_branch_reference(self): method test_smoothing_rollout_matches_manual_window_branch_reference (line 824) | def test_smoothing_rollout_matches_manual_window_branch_reference(self): method test_rollout_with_custom_policy_search_and_initial_carry (line 830) | def test_rollout_with_custom_policy_search_and_initial_carry(self): method test_offline_B_learning_matches_outer_products (line 874) | def test_offline_B_learning_matches_outer_products(self): FILE: test/test_sophisticated_inference_jax.py function _build_single_cue_model (line 17) | def _build_single_cue_model(): function _build_dual_cue_model (line 60) | def _build_dual_cue_model(): class TestSophisticatedInferenceJax (line 116) | class TestSophisticatedInferenceJax(unittest.TestCase): method _run_si_search (line 117) | def _run_si_search(self, agent, horizon): method _run_vanilla_search (line 132) | def _run_vanilla_search(self, agent): method test_si_accepts_costly_informative_cue (line 137) | def test_si_accepts_costly_informative_cue(self): method test_si_ignores_irrelevant_distractor_cue (line 200) | def test_si_ignores_irrelevant_distractor_cue(self): FILE: test/test_tmaze_envs.py class TestTMazeVariants (line 8) | class TestTMazeVariants(unittest.TestCase): method test_classic_shapes (line 9) | def test_classic_shapes(self): method test_simplified_shapes (line 24) | def test_simplified_shapes(self): method test_classic_cue_validity (line 38) | def test_classic_cue_validity(self): method test_simplified_cue_validity (line 48) | def test_simplified_cue_validity(self): method test_reward_outcomes_independent (line 58) | def test_reward_outcomes_independent(self): method test_reward_outcomes_dependent (line 72) | def test_reward_outcomes_dependent(self): method test_simplified_reward_and_punishment_probabilities (line 86) | def test_simplified_reward_and_punishment_probabilities(self): method test_classic_transition_connectivity (line 100) | def test_classic_transition_connectivity(self): method test_simplified_transition_connectivity (line 124) | def test_simplified_transition_connectivity(self): method test_render_accepts_singleton_discrete_and_categorical_observations (line 135) | def test_render_accepts_singleton_discrete_and_categorical_observation... FILE: test/test_tmaze_recoverability.py function _small_cfg (line 8) | def _small_cfg(parameterization: str) -> RecoverabilityConfig: function test_tmaze_recoverability_three_param_smoke (line 21) | def test_tmaze_recoverability_three_param_smoke(): function test_tmaze_recoverability_reward_only_smoke (line 32) | def test_tmaze_recoverability_reward_only_smoke(): FILE: test/test_utils.py class TestUtils (line 15) | class TestUtils(unittest.TestCase): method test_obj_array_from_list (line 16) | def test_obj_array_from_list(self): FILE: test/test_utils_jax.py class TestUtils (line 19) | class TestUtils(unittest.TestCase): method test_random_factorized_categorical (line 21) | def test_random_factorized_categorical(self): method test_random_A_array_shapes_and_normalization (line 39) | def test_random_A_array_shapes_and_normalization(self): method test_random_A_array_defaults_to_all_factors (line 63) | def test_random_A_array_defaults_to_all_factors(self): method test_random_B_array_shapes_and_normalization (line 80) | def test_random_B_array_shapes_and_normalization(self): method test_random_B_array_defaults_to_self_dependencies (line 109) | def test_random_B_array_defaults_to_self_dependencies(self): method test_random_B_array_accepts_out_of_order_control_dependency_lists (line 130) | def test_random_B_array_accepts_out_of_order_control_dependency_lists(... method test_norm_dist_list_version (line 150) | def test_norm_dist_list_version(self): method test_get_combination_index (line 164) | def test_get_combination_index(self): method test_index_to_combination (line 184) | def test_index_to_combination(self): method test_validate_normalization_ok (line 204) | def test_validate_normalization_ok(self): method test_validate_normalization_zero_filled_raises (line 213) | def test_validate_normalization_zero_filled_raises(self): method test_validate_normalization_not_normalised_raises (line 221) | def test_validate_normalization_not_normalised_raises(self): method test_validate_normalization_axis_argument (line 230) | def test_validate_normalization_axis_argument(self): method test_create_controllable_B_matches_legacy (line 242) | def test_create_controllable_B_matches_legacy(self): FILE: test/test_vfe_jax.py function _manual_dirichlet_kl (line 20) | def _manual_dirichlet_kl(q_dir: np.ndarray, p_dir: np.ndarray) -> float: function _bruteforce_sequence_vfe (line 34) | def _bruteforce_sequence_vfe( function _filtered_history_single_factor (line 57) | def _filtered_history_single_factor( function _factorized_fpi_vfe_history (line 86) | def _factorized_fpi_vfe_history( class TestCanonicalVFE (line 134) | class TestCanonicalVFE(unittest.TestCase): method test_calc_vfe_single_step_matches_legacy_full_model (line 135) | def test_calc_vfe_single_step_matches_legacy_full_model(self): method test_calc_vfe_sequence_matches_manual_terms_and_parameter_kls (line 184) | def test_calc_vfe_sequence_matches_manual_terms_and_parameter_kls(self): method test_calc_vfe_rejects_multifactor_1d_past_actions (line 261) | def test_calc_vfe_rejects_multifactor_1d_past_actions(self): method test_calc_vfe_accepts_multifactor_single_transition_action_history_as_2d (line 293) | def test_calc_vfe_accepts_multifactor_single_transition_action_history... method test_calc_vfe_rejects_mismatched_past_action_history_length_when_transitions_used (line 335) | def test_calc_vfe_rejects_mismatched_past_action_history_length_when_t... method test_calc_vfe_accepts_mismatched_past_action_history_when_no_transitions_used (line 364) | def test_calc_vfe_accepts_mismatched_past_action_history_when_no_trans... method test_update_posterior_states_rejects_multifactor_1d_past_actions (line 384) | def test_update_posterior_states_rejects_multifactor_1d_past_actions(s... method test_update_posterior_states_accepts_multifactor_single_transition_action_history_as_2d (line 416) | def test_update_posterior_states_accepts_multifactor_single_transition... method test_update_posterior_states_rejects_mismatched_past_action_history_length (line 459) | def test_update_posterior_states_rejects_mismatched_past_action_histor... method test_sequence_inference_uses_singleton_control_transitions_without_past_actions (line 488) | def test_sequence_inference_uses_singleton_control_transitions_without... method test_sequence_return_info_uses_singleton_control_transitions_without_past_actions (line 516) | def test_sequence_return_info_uses_singleton_control_transitions_witho... method test_return_info_requires_prior (line 556) | def test_return_info_requires_prior(self): method test_calc_vfe_gradients_are_finite (line 577) | def test_calc_vfe_gradients_are_finite(self): method test_calc_vfe_sequence_gradients_are_finite_without_joint_qs (line 595) | def test_calc_vfe_sequence_gradients_are_finite_without_joint_qs(self): method test_calc_vfe_sequence_is_jittable_without_joint_qs (line 642) | def test_calc_vfe_sequence_is_jittable_without_joint_qs(self): method test_calc_vfe_sequence_gradients_are_finite_with_joint_qs (line 675) | def test_calc_vfe_sequence_gradients_are_finite_with_joint_qs(self): method test_calc_vfe_sequence_is_jittable_with_joint_qs (line 732) | def test_calc_vfe_sequence_is_jittable_with_joint_qs(self): method test_factorized_fpi_vfe_decreases_monotonically (line 772) | def test_factorized_fpi_vfe_decreases_monotonically(self): method test_update_posterior_states_return_info_includes_vfe (line 815) | def test_update_posterior_states_return_info_includes_vfe(self): method test_agent_infer_states_return_info_batches_sequence_vfe (line 848) | def test_agent_infer_states_return_info_batches_sequence_vfe(self): method test_sequence_return_info_accepts_missing_past_actions (line 906) | def test_sequence_return_info_accepts_missing_past_actions(self): method test_sequence_inference_requires_past_actions_for_multiaction_transitions (line 937) | def test_sequence_inference_requires_past_actions_for_multiaction_tran... method test_calc_vfe_with_smoothed_exact_posterior_matches_bruteforce_sequence_vfe (line 967) | def test_calc_vfe_with_smoothed_exact_posterior_matches_bruteforce_seq... method test_calc_vfe_with_smoothed_ovf_posterior_matches_bruteforce_sequence_vfe (line 1006) | def test_calc_vfe_with_smoothed_ovf_posterior_matches_bruteforce_seque... FILE: test/test_wrappers.py class TestWrappers (line 4) | class TestWrappers(unittest.TestCase): method test_get_model_dimensions_from_labels (line 6) | def test_get_model_dimensions_from_labels(self):