SYMBOL INDEX (127 symbols across 14 files) FILE: agents/fql.py class FQLAgent (line 15) | class FQLAgent(flax.struct.PyTreeNode): method critic_loss (line 22) | def critic_loss(self, batch, grad_params, rng): method actor_loss (line 46) | def actor_loss(self, batch, grad_params, rng): method total_loss (line 95) | def total_loss(self, batch, grad_params, rng=None): method target_update (line 113) | def target_update(self, network, module_name): method update (line 123) | def update(self, batch): method sample_actions (line 136) | def sample_actions( method compute_flow_actions (line 156) | def compute_flow_actions( method create (line 174) | def create( function get_config (line 249) | def get_config(): FILE: agents/ifql.py class IFQLAgent (line 15) | class IFQLAgent(flax.struct.PyTreeNode): method expectile_loss (line 26) | def expectile_loss(adv, diff, expectile): method value_loss (line 31) | def value_loss(self, batch, grad_params): method critic_loss (line 45) | def critic_loss(self, batch, grad_params): method actor_loss (line 60) | def actor_loss(self, batch, grad_params, rng=None): method total_loss (line 79) | def total_loss(self, batch, grad_params, rng=None): method target_update (line 100) | def target_update(self, network, module_name): method update (line 110) | def update(self, batch): method sample_actions (line 123) | def sample_actions( method create (line 158) | def create( function get_config (line 231) | def get_config(): FILE: agents/iql.py class IQLAgent (line 15) | class IQLAgent(flax.struct.PyTreeNode): method expectile_loss (line 23) | def expectile_loss(adv, diff, expectile): method value_loss (line 28) | def value_loss(self, batch, grad_params): method critic_loss (line 42) | def critic_loss(self, batch, grad_params): method actor_loss (line 57) | def actor_loss(self, batch, grad_params, rng=None): method total_loss (line 115) | def total_loss(self, batch, grad_params, rng=None): method target_update (line 136) | def target_update(self, network, module_name): method update (line 146) | def update(self, batch): method sample_actions (line 159) | def sample_actions( method create (line 172) | def create( function get_config (line 242) | def get_config(): FILE: agents/rebrac.py class ReBRACAgent (line 16) | class ReBRACAgent(flax.struct.PyTreeNode): method critic_loss (line 26) | def critic_loss(self, batch, grad_params, rng): method actor_loss (line 56) | def actor_loss(self, batch, grad_params, rng): method total_loss (line 89) | def total_loss(self, batch, grad_params, full_update=True, rng=None): method target_update (line 112) | def target_update(self, network, module_name): method update (line 122) | def update(self, batch, full_update=True): method sample_actions (line 138) | def sample_actions( method create (line 156) | def create( function get_config (line 222) | def get_config(): FILE: agents/sac.py class SACAgent (line 14) | class SACAgent(flax.struct.PyTreeNode): method critic_loss (line 24) | def critic_loss(self, batch, grad_params, rng): method actor_loss (line 53) | def actor_loss(self, batch, grad_params, rng): method total_loss (line 87) | def total_loss(self, batch, grad_params, rng=None): method target_update (line 105) | def target_update(self, network, module_name): method update (line 115) | def update(self, batch): method sample_actions (line 128) | def sample_actions( method create (line 141) | def create( function get_config (line 203) | def get_config(): FILE: envs/d4rl_utils.py function make_env (line 9) | def make_env(env_name): function get_dataset (line 16) | def get_dataset( FILE: envs/env_utils.py class EpisodeMonitor (line 13) | class EpisodeMonitor(gymnasium.Wrapper): method __init__ (line 16) | def __init__(self, env, filter_regexes=None): method _reset_stats (line 22) | def _reset_stats(self): method step (line 27) | def step(self, action): method reset (line 55) | def reset(self, *args, **kwargs): class FrameStackWrapper (line 60) | class FrameStackWrapper(gymnasium.Wrapper): method __init__ (line 63) | def __init__(self, env, num_stack): method get_observation (line 73) | def get_observation(self): method reset (line 77) | def reset(self, **kwargs): method step (line 85) | def step(self, action): function make_env_and_datasets (line 91) | def make_env_and_datasets(env_name, frame_stack=None, action_clip_eps=1e... FILE: main.py function main (line 49) | def main(_): FILE: utils/datasets.py function get_size (line 9) | def get_size(data): function random_crop (line 16) | def random_crop(img, crop_from, padding): function batched_random_crop (line 29) | def batched_random_crop(imgs, crop_froms, padding): class Dataset (line 34) | class Dataset(FrozenDict): method create (line 38) | def create(cls, freeze=True, **fields): method __init__ (line 51) | def __init__(self, *args, **kwargs): method get_random_idxs (line 62) | def get_random_idxs(self, num_idxs): method sample (line 66) | def sample(self, batch_size: int, idxs=None): method get_subset (line 92) | def get_subset(self, idxs): method augment (line 100) | def augment(self, batch, keys): class ReplayBuffer (line 113) | class ReplayBuffer(Dataset): method create (line 120) | def create(cls, transition, size): method create_from_initial_dataset (line 136) | def create_from_initial_dataset(cls, init_dataset, size): method __init__ (line 154) | def __init__(self, *args, **kwargs): method add_transition (line 161) | def add_transition(self, transition): method clear (line 171) | def clear(self): FILE: utils/encoders.py class ResnetStack (line 10) | class ResnetStack(nn.Module): method __call__ (line 18) | def __call__(self, x): class ImpalaEncoder (line 60) | class ImpalaEncoder(nn.Module): method setup (line 70) | def setup(self): method __call__ (line 83) | def __call__(self, x, train=True, cond_var=None): FILE: utils/evaluation.py function supply_rng (line 8) | def supply_rng(f, rng=jax.random.PRNGKey(0)): function flatten (line 19) | def flatten(d, parent_key='', sep='.'): function add_to (line 31) | def add_to(dict_of_lists, single_dict): function evaluate (line 37) | def evaluate( FILE: utils/flax_utils.py class ModuleDict (line 16) | class ModuleDict(nn.Module): method __call__ (line 28) | def __call__(self, *args, name=None, **kwargs): class TrainState (line 53) | class TrainState(flax.struct.PyTreeNode): method create (line 73) | def create(cls, model_def, params, tx=None, **kwargs): method __call__ (line 90) | def __call__(self, *args, params=None, method=None, **kwargs): method select (line 116) | def select(self, name): method apply_gradients (line 120) | def apply_gradients(self, grads, **kwargs): method apply_loss_fn (line 132) | def apply_loss_fn(self, loss_fn): function save_agent (line 162) | def save_agent(agent, save_dir, epoch): function restore_agent (line 181) | def restore_agent(agent, restore_path, restore_epoch): FILE: utils/log_utils.py class CsvLogger (line 12) | class CsvLogger: method __init__ (line 15) | def __init__(self, path): method log (line 21) | def log(self, row, step): method close (line 35) | def close(self): function get_exp_name (line 40) | def get_exp_name(seed): function get_flag_dict (line 53) | def get_flag_dict(): function setup_wandb (line 62) | def setup_wandb( function reshape_video (line 94) | def reshape_video(v, n_cols=None): function get_wandb_video (line 116) | def get_wandb_video(renders=None, n_cols=None, fps=15): FILE: utils/networks.py function default_init (line 8) | def default_init(scale=1.0): function ensemblize (line 13) | def ensemblize(cls, num_qs, in_axes=None, out_axes=0, **kwargs): class Identity (line 26) | class Identity(nn.Module): method __call__ (line 29) | def __call__(self, x): class MLP (line 33) | class MLP(nn.Module): method __call__ (line 51) | def __call__(self, x): class LogParam (line 63) | class LogParam(nn.Module): method __call__ (line 69) | def __call__(self): class TransformedWithMode (line 74) | class TransformedWithMode(distrax.Transformed): method mode (line 77) | def mode(self): class Actor (line 81) | class Actor(nn.Module): method setup (line 108) | def setup(self): method __call__ (line 117) | def __call__( class Value (line 152) | class Value(nn.Module): method setup (line 169) | def setup(self): method __call__ (line 177) | def __call__(self, observations, actions=None): class ActorVectorField (line 197) | class ActorVectorField(nn.Module): method setup (line 212) | def setup(self) -> None: method __call__ (line 216) | def __call__(self, observations, actions, times=None, is_encoded=False):