SYMBOL INDEX (75 symbols across 6 files) FILE: src/eval_flow.py class NaiveMethodConfig (line 25) | class NaiveMethodConfig: class RealtimeMethodConfig (line 30) | class RealtimeMethodConfig: class BIDMethodConfig (line 36) | class BIDMethodConfig: class EvalConfig (line 42) | class EvalConfig: function eval (line 55) | def eval( function main (line 171) | def main( FILE: src/generate_data.py class Config (line 25) | class Config: class Data (line 54) | class Data: class StepCarry (line 64) | class StepCarry: function main (line 71) | def main(config: Config): FILE: src/model.py class ModelConfig (line 12) | class ModelConfig: function posemb_sincos (line 21) | def posemb_sincos(pos: jax.Array, embedding_dim: int, min_period: float,... function get_prefix_weights (line 40) | def get_prefix_weights(start: int, end: int, total: int, schedule: Prefi... class MLPMixerBlock (line 66) | class MLPMixerBlock(nnx.Module): method __init__ (line 67) | def __init__( method __call__ (line 79) | def __call__(self, x: jax.Array, adaln_cond: jax.Array) -> jax.Array: class FlowPolicy (line 103) | class FlowPolicy(nnx.Module): method __init__ (line 104) | def __init__( method __call__ (line 140) | def __call__(self, obs: jax.Array, x_t: jax.Array, time: jax.Array) ->... method action (line 160) | def action(self, rng: jax.Array, obs: jax.Array, num_steps: int) -> ja... method bid_action (line 173) | def bid_action( method realtime_action (line 219) | def realtime_action( method loss (line 267) | def loss(self, rng: jax.Array, obs: jax.Array, action: jax.Array): FILE: src/render_levels.py function load_levels (line 23) | def load_levels(paths): function main (line 38) | def main(): FILE: src/train_expert.py class Config (line 26) | class Config: class BatchEnvWrapper (line 78) | class BatchEnvWrapper(wrappers.GymnaxWrapper): method __init__ (line 81) | def __init__(self, env, num: int): method reset (line 85) | def reset(self, rng, params): method reset_to_level (line 88) | def reset_to_level(self, rng, level, params): method step (line 93) | def step(self, rng, state, action, params): class DenseRewardState (line 98) | class DenseRewardState: class DenseRewardWrapper (line 104) | class DenseRewardWrapper(wrappers.GymnaxWrapper): method __init__ (line 105) | def __init__(self, env): method step (line 108) | def step(self, key, state, action, params=None): method reset (line 116) | def reset(self, rng, params=None): method reset_to_level (line 120) | def reset_to_level(self, rng, level, params=None): class ActionHistoryWrapper (line 125) | class ActionHistoryWrapper(wrappers.UnderspecifiedEnvWrapper): method __init__ (line 126) | def __init__(self, env): method step_env (line 129) | def step_env(self, key, state, action, params): method reset_to_level (line 134) | def reset_to_level(self, rng, level, params): method action_space (line 139) | def action_space(self, params): class NoisyActionWrapper (line 143) | class NoisyActionWrapper(wrappers.UnderspecifiedEnvWrapper): method __init__ (line 144) | def __init__(self, env): method step_env (line 147) | def step_env(self, key, state, action, params): method reset_to_level (line 152) | def reset_to_level(self, rng, level, params): method action_space (line 155) | def action_space(self, params): class StickyActionState (line 160) | class StickyActionState: class StickyActionWrapper (line 165) | class StickyActionWrapper(wrappers.UnderspecifiedEnvWrapper): method __init__ (line 166) | def __init__(self, env, stickiness: float): method step_env (line 170) | def step_env(self, key, state, action, params): method reset_to_level (line 176) | def reset_to_level(self, rng, level, params): method action_space (line 186) | def action_space(self, params): class ObsHistoryState (line 191) | class ObsHistoryState: class ObsHistoryWrapper (line 197) | class ObsHistoryWrapper(wrappers.UnderspecifiedEnvWrapper): method __init__ (line 198) | def __init__(self, env, history_length: int): method step_env (line 202) | def step_env(self, key, state, action, params): method reset_to_level (line 207) | def reset_to_level(self, rng, level, params): method action_space (line 212) | def action_space(self, params): method get_original_obs (line 216) | def get_original_obs(env_state) -> jax.Array: function make_squashed_normal_diag (line 222) | def make_squashed_normal_diag(mean, std, num_motor_bindings: int): class Agent (line 232) | class Agent(nnx.Module): method __init__ (line 233) | def __init__(self, obs_dim: int, action_dim: int, layer_width: int, *,... method value (line 250) | def value(self, obs: jax.Array) -> jax.Array: method action (line 253) | def action(self, obs: jax.Array): class Transition (line 260) | class Transition: class StepCarry (line 272) | class StepCarry: class UpdateCarry (line 282) | class UpdateCarry: class TrainCarry (line 292) | class TrainCarry: function make_render_video (line 297) | def make_render_video(render_pixels): function load_levels (line 307) | def load_levels(paths: Sequence[str], static_env_params: kenv_state.Stat... function main (line 319) | def main(config: Config): FILE: src/train_flow.py class Config (line 32) | class Config: class EpochCarry (line 63) | class EpochCarry: function main (line 69) | def main(config: Config):