SYMBOL INDEX (148 symbols across 18 files) FILE: leanrl/dqn.py class Args (line 21) | class Args: function make_env (line 67) | def make_env(env_id, seed, idx, capture_video, run_name): class QNetwork (line 83) | class QNetwork(nn.Module): method __init__ (line 84) | def __init__(self, env): method forward (line 94) | def forward(self, x): function linear_schedule (line 98) | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): FILE: leanrl/dqn_jax.py class Args (line 23) | class Args: function make_env (line 65) | def make_env(env_id, seed, idx, capture_video, run_name): class QNetwork (line 81) | class QNetwork(nn.Module): method __call__ (line 85) | def __call__(self, x: jnp.ndarray): class TrainState (line 94) | class TrainState(TrainState): function linear_schedule (line 98) | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): function update (line 158) | def update(q_state, observations, actions, next_observations, rewards, d... FILE: leanrl/dqn_torchcompile.py class Args (line 24) | class Args: function make_env (line 75) | def make_env(env_id, seed, idx, capture_video, run_name): class QNetwork (line 91) | class QNetwork(nn.Module): method __init__ (line 92) | def __init__(self, n_obs, n_act, device=None): method forward (line 102) | def forward(self, x): function linear_schedule (line 106) | def linear_schedule(start_e: float, end_e: float, duration: int): function update (line 153) | def update(data): function policy (line 166) | def policy(obs, epsilon): FILE: leanrl/ppo_atari_envpool.py class Args (line 21) | class Args: class RecordEpisodeStatistics (line 81) | class RecordEpisodeStatistics(gym.Wrapper): method __init__ (line 82) | def __init__(self, env, deque_size=100): method reset (line 88) | def reset(self, **kwargs): method step (line 97) | def step(self, action): function layer_init (line 115) | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): class Agent (line 121) | class Agent(nn.Module): method __init__ (line 122) | def __init__(self, envs): method get_value (line 138) | def get_value(self, x): method get_action_and_value (line 141) | def get_action_and_value(self, x, action=None): FILE: leanrl/ppo_atari_envpool_torchcompile.py class Args (line 38) | class Args: class RecordEpisodeStatistics (line 103) | class RecordEpisodeStatistics(gym.Wrapper): method __init__ (line 104) | def __init__(self, env, deque_size=100): method reset (line 110) | def reset(self, **kwargs): method step (line 119) | def step(self, action): function layer_init (line 137) | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): class Agent (line 143) | class Agent(nn.Module): method __init__ (line 144) | def __init__(self, envs, device=None): method get_value (line 160) | def get_value(self, x): method get_action_and_value (line 163) | def get_action_and_value(self, obs, action=None): function gae (line 172) | def gae(next_obs, next_done, container): function rollout (line 198) | def rollout(obs, done, avg_returns=[]): function update (line 235) | def update(obs, actions, logprobs, advantages, returns, vals): FILE: leanrl/ppo_atari_envpool_xla_jax.py class Args (line 30) | class Args: class Network (line 89) | class Network(nn.Module): method __call__ (line 91) | def __call__(self, x): class Critic (line 127) | class Critic(nn.Module): method __call__ (line 129) | def __call__(self, x): class Actor (line 133) | class Actor(nn.Module): method __call__ (line 137) | def __call__(self, x): class AgentParams (line 142) | class AgentParams: class Storage (line 149) | class Storage: class EpisodeStatistics (line 161) | class EpisodeStatistics: function step_env_wrappeed (line 210) | def step_env_wrappeed(episode_stats, handle, action): function linear_schedule (line 229) | def linear_schedule(count): function get_action_and_value (line 270) | def get_action_and_value( function get_action_and_value2 (line 298) | def get_action_and_value2( function compute_gae (line 316) | def compute_gae( function update_ppo (line 341) | def update_ppo( function rollout (line 398) | def rollout(agent_state, episode_stats, next_obs, next_done, storage, ke... FILE: leanrl/ppo_continuous_action.py class Args (line 20) | class Args: function make_env (line 79) | def make_env(env_id, idx, capture_video, run_name, gamma): function layer_init (line 98) | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): class Agent (line 104) | class Agent(nn.Module): method __init__ (line 105) | def __init__(self, envs): method get_value (line 123) | def get_value(self, x): method get_action_and_value (line 126) | def get_action_and_value(self, x, action=None): FILE: leanrl/ppo_continuous_action_torchcompile.py class Args (line 29) | class Args: function make_env (line 94) | def make_env(env_id, idx, capture_video, run_name, gamma): function layer_init (line 113) | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): class Agent (line 119) | class Agent(nn.Module): method __init__ (line 120) | def __init__(self, n_obs, n_act, device=None): method get_value (line 138) | def get_value(self, x): method get_action_and_value (line 141) | def get_action_and_value(self, obs, action=None): function gae (line 151) | def gae(next_obs, next_done, container): function rollout (line 177) | def rollout(obs, done, avg_returns=[]): function update (line 213) | def update(obs, actions, logprobs, advantages, returns, vals): function step_func (line 298) | def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor,... FILE: leanrl/sac_continuous_action.py class Args (line 21) | class Args: function make_env (line 64) | def make_env(env_id, seed, idx, capture_video, run_name): class SoftQNetwork (line 79) | class SoftQNetwork(nn.Module): method __init__ (line 80) | def __init__(self, env): method forward (line 86) | def forward(self, x, a): class Actor (line 98) | class Actor(nn.Module): method __init__ (line 99) | def __init__(self, env): method forward (line 113) | def forward(self, x): method get_action (line 123) | def get_action(self, x): FILE: leanrl/sac_continuous_action_torchcompile.py class Args (line 30) | class Args: function make_env (line 79) | def make_env(env_id, seed, idx, capture_video, run_name): class SoftQNetwork (line 94) | class SoftQNetwork(nn.Module): method __init__ (line 95) | def __init__(self, env, n_act, n_obs, device=None): method forward (line 101) | def forward(self, x, a): class Actor (line 113) | class Actor(nn.Module): method __init__ (line 114) | def __init__(self, env, n_obs, n_act, device=None): method forward (line 130) | def forward(self, x): method get_action (line 140) | def get_action(self, x): function get_q_params (line 188) | def get_q_params(): function batched_qf (line 217) | def batched_qf(params, obs, action, next_q_value=None): function update_main (line 225) | def update_main(data): function update_pol (line 247) | def update_pol(data): function extend_and_sample (line 267) | def extend_and_sample(transition): FILE: leanrl/td3_continuous_action.py class Args (line 21) | class Args: function make_env (line 62) | def make_env(env_id, seed, idx, capture_video, run_name): class QNetwork (line 77) | class QNetwork(nn.Module): method __init__ (line 78) | def __init__(self, env): method forward (line 84) | def forward(self, x, a): class Actor (line 92) | class Actor(nn.Module): method __init__ (line 93) | def __init__(self, env): method forward (line 106) | def forward(self, x): FILE: leanrl/td3_continuous_action_jax.py class Args (line 23) | class Args: function make_env (line 60) | def make_env(env_id, seed, idx, capture_video, run_name): class QNetwork (line 75) | class QNetwork(nn.Module): method __call__ (line 77) | def __call__(self, x: jnp.ndarray, a: jnp.ndarray): class Actor (line 87) | class Actor(nn.Module): method __call__ (line 93) | def __call__(self, x): class TrainState (line 104) | class TrainState(TrainState): function update_critic (line 178) | def update_critic( function update_actor (line 222) | def update_actor( FILE: leanrl/td3_continuous_action_torchcompile.py class Args (line 28) | class Args: function make_env (line 75) | def make_env(env_id, seed, idx, capture_video, run_name): class QNetwork (line 90) | class QNetwork(nn.Module): method __init__ (line 91) | def __init__(self, n_obs, n_act, device=None): method forward (line 97) | def forward(self, x, a): class Actor (line 105) | class Actor(nn.Module): method __init__ (line 106) | def __init__(self, n_obs, n_act, env, exploration_noise=1, device=None): method forward (line 122) | def forward(self, obs): method explore (line 128) | def explore(self, obs): function get_params_qnet (line 165) | def get_params_qnet(): function get_params_actor (line 178) | def get_params_actor(actor): function batched_qf (line 200) | def batched_qf(params, obs, action, next_q_value=None): function update_main (line 212) | def update_main(data): function update_pol (line 236) | def update_pol(data): function extend_and_sample (line 245) | def extend_and_sample(transition): FILE: tests/test_atari.py function test_ppo (line 4) | def test_ppo(): function test_ppo_envpool (line 12) | def test_ppo_envpool(): function test_ppo_atari_envpool_torchcompile (line 20) | def test_ppo_atari_envpool_torchcompile(): function test_ppo_atari_envpool_xla_jax (line 28) | def test_ppo_atari_envpool_xla_jax(): FILE: tests/test_dqn.py function test_dqn (line 4) | def test_dqn(): function test_dqn_jax (line 12) | def test_dqn_jax(): function test_dqn_torchcompile (line 20) | def test_dqn_torchcompile(): FILE: tests/test_ppo_continuous.py function test_ppo_continuous_action (line 4) | def test_ppo_continuous_action(): function test_ppo_continuous_action_torchcompile (line 12) | def test_ppo_continuous_action_torchcompile(): FILE: tests/test_sac_continuous.py function test_sac_continuous_action (line 4) | def test_sac_continuous_action(): function test_sac_continuous_action_torchcompile (line 12) | def test_sac_continuous_action_torchcompile(): FILE: tests/test_td3_continuous.py function test_td3_continuous_action (line 4) | def test_td3_continuous_action(): function test_td3_continuous_action_jax (line 12) | def test_td3_continuous_action_jax(): function test_td3_continuous_action_torchcompile (line 20) | def test_td3_continuous_action_torchcompile():