SYMBOL INDEX (837 symbols across 84 files) FILE: examples/aloha_real/convert_aloha_data_to_lerobot.py class DatasetConfig (line 23) | class DatasetConfig: function create_empty_dataset (line 34) | def create_empty_dataset( function get_cameras (line 128) | def get_cameras(hdf5_files: list[Path]) -> list[str]: function has_velocity (line 134) | def has_velocity(hdf5_files: list[Path]) -> bool: function has_effort (line 139) | def has_effort(hdf5_files: list[Path]) -> bool: function load_raw_images_per_camera (line 144) | def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dic... function load_raw_episode_data (line 165) | def load_raw_episode_data( function populate_dataset (line 193) | def populate_dataset( function port_aloha (line 229) | def port_aloha( FILE: examples/aloha_real/env.py class AlohaRealEnvironment (line 11) | class AlohaRealEnvironment(_environment.Environment): method __init__ (line 14) | def __init__( method reset (line 27) | def reset(self) -> None: method is_episode_complete (line 31) | def is_episode_complete(self) -> bool: method get_observation (line 35) | def get_observation(self) -> dict: method apply_action (line 56) | def apply_action(self, action: dict) -> None: FILE: examples/aloha_real/main.py class Args (line 14) | class Args: function main (line 24) | def main(args: Args) -> None: FILE: examples/aloha_real/real_env.py class RealEnv (line 18) | class RealEnv: method __init__ (line 40) | def __init__(self, init_node, *, reset_position: Optional[List[float]]... method setup_robots (line 62) | def setup_robots(self): method get_qpos (line 66) | def get_qpos(self): method get_qvel (line 79) | def get_qvel(self): method get_effort (line 88) | def get_effort(self): method get_images (line 95) | def get_images(self): method set_gripper_pose (line 98) | def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_... method _reset_joints (line 109) | def _reset_joints(self): method _reset_gripper (line 114) | def _reset_gripper(self): method get_observation (line 128) | def get_observation(self): method get_reward (line 136) | def get_reward(self): method reset (line 139) | def reset(self, *, fake=False): method step (line 150) | def step(self, action): function get_action (line 163) | def get_action(master_bot_left, master_bot_right): function make_real_env (line 175) | def make_real_env(init_node, *, reset_position: Optional[List[float]] = ... FILE: examples/aloha_real/robot_utils.py class ImageRecorder (line 19) | class ImageRecorder: method __init__ (line 20) | def __init__(self, init_node=True, is_debug=False): method image_cb (line 48) | def image_cb(self, cam_name, data): method image_cb_cam_high (line 72) | def image_cb_cam_high(self, data): method image_cb_cam_low (line 76) | def image_cb_cam_low(self, data): method image_cb_cam_left_wrist (line 80) | def image_cb_cam_left_wrist(self, data): method image_cb_cam_right_wrist (line 84) | def image_cb_cam_right_wrist(self, data): method get_images (line 88) | def get_images(self): method print_diagnostics (line 100) | def print_diagnostics(self): class Recorder (line 112) | class Recorder: method __init__ (line 113) | def __init__(self, side, init_node=True, is_debug=False): method puppet_state_cb (line 141) | def puppet_state_cb(self, data): method puppet_arm_commands_cb (line 149) | def puppet_arm_commands_cb(self, data): method puppet_gripper_commands_cb (line 154) | def puppet_gripper_commands_cb(self, data): method print_diagnostics (line 159) | def print_diagnostics(self): function get_arm_joint_positions (line 172) | def get_arm_joint_positions(bot): function get_arm_gripper_positions (line 176) | def get_arm_gripper_positions(bot): function move_arms (line 180) | def move_arms(bot_list, target_pose_list, move_time=1): function move_grippers (line 193) | def move_grippers(bot_list, target_pose_list, move_time): function setup_puppet_bot (line 214) | def setup_puppet_bot(bot): function setup_master_bot (line 221) | def setup_master_bot(bot): function set_standard_pid_gains (line 227) | def set_standard_pid_gains(bot): function set_low_pid_gains (line 232) | def set_low_pid_gains(bot): function torque_off (line 237) | def torque_off(bot): function torque_on (line 242) | def torque_on(bot): function sync_puppet_to_master (line 248) | def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_... FILE: examples/aloha_real/video_display.py class VideoDisplay (line 7) | class VideoDisplay(_subscriber.Subscriber): method __init__ (line 10) | def __init__(self) -> None: method on_episode_start (line 15) | def on_episode_start(self) -> None: method on_step (line 21) | def on_step(self, observation: dict, action: dict) -> None: method on_episode_end (line 34) | def on_episode_end(self) -> None: FILE: examples/aloha_sim/env.py class AlohaSimEnvironment (line 9) | class AlohaSimEnvironment(_environment.Environment): method __init__ (line 12) | def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed... method reset (line 23) | def reset(self) -> None: method is_episode_complete (line 30) | def is_episode_complete(self) -> bool: method get_observation (line 34) | def get_observation(self) -> dict: method apply_action (line 41) | def apply_action(self, action: dict) -> None: method _convert_observation (line 47) | def _convert_observation(self, gym_obs: dict) -> dict: FILE: examples/aloha_sim/main.py class Args (line 15) | class Args: function main (line 29) | def main(args: Args) -> None: FILE: examples/aloha_sim/saver.py class VideoSaver (line 10) | class VideoSaver(_subscriber.Subscriber): method __init__ (line 13) | def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None: method on_episode_start (line 20) | def on_episode_start(self) -> None: method on_step (line 24) | def on_step(self, observation: dict, action: dict) -> None: method on_episode_end (line 30) | def on_episode_end(self) -> None: FILE: examples/convert_jax_model_to_pytorch.py function slice_paligemma_state_dict (line 50) | def slice_paligemma_state_dict(state_dict, config): function slice_gemma_state_dict (line 271) | def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint... function slice_initial_orbax_checkpoint (line 396) | def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precisio... function load_jax_model_and_print_keys (line 408) | def load_jax_model_and_print_keys(checkpoint_dir: str): function convert_pi0_checkpoint (line 422) | def convert_pi0_checkpoint( function main (line 558) | def main( FILE: examples/droid/convert_droid_data_to_lerobot.py function resize_image (line 32) | def resize_image(image, size): function main (line 37) | def main(data_dir: str, *, push_to_hub: bool = False): function get_camera_type (line 180) | def get_camera_type(cam_id): class MP4Reader (line 187) | class MP4Reader: method __init__ (line 188) | def __init__(self, filepath, serial_number): method set_reading_parameters (line 198) | def set_reading_parameters( method get_frame_resolution (line 214) | def get_frame_resolution(self): method get_frame_count (line 219) | def get_frame_count(self): method set_frame_index (line 224) | def set_frame_index(self, index): method _process_frame (line 235) | def _process_frame(self, frame): method read_camera (line 241) | def read_camera(self, ignore_data=False, correct_timestamp=None): # n... method disable_camera (line 269) | def disable_camera(self): class RecordedMultiCameraWrapper (line 274) | class RecordedMultiCameraWrapper: method __init__ (line 275) | def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006 method read_cameras (line 296) | def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict... function get_hdf5_length (line 329) | def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006 function load_hdf5_to_dict (line 351) | def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006 class TrajectoryReader (line 369) | class TrajectoryReader: method __init__ (line 370) | def __init__(self, filepath, read_images=True): # noqa: FBT002 method length (line 378) | def length(self): method read_timestep (line 381) | def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006 method close (line 400) | def close(self): function load_trajectory (line 404) | def load_trajectory( FILE: examples/droid/main.py class Args (line 27) | class Args: function prevent_keyboard_interrupt (line 55) | def prevent_keyboard_interrupt(): function main (line 73) | def main(args: Args): function _extract_observation (line 198) | def _extract_observation(args: Args, obs_dict, *, save_to_disk=False): FILE: examples/libero/convert_libero_data_to_lerobot.py function main (line 37) | def main(data_dir: str, *, push_to_hub: bool = False): FILE: examples/libero/main.py class Args (line 22) | class Args: function eval_libero (line 48) | def eval_libero(args: Args) -> None: function _get_libero_env (line 189) | def _get_libero_env(task, resolution, seed): function _quat2axisangle (line 199) | def _quat2axisangle(quat): FILE: examples/simple_client/main.py class EnvMode (line 17) | class EnvMode(enum.Enum): class Args (line 27) | class Args: class TimingRecorder (line 44) | class TimingRecorder: method __init__ (line 47) | def __init__(self) -> None: method record (line 50) | def record(self, key: str, time_ms: float) -> None: method get_stats (line 56) | def get_stats(self, key: str) -> dict[str, float]: method print_all_stats (line 70) | def print_all_stats(self) -> None: method write_parquet (line 109) | def write_parquet(self, path: pathlib.Path) -> None: function main (line 117) | def main(args: Args) -> None: function _random_observation_aloha (line 153) | def _random_observation_aloha() -> dict: function _random_observation_droid (line 166) | def _random_observation_droid() -> dict: function _random_observation_libero (line 176) | def _random_observation_libero() -> dict: FILE: packages/openpi-client/src/openpi_client/action_chunk_broker.py class ActionChunkBroker (line 10) | class ActionChunkBroker(_base_policy.BasePolicy): method __init__ (line 19) | def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int): method infer (line 27) | def infer(self, obs: Dict) -> Dict: # noqa: UP006 method reset (line 47) | def reset(self) -> None: FILE: packages/openpi-client/src/openpi_client/base_policy.py class BasePolicy (line 5) | class BasePolicy(abc.ABC): method infer (line 7) | def infer(self, obs: Dict) -> Dict: method reset (line 10) | def reset(self) -> None: FILE: packages/openpi-client/src/openpi_client/image_tools.py function convert_to_uint8 (line 5) | def convert_to_uint8(img: np.ndarray) -> np.ndarray: function resize_with_pad (line 15) | def resize_with_pad(images: np.ndarray, height: int, width: int, method=... function _resize_with_pad_pil (line 38) | def _resize_with_pad_pil(image: Image.Image, height: int, width: int, me... FILE: packages/openpi-client/src/openpi_client/image_tools_test.py function test_resize_with_pad_shapes (line 6) | def test_resize_with_pad_shapes(): FILE: packages/openpi-client/src/openpi_client/msgpack_numpy.py function pack_array (line 21) | def pack_array(obj): function unpack_array (line 43) | def unpack_array(obj): FILE: packages/openpi-client/src/openpi_client/msgpack_numpy_test.py function _check (line 8) | def _check(expected, actual): function test_pack_unpack (line 42) | def test_pack_unpack(data): FILE: packages/openpi-client/src/openpi_client/runtime/agent.py class Agent (line 4) | class Agent(abc.ABC): method get_action (line 12) | def get_action(self, observation: dict) -> dict: method reset (line 16) | def reset(self) -> None: FILE: packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py class PolicyAgent (line 7) | class PolicyAgent(_agent.Agent): method __init__ (line 10) | def __init__(self, policy: _base_policy.BasePolicy) -> None: method get_action (line 14) | def get_action(self, observation: dict) -> dict: method reset (line 17) | def reset(self) -> None: FILE: packages/openpi-client/src/openpi_client/runtime/environment.py class Environment (line 4) | class Environment(abc.ABC): method reset (line 12) | def reset(self) -> None: method is_episode_complete (line 19) | def is_episode_complete(self) -> bool: method get_observation (line 27) | def get_observation(self) -> dict: method apply_action (line 31) | def apply_action(self, action: dict) -> None: FILE: packages/openpi-client/src/openpi_client/runtime/runtime.py class Runtime (line 10) | class Runtime: method __init__ (line 13) | def __init__( method run (line 32) | def run(self) -> None: method run_in_new_thread (line 40) | def run_in_new_thread(self) -> threading.Thread: method mark_episode_complete (line 46) | def mark_episode_complete(self) -> None: method _run_episode (line 50) | def _run_episode(self) -> None: method _step (line 80) | def _step(self) -> None: FILE: packages/openpi-client/src/openpi_client/runtime/subscriber.py class Subscriber (line 4) | class Subscriber(abc.ABC): method on_episode_start (line 11) | def on_episode_start(self) -> None: method on_step (line 15) | def on_step(self, observation: dict, action: dict) -> None: method on_episode_end (line 19) | def on_episode_end(self) -> None: FILE: packages/openpi-client/src/openpi_client/websocket_client_policy.py class WebsocketClientPolicy (line 12) | class WebsocketClientPolicy(_base_policy.BasePolicy): method __init__ (line 18) | def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, ... method get_server_metadata (line 29) | def get_server_metadata(self) -> Dict: method _wait_for_server (line 32) | def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConne... method infer (line 47) | def infer(self, obs: Dict) -> Dict: # noqa: UP006 method reset (line 57) | def reset(self) -> None: FILE: scripts/compute_norm_stats.py class RemoveStrings (line 19) | class RemoveStrings(transforms.DataTransformFn): method __call__ (line 20) | def __call__(self, x: dict) -> dict: function create_torch_dataloader (line 24) | def create_torch_dataloader( function create_rlds_dataloader (line 60) | def create_rlds_dataloader( function main (line 89) | def main(config_name: str, max_frames: int | None = None): FILE: scripts/serve_policy.py class EnvMode (line 14) | class EnvMode(enum.Enum): class Checkpoint (line 24) | class Checkpoint: class Default (line 34) | class Default: class Args (line 39) | class Args: function create_default_policy (line 79) | def create_default_policy(env: EnvMode, *, default_prompt: str | None = ... function create_policy (line 88) | def create_policy(args: Args) -> _policy.Policy: function main (line 99) | def main(args: Args) -> None: FILE: scripts/train.py function init_logging (line 31) | def init_logging(): function init_wandb (line 50) | def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code:... function _load_weights_and_validate (line 73) | def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, par... function init_train_state (line 85) | def init_train_state( function train_step (line 137) | def train_step( function main (line 194) | def main(config: _config.TrainConfig): FILE: scripts/train_pytorch.py function init_logging (line 50) | def init_logging(): function init_wandb (line 72) | def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: ... function setup_ddp (line 94) | def setup_ddp(): function cleanup_ddp (line 112) | def cleanup_ddp(): function set_seed (line 118) | def set_seed(seed: int, local_rank: int): function build_datasets (line 125) | def build_datasets(config: _config.TrainConfig): function get_model_state_dict (line 131) | def get_model_state_dict(model): function get_model_parameters (line 140) | def get_model_parameters(model): function save_checkpoint (line 149) | def save_checkpoint(model, optimizer, global_step, config, is_main, data... function load_checkpoint (line 197) | def load_checkpoint(model, optimizer, checkpoint_dir, device): function get_latest_checkpoint_step (line 274) | def get_latest_checkpoint_step(checkpoint_dir): function log_memory_usage (line 284) | def log_memory_usage(device, step, phase="unknown"): function train_loop (line 309) | def train_loop(config: _config.TrainConfig): function main (line 625) | def main(): FILE: scripts/train_test.py function test_train (line 15) | def test_train(tmp_path: pathlib.Path, config_name: str): FILE: src/openpi/conftest.py function set_jax_cpu_backend_if_no_gpu (line 7) | def set_jax_cpu_backend_if_no_gpu() -> None: function pytest_configure (line 16) | def pytest_configure(config: pytest.Config) -> None: FILE: src/openpi/models/gemma.py class Config (line 45) | class Config: function get_config (line 58) | def get_config(variant: Variant) -> Config: class RMSNorm (line 113) | class RMSNorm(nn.Module): method __call__ (line 115) | def __call__(self, x, cond): class Embedder (line 135) | class Embedder(nn.Module): method setup (line 141) | def setup(self): method encode (line 148) | def encode(self, x): method decode (line 153) | def decode(self, x): class Attention (line 158) | class Attention(nn.Module): method __call__ (line 164) | def __call__(self, xs, positions, attn_mask, kv_cache): class FeedForward (line 253) | class FeedForward(nn.Module): method __call__ (line 260) | def __call__(self, x): class Block (line 284) | class Block(nn.Module): method __call__ (line 293) | def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, de... class Module (line 340) | class Module(nn.Module): method setup (line 350) | def setup(self): method embed (line 385) | def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array,... method __call__ (line 389) | def __call__( method init (line 413) | def init(self, use_adarms: Sequence[bool]): function _apply_rope (line 424) | def _apply_rope(x, *, positions, max_wavelength=10_000): function _name (line 443) | def _name(name, i): function _gated_residual (line 453) | def _gated_residual(x, y, gate): FILE: src/openpi/models/gemma_fast.py function get_config (line 35) | def get_config(variant): class Einsum (line 77) | class Einsum(nn.Module): method __call__ (line 81) | def __call__(self, eqn, x): class RMSNorm (line 88) | class RMSNorm(nn.Module): method __call__ (line 90) | def __call__(self, x): class Embedder (line 102) | class Embedder(nn.Module): method setup (line 108) | def setup(self): method encode (line 115) | def encode(self, x): method decode (line 120) | def decode(self, x): class Attention (line 125) | class Attention(nn.Module): method setup (line 137) | def setup(self): method _init_cache (line 165) | def _init_cache(self, k, v, cache_size): method _update_cache (line 175) | def _update_cache(self, k, v, idx, k_cache, v_cache): method __call__ (line 186) | def __call__(self, x, positions, attn_mask, kv_cache, decode, determin... class Block (line 228) | class Block(nn.Module): method setup (line 242) | def setup(self): method __call__ (line 261) | def __call__(self, x, kv_cache, positions, attn_mask, decode, determin... class Module (line 279) | class Module(nn.Module): method __call__ (line 303) | def __call__( method init (line 420) | def init(self): function _apply_rope (line 425) | def _apply_rope(x, *, positions, max_wavelength=10_000): FILE: src/openpi/models/lora.py class LoRAConfig (line 12) | class LoRAConfig: method scaling_value (line 29) | def scaling_value(self) -> float: class Einsum (line 33) | class Einsum(nn.Module): method setup (line 43) | def setup(self): method __call__ (line 55) | def __call__(self, eqn: str, x): method _make_lora_eqns (line 67) | def _make_lora_eqns(self, eqn: str) -> tuple[str, str]: class FeedForward (line 88) | class FeedForward(nn.Module): method setup (line 96) | def setup(self): method __call__ (line 124) | def __call__(self, x): method _dot (line 144) | def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array,... FILE: src/openpi/models/lora_test.py function test_lora_einsum_params_shape (line 8) | def test_lora_einsum_params_shape(): function test_lora_einsum_same_output (line 34) | def test_lora_einsum_same_output(): function test_lora_ffn_params_shape (line 53) | def test_lora_ffn_params_shape(): function test_lora_ffn_same_output (line 77) | def test_lora_ffn_same_output(): FILE: src/openpi/models/model.py class ModelType (line 30) | class ModelType(enum.Enum): class Observation (line 83) | class Observation(Generic[ArrayT]): method from_dict (line 110) | def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]": method to_dict (line 131) | def to_dict(self) -> at.PyTree[ArrayT]: function preprocess_observation (line 144) | def preprocess_observation( class BaseModelConfig (line 212) | class BaseModelConfig(abc.ABC): method model_type (line 226) | def model_type(self) -> ModelType: method create (line 230) | def create(self, rng: at.KeyArrayLike) -> "BaseModel": method load (line 233) | def load(self, params: at.Params, *, remove_extra_params: bool = True)... method load_pytorch (line 243) | def load_pytorch(self, train_config, weight_path: str): method inputs_spec (line 250) | def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Ac... method fake_obs (line 253) | def fake_obs(self, batch_size: int = 1) -> Observation: method fake_act (line 257) | def fake_act(self, batch_size: int = 1) -> Actions: class BaseModel (line 263) | class BaseModel(nnx.Module, abc.ABC): method compute_loss (line 273) | def compute_loss( method sample_actions (line 283) | def sample_actions(self, rng: at.KeyArrayLike, observation: Observatio... function restore_params (line 286) | def restore_params( FILE: src/openpi/models/model_test.py function test_pi0_model (line 12) | def test_pi0_model(): function test_pi0_lora_model (line 27) | def test_pi0_lora_model(): function test_pi0_fast_model (line 42) | def test_pi0_fast_model(): function test_pi0_fast_lora_model (line 57) | def test_pi0_fast_lora_model(): function test_model_restore (line 79) | def test_model_restore(): FILE: src/openpi/models/pi0.py function make_attn_mask (line 19) | def make_attn_mask(input_mask, mask_ar): function posemb_sincos (line 48) | def posemb_sincos( class Pi0 (line 66) | class Pi0(_model.BaseModel): method __init__ (line 67) | def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs): method embed_prefix (line 106) | def embed_prefix( method embed_suffix (line 140) | def embed_suffix( method compute_loss (line 189) | def compute_loss( method sample_actions (line 217) | def sample_actions( FILE: src/openpi/models/pi0_config.py class Pi0Config (line 19) | class Pi0Config(_model.BaseModelConfig): method __post_init__ (line 37) | def __post_init__(self): method model_type (line 52) | def model_type(self) -> _model.ModelType: method create (line 58) | def create(self, rng: at.KeyArrayLike) -> "Pi0": method inputs_spec (line 64) | def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observat... method get_freeze_filter (line 88) | def get_freeze_filter(self) -> nnx.filterlib.Filter: FILE: src/openpi/models/pi0_fast.py function make_attn_mask (line 23) | def make_attn_mask(input_mask, mask_ar): function left_to_right_align (line 52) | def left_to_right_align(x, input_mask, attn_mask): function put_along_last_axis (line 67) | def put_along_last_axis(arr, indices, values): class Pi0FASTConfig (line 77) | class Pi0FASTConfig(_model.BaseModelConfig): method model_type (line 93) | def model_type(self) -> _model.ModelType: method create (line 97) | def create(self, rng: at.KeyArrayLike) -> "Pi0FAST": method inputs_spec (line 101) | def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observat... method get_freeze_filter (line 127) | def get_freeze_filter(self) -> nnx.filterlib.Filter: class Pi0FAST (line 134) | class Pi0FAST(_model.BaseModel): method __init__ (line 135) | def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs): method embed_inputs (line 160) | def embed_inputs( method compute_loss (line 198) | def compute_loss( method sample_actions (line 236) | def sample_actions( FILE: src/openpi/models/pi0_test.py function _get_frozen_state (line 7) | def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State: function test_pi0_full_finetune (line 14) | def test_pi0_full_finetune(): function test_pi0_gemma_lora (line 20) | def test_pi0_gemma_lora(): function test_pi0_action_expert_lora (line 29) | def test_pi0_action_expert_lora(): function test_pi0_all_lora (line 40) | def test_pi0_all_lora(): FILE: src/openpi/models/siglip.py function posemb_sincos_2d (line 27) | def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32): function get_posemb (line 40) | def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32): class MlpBlock (line 53) | class MlpBlock(nn.Module): method __call__ (line 61) | def __call__(self, x, deterministic=True): # noqa: FBT002 class Encoder1DBlock (line 75) | class Encoder1DBlock(nn.Module): method __call__ (line 84) | def __call__(self, x, deterministic=True): # noqa: FBT002 class Encoder (line 111) | class Encoder(nn.Module): method __call__ (line 123) | def __call__(self, x, deterministic=True): # noqa: FBT002 class MAPHead (line 164) | class MAPHead(nn.Module): method __call__ (line 172) | def __call__(self, x): class _Module (line 188) | class _Module(nn.Module): method __call__ (line 208) | def __call__(self, image, *, train=False): function Module (line 293) | def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=... function decode_variant (line 298) | def decode_variant(variant): FILE: src/openpi/models/tokenizer.py class PaligemmaTokenizer (line 14) | class PaligemmaTokenizer: method __init__ (line 15) | def __init__(self, max_len: int = 48): method tokenize (line 22) | def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tu... class FASTTokenizer (line 51) | class FASTTokenizer: method __init__ (line 52) | def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "phy... method tokenize (line 64) | def tokenize( method extract_actions (line 119) | def extract_actions(self, tokens: np.ndarray, action_horizon: int, act... method _act_tokens_to_paligemma_tokens (line 136) | def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[in... class BinningTokenizer (line 148) | class BinningTokenizer: method __init__ (line 153) | def __init__(self, max_len: int = 256, n_bins: int = 256): method tokenize (line 164) | def tokenize( method extract_actions (line 222) | def extract_actions(self, tokens: np.ndarray, action_horizon: int, act... method _act_tokens_to_paligemma_tokens (line 240) | def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[in... class FSQTokenizer (line 246) | class FSQTokenizer: method __init__ (line 251) | def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None ... method tokenize (line 300) | def tokenize( method extract_actions (line 345) | def extract_actions(self, tokens: np.ndarray, action_horizon: int, act... method _act_tokens_to_paligemma_tokens (line 368) | def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[in... FILE: src/openpi/models/tokenizer_test.py function test_tokenize (line 6) | def test_tokenize(): function test_fast_tokenizer (line 14) | def test_fast_tokenizer(): FILE: src/openpi/models/utils/fsq_tokenizer.py class FsqCodebook (line 15) | class FsqCodebook(nn.Module): method bins_per_dim (line 23) | def bins_per_dim(self) -> tuple[int]: method place_values (line 37) | def place_values(self) -> jnp.ndarray: method _get_bins_fsq (line 44) | def _get_bins_fsq(target_codebook_size: int) -> tuple[int]: method _get_bins_custom (line 62) | def _get_bins_custom(target_codebook_size: int) -> tuple[int]: method _get_bins_lfq (line 76) | def _get_bins_lfq(target_codebook_size: int) -> tuple[int]: method setup (line 84) | def setup(self): method __call__ (line 88) | def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndar... method encode (line 93) | def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: method decode (line 105) | def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None)... method undigitize (line 117) | def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray: method digitize (line 120) | def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray: method vocab_size (line 124) | def vocab_size(self) -> int: class ResNetDownBlock (line 128) | class ResNetDownBlock(nn.Module): method __call__ (line 135) | def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: class ResNetUpBlock (line 150) | class ResNetUpBlock(nn.Module): method __call__ (line 157) | def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: class LfqCodebookOutput (line 173) | class LfqCodebookOutput: class LookupFreeQuantization (line 181) | class LookupFreeQuantization(nn.Module): method setup (line 185) | def setup(self): method encode (line 192) | def encode(self, z: jnp.ndarray) -> jnp.ndarray: method decode (line 198) | def decode(self, tokens: jnp.ndarray) -> jnp.ndarray: method loss (line 202) | def loss(self, x: jnp.ndarray) -> LfqCodebookOutput: function make_block_causal_attention_matrix (line 238) | def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, b... class GeGLU (line 242) | class GeGLU(Module): method __call__ (line 255) | def __call__(self, inputs: Array) -> Array: class CrossAttentionLayer (line 269) | class CrossAttentionLayer(nn.Module): method __call__ (line 276) | def __call__( function sinusoidal_pe_init (line 327) | def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray: class TokenizerEncoderDecoder (line 341) | class TokenizerEncoderDecoder(nn.Module): method __call__ (line 351) | def __call__( class FsqAttentionTokenizer (line 385) | class FsqAttentionTokenizer(nn.Module): method vocab_size (line 400) | def vocab_size(self) -> int: method setup (line 403) | def setup(self): method tokenize (line 430) | def tokenize( method detokenize (line 441) | def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None =... method loss (line 446) | def loss( method __call__ (line 468) | def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, di... FILE: src/openpi/models/vit.py class IdentityLayer (line 31) | class IdentityLayer(nn.Module): method __call__ (line 35) | def __call__(self, x): class AddPositionEmbs (line 39) | class AddPositionEmbs(nn.Module): method __call__ (line 50) | def __call__(self, inputs): class MlpBlock (line 66) | class MlpBlock(nn.Module): method __call__ (line 78) | def __call__(self, inputs, *, deterministic): class Encoder1DBlock (line 104) | class Encoder1DBlock(nn.Module): method __call__ (line 124) | def __call__(self, inputs, deterministic): class Encoder (line 160) | class Encoder(nn.Module): method __call__ (line 180) | def __call__(self, x, *, train): class VisionTransformer (line 219) | class VisionTransformer(nn.Module): method __call__ (line 235) | def __call__(self, inputs, *, train): FILE: src/openpi/models_pytorch/gemma_pytorch.py class PaliGemmaWithExpertModel (line 12) | class PaliGemmaWithExpertModel(nn.Module): method __init__ (line 13) | def __init__( method to_bfloat16_for_selected_params (line 63) | def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16... method embed_image (line 85) | def embed_image(self, image: torch.Tensor): method embed_language_tokens (line 88) | def embed_language_tokens(self, tokens: torch.Tensor): method forward (line 91) | def forward( FILE: src/openpi/models_pytorch/pi0_pytorch.py function get_safe_dtype (line 14) | def get_safe_dtype(target_dtype, device_type): function create_sinusoidal_pos_embedding (line 25) | def create_sinusoidal_pos_embedding( function sample_beta (line 45) | def sample_beta(alpha, beta, bsize, device): function make_att_2d_masks (line 52) | def make_att_2d_masks(pad_masks, att_masks): class PI0Pytorch (line 84) | class PI0Pytorch(nn.Module): method __init__ (line 85) | def __init__(self, config): method gradient_checkpointing_enable (line 127) | def gradient_checkpointing_enable(self): method gradient_checkpointing_disable (line 136) | def gradient_checkpointing_disable(self): method is_gradient_checkpointing_enabled (line 145) | def is_gradient_checkpointing_enabled(self): method _apply_checkpoint (line 149) | def _apply_checkpoint(self, func, *args, **kwargs): method _prepare_attention_masks_4d (line 157) | def _prepare_attention_masks_4d(self, att_2d_masks): method _preprocess_observation (line 162) | def _preprocess_observation(self, observation, *, train=True): method sample_noise (line 173) | def sample_noise(self, shape, device): method sample_time (line 182) | def sample_time(self, bsize, device): method embed_prefix (line 187) | def embed_prefix( method embed_suffix (line 238) | def embed_suffix(self, state, noisy_actions, timestep): method forward (line 317) | def forward(self, observation, actions, noise=None, time=None) -> Tensor: method sample_actions (line 377) | def sample_actions(self, device, observation, noise=None, num_steps=10... method denoise_step (line 422) | def denoise_step( FILE: src/openpi/models_pytorch/preprocessing_pytorch.py function preprocess_observation_pytorch (line 20) | def preprocess_observation_pytorch( FILE: src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py class GemmaConfig (line 26) | class GemmaConfig(PretrainedConfig): method __init__ (line 115) | def __init__( FILE: src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py class GemmaRMSNorm (line 49) | class GemmaRMSNorm(nn.Module): method __init__ (line 50) | def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int... method _norm (line 66) | def _norm(self, x): method forward (line 73) | def forward(self, x, cond=None): method extra_repr (line 106) | def extra_repr(self): class GemmaMLP (line 113) | class GemmaMLP(nn.Module): method __init__ (line 114) | def __init__(self, config): method forward (line 124) | def forward(self, x): class GemmaRotaryEmbedding (line 129) | class GemmaRotaryEmbedding(nn.Module): method __init__ (line 130) | def __init__(self, config: GemmaConfig, device=None): method forward (line 149) | def forward(self, x, position_ids): function rotate_half (line 163) | def rotate_half(x): function apply_rotary_pos_emb (line 170) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_di... function repeat_kv (line 197) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: function _gated_residual (line 209) | def _gated_residual(x, y, gate): function eager_attention_forward (line 230) | def eager_attention_forward( class GemmaAttention (line 256) | class GemmaAttention(nn.Module): method __init__ (line 259) | def __init__(self, config: GemmaConfig, layer_idx: int): method forward (line 282) | def forward( class GemmaDecoderLayer (line 332) | class GemmaDecoderLayer(GradientCheckpointingLayer): method __init__ (line 333) | def __init__(self, config: GemmaConfig, layer_idx: int): method forward (line 344) | def forward( class GemmaPreTrainedModel (line 388) | class GemmaPreTrainedModel(PreTrainedModel): method _init_weights (line 403) | def _init_weights(self, module): class GemmaModel (line 419) | class GemmaModel(GemmaPreTrainedModel): method __init__ (line 420) | def __init__(self, config: GemmaConfig): method get_input_embeddings (line 438) | def get_input_embeddings(self): method set_input_embeddings (line 441) | def set_input_embeddings(self, value): method forward (line 446) | def forward( class KwargsForCausalLM (line 558) | class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GemmaForCausalLM (line 562) | class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): method __init__ (line 567) | def __init__(self, config): method get_input_embeddings (line 576) | def get_input_embeddings(self): method set_input_embeddings (line 579) | def set_input_embeddings(self, value): method get_output_embeddings (line 582) | def get_output_embeddings(self): method set_output_embeddings (line 585) | def set_output_embeddings(self, new_embeddings): method set_decoder (line 588) | def set_decoder(self, decoder): method get_decoder (line 591) | def get_decoder(self): method forward (line 596) | def forward( class GemmaForSequenceClassification (line 689) | class GemmaForSequenceClassification(GemmaPreTrainedModel): method __init__ (line 690) | def __init__(self, config): method get_input_embeddings (line 699) | def get_input_embeddings(self): method set_input_embeddings (line 702) | def set_input_embeddings(self, value): method forward (line 707) | def forward( class GemmaForTokenClassification (line 781) | class GemmaForTokenClassification(GemmaPreTrainedModel): method __init__ (line 782) | def __init__(self, config): method get_input_embeddings (line 798) | def get_input_embeddings(self): method set_input_embeddings (line 801) | def set_input_embeddings(self, value): method forward (line 806) | def forward( FILE: src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py class PaligemmaModelOutputWithPast (line 44) | class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): class PaliGemmaCausalLMOutputWithPast (line 66) | class PaliGemmaCausalLMOutputWithPast(ModelOutput): class PaliGemmaMultiModalProjector (line 91) | class PaliGemmaMultiModalProjector(nn.Module): method __init__ (line 92) | def __init__(self, config: PaliGemmaConfig): method forward (line 96) | def forward(self, image_features): class PaliGemmaPreTrainedModel (line 103) | class PaliGemmaPreTrainedModel(PreTrainedModel): method _init_weights (line 117) | def _init_weights(self, module): class PaliGemmaModel (line 133) | class PaliGemmaModel(PaliGemmaPreTrainedModel): method __init__ (line 138) | def __init__(self, config: PaliGemmaConfig): method get_input_embeddings (line 151) | def get_input_embeddings(self): method set_input_embeddings (line 155) | def set_input_embeddings(self, value): method set_decoder (line 158) | def set_decoder(self, decoder): method get_decoder (line 161) | def get_decoder(self): method _update_causal_mask (line 164) | def _update_causal_mask( method get_image_features (line 232) | def get_image_features(self, pixel_values: torch.FloatTensor): method forward (line 249) | def forward( class KwargsForCausalLM (line 372) | class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class PaliGemmaForConditionalGeneration (line 380) | class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, Genera... method __init__ (line 389) | def __init__(self, config: PaliGemmaConfig): method get_input_embeddings (line 395) | def get_input_embeddings(self): method set_input_embeddings (line 398) | def set_input_embeddings(self, value): method get_output_embeddings (line 401) | def get_output_embeddings(self): method set_output_embeddings (line 404) | def set_output_embeddings(self, new_embeddings): method set_decoder (line 407) | def set_decoder(self, decoder): method get_decoder (line 410) | def get_decoder(self): method get_image_features (line 413) | def get_image_features(self, pixel_values): method language_model (line 418) | def language_model(self): method vision_tower (line 422) | def vision_tower(self): method multi_modal_projector (line 426) | def multi_modal_projector(self): method forward (line 431) | def forward( method prepare_inputs_for_generation (line 519) | def prepare_inputs_for_generation( method _prepare_4d_causal_attention_mask_with_cache_position (line 567) | def _prepare_4d_causal_attention_mask_with_cache_position( FILE: src/openpi/models_pytorch/transformers_replace/models/siglip/check.py function check_whether_transformers_replace_is_installed_correctly (line 3) | def check_whether_transformers_replace_is_installed_correctly(): FILE: src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py function _trunc_normal_ (line 41) | def _trunc_normal_(tensor, mean, std, a, b): function trunc_normal_tf_ (line 77) | def trunc_normal_tf_( function variance_scaling_ (line 103) | def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="no... function lecun_normal_ (line 128) | def lecun_normal_(tensor): function default_flax_embed_init (line 132) | def default_flax_embed_init(tensor): class SiglipVisionModelOutput (line 143) | class SiglipVisionModelOutput(ModelOutput): class SiglipTextModelOutput (line 162) | class SiglipTextModelOutput(ModelOutput): class SiglipOutput (line 177) | class SiglipOutput(ModelOutput): method to_tuple (line 205) | def to_tuple(self) -> tuple[Any]: class SiglipVisionEmbeddings (line 212) | class SiglipVisionEmbeddings(nn.Module): method __init__ (line 213) | def __init__(self, config: SiglipVisionConfig): method interpolate_pos_encoding (line 233) | def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: i... method forward (line 271) | def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_enc... class SiglipTextEmbeddings (line 285) | class SiglipTextEmbeddings(nn.Module): method __init__ (line 286) | def __init__(self, config: SiglipTextConfig): method forward (line 298) | def forward( function eager_attention_forward (line 325) | def eager_attention_forward( class SiglipAttention (line 348) | class SiglipAttention(nn.Module): method __init__ (line 351) | def __init__(self, config): method forward (line 371) | def forward( class SiglipMLP (line 420) | class SiglipMLP(nn.Module): method __init__ (line 421) | def __init__(self, config): method forward (line 428) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class SiglipEncoderLayer (line 435) | class SiglipEncoderLayer(GradientCheckpointingLayer): method __init__ (line 436) | def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): method forward (line 444) | def forward( class SiglipPreTrainedModel (line 484) | class SiglipPreTrainedModel(PreTrainedModel): method _init_weights (line 501) | def _init_weights(self, module): class SiglipEncoder (line 549) | class SiglipEncoder(nn.Module): method __init__ (line 558) | def __init__(self, config: SiglipConfig): method forward (line 566) | def forward( class SiglipTextTransformer (line 629) | class SiglipTextTransformer(nn.Module): method __init__ (line 630) | def __init__(self, config: SiglipTextConfig): method forward (line 643) | def forward( class SiglipTextModel (line 697) | class SiglipTextModel(SiglipPreTrainedModel): method __init__ (line 700) | def __init__(self, config: SiglipTextConfig): method get_input_embeddings (line 706) | def get_input_embeddings(self) -> nn.Module: method set_input_embeddings (line 709) | def set_input_embeddings(self, value): method forward (line 714) | def forward( class SiglipVisionTransformer (line 748) | class SiglipVisionTransformer(nn.Module): method __init__ (line 749) | def __init__(self, config: SiglipVisionConfig): method forward (line 763) | def forward( class SiglipMultiheadAttentionPoolingHead (line 799) | class SiglipMultiheadAttentionPoolingHead(nn.Module): method __init__ (line 802) | def __init__(self, config: SiglipVisionConfig): method forward (line 810) | def forward(self, hidden_state): class SiglipVisionModel (line 828) | class SiglipVisionModel(SiglipPreTrainedModel): method __init__ (line 832) | def __init__(self, config: SiglipVisionConfig): method get_input_embeddings (line 840) | def get_input_embeddings(self) -> nn.Module: method forward (line 845) | def forward( class SiglipModel (line 882) | class SiglipModel(SiglipPreTrainedModel): method __init__ (line 885) | def __init__(self, config: SiglipConfig): method get_text_features (line 918) | def get_text_features( method get_image_features (line 964) | def get_image_features( method forward (line 1014) | def forward( class SiglipForImageClassification (line 1117) | class SiglipForImageClassification(SiglipPreTrainedModel): method __init__ (line 1120) | def __init__(self, config: SiglipConfig) -> None: method forward (line 1140) | def forward( FILE: src/openpi/policies/aloha_policy.py function make_aloha_example (line 10) | def make_aloha_example() -> dict: class AlohaInputs (line 25) | class AlohaInputs(transforms.DataTransformFn): method __call__ (line 42) | def __call__(self, data: dict) -> dict: class AlohaOutputs (line 91) | class AlohaOutputs(transforms.DataTransformFn): method __call__ (line 98) | def __call__(self, data: dict) -> dict: function _joint_flip_mask (line 104) | def _joint_flip_mask() -> np.ndarray: function _normalize (line 109) | def _normalize(x, min_val, max_val): function _unnormalize (line 113) | def _unnormalize(x, min_val, max_val): function _gripper_to_angular (line 117) | def _gripper_to_angular(value): function _gripper_from_angular (line 140) | def _gripper_from_angular(value): function _gripper_from_angular_inv (line 153) | def _gripper_from_angular_inv(value): function _decode_aloha (line 159) | def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict: function _decode_state (line 181) | def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np... function _encode_actions (line 190) | def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -... function _encode_actions_inv (line 198) | def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = Fals... FILE: src/openpi/policies/droid_policy.py function make_droid_example (line 10) | def make_droid_example() -> dict: function _parse_image (line 21) | def _parse_image(image) -> np.ndarray: class DroidInputs (line 31) | class DroidInputs(transforms.DataTransformFn): method __call__ (line 35) | def __call__(self, data: dict) -> dict: class DroidOutputs (line 78) | class DroidOutputs(transforms.DataTransformFn): method __call__ (line 79) | def __call__(self, data: dict) -> dict: FILE: src/openpi/policies/libero_policy.py function make_libero_example (line 10) | def make_libero_example() -> dict: function _parse_image (line 20) | def _parse_image(image) -> np.ndarray: class LiberoInputs (line 30) | class LiberoInputs(transforms.DataTransformFn): method __call__ (line 42) | def __call__(self, data: dict) -> dict: class LiberoOutputs (line 87) | class LiberoOutputs(transforms.DataTransformFn): method __call__ (line 95) | def __call__(self, data: dict) -> dict: FILE: src/openpi/policies/policy.py class Policy (line 24) | class Policy(BasePolicy): method __init__ (line 25) | def __init__( method infer (line 68) | def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict... method metadata (line 109) | def metadata(self) -> dict[str, Any]: class PolicyRecorder (line 113) | class PolicyRecorder(_base_policy.BasePolicy): method __init__ (line 116) | def __init__(self, policy: _base_policy.BasePolicy, record_dir: str): method infer (line 125) | def infer(self, obs: dict) -> dict: # type: ignore[misc] FILE: src/openpi/policies/policy_config.py function create_trained_policy (line 16) | def create_trained_policy( FILE: src/openpi/policies/policy_test.py function test_infer (line 10) | def test_infer(): function test_broker (line 21) | def test_broker(): FILE: src/openpi/serving/websocket_policy_server.py class WebsocketPolicyServer (line 15) | class WebsocketPolicyServer: method __init__ (line 21) | def __init__( method serve_forever (line 34) | def serve_forever(self) -> None: method run (line 37) | async def run(self): method _handler (line 48) | async def _handler(self, websocket: _server.ServerConnection): function _health_check (line 86) | def _health_check(connection: _server.ServerConnection, request: _server... FILE: src/openpi/shared/array_typing.py function _check_dataclass_annotations (line 34) | def _check_dataclass_annotations(self, typechecker): function typecheck (line 52) | def typecheck(t: T) -> T: function disable_typechecking (line 57) | def disable_typechecking(): function check_pytree_equality (line 64) | def check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes... FILE: src/openpi/shared/download.py function get_cache_dir (line 25) | def get_cache_dir() -> pathlib.Path: function maybe_download (line 32) | def maybe_download(url: str, *, force_download: bool = False, **kwargs) ... function _download_gsutil (line 108) | def _download_gsutil(url: str, local_path: pathlib.Path, **kwargs) -> None: function _download_fsspec (line 123) | def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None: function _set_permission (line 142) | def _set_permission(path: pathlib.Path, target_permission: int): function _set_folder_permission (line 151) | def _set_folder_permission(folder_path: pathlib.Path) -> None: function _ensure_permissions (line 156) | def _ensure_permissions(path: pathlib.Path) -> None: function _get_mtime (line 189) | def _get_mtime(year: int, month: int, day: int) -> float: function _should_invalidate_cache (line 205) | def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathli... FILE: src/openpi/shared/download_test.py function set_openpi_data_home (line 9) | def set_openpi_data_home(tmp_path_factory): function test_download_local (line 16) | def test_download_local(tmp_path: pathlib.Path): function test_download_gs_dir (line 27) | def test_download_gs_dir(): function test_download_gs (line 37) | def test_download_gs(): function test_download_fsspec (line 47) | def test_download_fsspec(): FILE: src/openpi/shared/image_tools.py function resize_with_pad (line 13) | def resize_with_pad( function resize_with_pad_torch (line 55) | def resize_with_pad_torch( FILE: src/openpi/shared/image_tools_test.py function test_resize_with_pad_shapes (line 6) | def test_resize_with_pad_shapes(): FILE: src/openpi/shared/nnx_utils.py function module_jit (line 15) | def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callabl... class PathRegex (line 47) | class PathRegex: method __post_init__ (line 56) | def __post_init__(self): method __call__ (line 60) | def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool: function state_map (line 66) | def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callab... FILE: src/openpi/shared/normalize.py class NormStats (line 10) | class NormStats: class RunningStats (line 17) | class RunningStats: method __init__ (line 20) | def __init__(self): method update (line 30) | def update(self, batch: np.ndarray) -> None: method get_statistics (line 73) | def get_statistics(self) -> NormStats: method _adjust_histograms (line 88) | def _adjust_histograms(self): method _update_histograms (line 100) | def _update_histograms(self, batch: np.ndarray) -> None: method _compute_quantiles (line 106) | def _compute_quantiles(self, quantiles): class _NormStatsDict (line 120) | class _NormStatsDict(pydantic.BaseModel): function serialize_json (line 124) | def serialize_json(norm_stats: dict[str, NormStats]) -> str: function deserialize_json (line 129) | def deserialize_json(data: str) -> dict[str, NormStats]: function save (line 134) | def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]... function load (line 141) | def load(directory: pathlib.Path | str) -> dict[str, NormStats]: FILE: src/openpi/shared/normalize_test.py function test_normalize_update (line 6) | def test_normalize_update(): function test_serialize_deserialize (line 18) | def test_serialize_deserialize(): function test_multiple_batch_dimensions (line 28) | def test_multiple_batch_dimensions(): FILE: src/openpi/training/checkpoints.py function initialize_checkpoint_dir (line 20) | def initialize_checkpoint_dir( function save_state (line 65) | def save_state( function restore_state (line 89) | def restore_state( function load_norm_stats (line 110) | def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict... class Callback (line 117) | class Callback(Protocol): method __call__ (line 118) | def __call__(self, directory: epath.Path) -> None: ... class CallbackHandler (line 121) | class CallbackHandler(ocp.AsyncCheckpointHandler): method save (line 124) | def save(self, directory: epath.Path, args: CallbackSave): method async_save (line 128) | async def async_save(self, directory: epath.Path, args: CallbackSave) ... method restore (line 131) | def restore(self, *args, **kwargs): class CallbackSave (line 137) | class CallbackSave(ocp.args.CheckpointArgs): class CallbackRestore (line 142) | class CallbackRestore(ocp.args.CheckpointArgs): ... function _split_params (line 145) | def _split_params(state: training_utils.TrainState) -> tuple[training_ut... function _merge_params (line 155) | def _merge_params(train_state: training_utils.TrainState, params: dict[s... FILE: src/openpi/training/config.py class AssetsConfig (line 38) | class AssetsConfig: class DataConfig (line 65) | class DataConfig: class GroupFactory (line 101) | class GroupFactory(Protocol): method __call__ (line 102) | def __call__(self, model_config: _model.BaseModelConfig) -> _transform... class ModelTransformFactory (line 107) | class ModelTransformFactory(GroupFactory): method __call__ (line 113) | def __call__(self, model_config: _model.BaseModelConfig) -> _transform... class DataConfigFactory (line 167) | class DataConfigFactory(abc.ABC): method create (line 176) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM... method create_base_config (line 179) | def create_base_config(self, assets_dirs: pathlib.Path, model_config: ... method _load_norm_stats (line 190) | def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | Non... class FakeDataConfig (line 204) | class FakeDataConfig(DataConfigFactory): method create (line 208) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM... class SimpleDataConfig (line 213) | class SimpleDataConfig(DataConfigFactory): method create (line 220) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM... class LeRobotAlohaDataConfig (line 229) | class LeRobotAlohaDataConfig(DataConfigFactory): method create (line 258) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM... class LeRobotLiberoDataConfig (line 282) | class LeRobotLiberoDataConfig(DataConfigFactory): method create (line 292) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM... class RLDSDroidDataConfig (line 359) | class RLDSDroidDataConfig(DataConfigFactory): method create (line 382) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM... class LeRobotDROIDDataConfig (line 427) | class LeRobotDROIDDataConfig(DataConfigFactory): method create (line 434) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM... class TrainConfig (line 466) | class TrainConfig: method assets_dirs (line 538) | def assets_dirs(self) -> pathlib.Path: method checkpoint_dir (line 543) | def checkpoint_dir(self) -> pathlib.Path: method trainable_filter (line 550) | def trainable_filter(self) -> nnx.filterlib.Filter: method __post_init__ (line 554) | def __post_init__(self) -> None: function cli (line 978) | def cli() -> TrainConfig: function get_config (line 982) | def get_config(config_name: str) -> TrainConfig: FILE: src/openpi/training/data_loader.py class Dataset (line 22) | class Dataset(Protocol[T_co]): method __getitem__ (line 25) | def __getitem__(self, index: SupportsIndex) -> T_co: method __len__ (line 28) | def __len__(self) -> int: class IterableDataset (line 32) | class IterableDataset(Protocol[T_co]): method __iter__ (line 35) | def __iter__(self) -> Iterator[T_co]: method __len__ (line 38) | def __len__(self) -> int: class DataLoader (line 42) | class DataLoader(Protocol[T_co]): method data_config (line 45) | def data_config(self) -> _config.DataConfig: method __iter__ (line 49) | def __iter__(self) -> Iterator[T_co]: class TransformedDataset (line 53) | class TransformedDataset(Dataset[T_co]): method __init__ (line 54) | def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.... method __getitem__ (line 58) | def __getitem__(self, index: SupportsIndex) -> T_co: method __len__ (line 61) | def __len__(self) -> int: class IterableTransformedDataset (line 65) | class IterableTransformedDataset(IterableDataset[T_co]): method __init__ (line 66) | def __init__( method __iter__ (line 77) | def __iter__(self): method __len__ (line 95) | def __len__(self) -> int: class FakeDataset (line 99) | class FakeDataset(Dataset): method __init__ (line 100) | def __init__(self, model_config: _model.BaseModelConfig, num_samples: ... method __getitem__ (line 104) | def __getitem__(self, index: SupportsIndex) -> dict: method __len__ (line 126) | def __len__(self) -> int: function create_torch_dataset (line 130) | def create_torch_dataset( function create_rlds_dataset (line 154) | def create_rlds_dataset( function transform_dataset (line 172) | def transform_dataset(dataset: Dataset, data_config: _config.DataConfig,... function transform_iterable_dataset (line 194) | def transform_iterable_dataset( function create_data_loader (line 223) | def create_data_loader( function create_torch_data_loader (line 271) | def create_torch_data_loader( function create_rlds_data_loader (line 340) | def create_rlds_data_loader( class TorchDataLoader (line 381) | class TorchDataLoader: method __init__ (line 384) | def __init__( method torch_loader (line 449) | def torch_loader(self) -> torch.utils.data.DataLoader: method __iter__ (line 452) | def __iter__(self): function _collate_fn (line 471) | def _collate_fn(items): function _worker_init_fn (line 478) | def _worker_init_fn(worker_id: int) -> None: class RLDSDataLoader (line 486) | class RLDSDataLoader: method __init__ (line 492) | def __init__( method __iter__ (line 515) | def __iter__(self): class DataLoaderImpl (line 530) | class DataLoaderImpl(DataLoader): method __init__ (line 531) | def __init__(self, data_config: _config.DataConfig, data_loader: Torch... method data_config (line 535) | def data_config(self) -> _config.DataConfig: method __iter__ (line 538) | def __iter__(self): FILE: src/openpi/training/data_loader_test.py function test_torch_data_loader (line 10) | def test_torch_data_loader(): function test_torch_data_loader_infinite (line 26) | def test_torch_data_loader_infinite(): function test_torch_data_loader_parallel (line 37) | def test_torch_data_loader_parallel(): function test_with_fake_dataset (line 50) | def test_with_fake_dataset(): function test_with_real_dataset (line 65) | def test_with_real_dataset(): FILE: src/openpi/training/droid_rlds_dataset.py class DroidActionSpace (line 21) | class DroidActionSpace(Enum): class RLDSDataset (line 29) | class RLDSDataset: class DroidRldsDataset (line 36) | class DroidRldsDataset: method __init__ (line 37) | def __init__( method __iter__ (line 242) | def __iter__(self): method __len__ (line 245) | def __len__(self): FILE: src/openpi/training/misc/polaris_config.py function get_polaris_configs (line 18) | def get_polaris_configs(): FILE: src/openpi/training/misc/roboarena_config.py function get_roboarena_configs (line 15) | def get_roboarena_configs(): FILE: src/openpi/training/optimizer.py class LRScheduleConfig (line 11) | class LRScheduleConfig(Protocol): method create (line 12) | def create(self) -> optax.Schedule: ... class CosineDecaySchedule (line 16) | class CosineDecaySchedule(LRScheduleConfig): method create (line 24) | def create(self) -> optax.Schedule: class RsqrtDecaySchedule (line 35) | class RsqrtDecaySchedule(LRScheduleConfig): method create (line 42) | def create(self) -> optax.Schedule: class OptimizerConfig (line 57) | class OptimizerConfig(Protocol): method create (line 58) | def create( class AdamW (line 66) | class AdamW(OptimizerConfig): method create (line 76) | def create( class SGD (line 89) | class SGD(OptimizerConfig): method create (line 96) | def create( function create_optimizer (line 105) | def create_optimizer( FILE: src/openpi/training/sharding.py class _MeshState (line 13) | class _MeshState: function make_mesh (line 17) | def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh: function set_mesh (line 27) | def set_mesh(mesh: jax.sharding.Mesh): function activation_sharding_constraint (line 40) | def activation_sharding_constraint(pytree): function fsdp_sharding (line 48) | def fsdp_sharding( FILE: src/openpi/training/utils.py class TrainState (line 15) | class TrainState: function tree_to_info (line 27) | def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = st... function array_tree_to_info (line 36) | def array_tree_to_info(tree: at.PyTree) -> str: FILE: src/openpi/training/weight_loaders.py class WeightLoader (line 17) | class WeightLoader(Protocol): method load (line 18) | def load(self, params: at.Params) -> at.Params: class NoOpWeightLoader (line 32) | class NoOpWeightLoader(WeightLoader): method load (line 33) | def load(self, params: at.Params) -> at.Params: class CheckpointWeightLoader (line 38) | class CheckpointWeightLoader(WeightLoader): method load (line 50) | def load(self, params: at.Params) -> at.Params: class PaliGemmaWeightLoader (line 58) | class PaliGemmaWeightLoader(WeightLoader): method load (line 65) | def load(self, params: at.Params) -> at.Params: function _merge_params (line 76) | def _merge_params(loaded_params: at.Params, params: at.Params, *, missin... FILE: src/openpi/transforms.py class DataTransformFn (line 24) | class DataTransformFn(Protocol): method __call__ (line 25) | def __call__(self, data: DataDict) -> DataDict: class Group (line 40) | class Group: method push (line 49) | def push(self, *, inputs: Sequence[DataTransformFn] = (), outputs: Seq... class CompositeTransform (line 63) | class CompositeTransform(DataTransformFn): method __call__ (line 68) | def __call__(self, data: DataDict) -> DataDict: function compose (line 74) | def compose(transforms: Sequence[DataTransformFn]) -> DataTransformFn: class RepackTransform (line 80) | class RepackTransform(DataTransformFn): method __call__ (line 99) | def __call__(self, data: DataDict) -> DataDict: class InjectDefaultPrompt (line 105) | class InjectDefaultPrompt(DataTransformFn): method __call__ (line 108) | def __call__(self, data: DataDict) -> DataDict: class Normalize (line 115) | class Normalize(DataTransformFn): method __post_init__ (line 122) | def __post_init__(self): method __call__ (line 126) | def __call__(self, data: DataDict) -> DataDict: method _normalize (line 137) | def _normalize(self, x, stats: NormStats): method _normalize_quantile (line 141) | def _normalize_quantile(self, x, stats: NormStats): class Unnormalize (line 149) | class Unnormalize(DataTransformFn): method __post_init__ (line 154) | def __post_init__(self): method __call__ (line 158) | def __call__(self, data: DataDict) -> DataDict: method _unnormalize (line 170) | def _unnormalize(self, x, stats: NormStats): method _unnormalize_quantile (line 175) | def _unnormalize_quantile(self, x, stats: NormStats): class ResizeImages (line 185) | class ResizeImages(DataTransformFn): method __call__ (line 189) | def __call__(self, data: DataDict) -> DataDict: class SubsampleActions (line 195) | class SubsampleActions(DataTransformFn): method __call__ (line 198) | def __call__(self, data: DataDict) -> DataDict: class DeltaActions (line 204) | class DeltaActions(DataTransformFn): method __call__ (line 212) | def __call__(self, data: DataDict) -> DataDict: class AbsoluteActions (line 226) | class AbsoluteActions(DataTransformFn): method __call__ (line 234) | def __call__(self, data: DataDict) -> DataDict: class TokenizePrompt (line 248) | class TokenizePrompt(DataTransformFn): method __call__ (line 252) | def __call__(self, data: DataDict) -> DataDict: class TokenizeFASTInputs (line 270) | class TokenizeFASTInputs(DataTransformFn): method __call__ (line 273) | def __call__(self, data: DataDict) -> DataDict: class ExtractFASTActions (line 292) | class ExtractFASTActions(DataTransformFn): method __call__ (line 297) | def __call__(self, data: DataDict) -> DataDict: class PromptFromLeRobotTask (line 310) | class PromptFromLeRobotTask(DataTransformFn): method __call__ (line 316) | def __call__(self, data: DataDict) -> DataDict: class PadStatesAndActions (line 328) | class PadStatesAndActions(DataTransformFn): method __call__ (line 333) | def __call__(self, data: DataDict) -> DataDict: function flatten_dict (line 340) | def flatten_dict(tree: at.PyTree) -> dict: function unflatten_dict (line 345) | def unflatten_dict(tree: dict) -> at.PyTree: function transform_dict (line 350) | def transform_dict(patterns: Mapping[str, str | None], tree: at.PyTree) ... function apply_tree (line 404) | def apply_tree( function pad_to_dim (line 423) | def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: fl... function make_bool_mask (line 433) | def make_bool_mask(*dims: int) -> tuple[bool, ...]: function _assert_quantile_stats (line 455) | def _assert_quantile_stats(norm_stats: at.PyTree[NormStats]) -> None: FILE: src/openpi/transforms_test.py function test_repack_transform (line 8) | def test_repack_transform(): function test_delta_actions (line 19) | def test_delta_actions(): function test_delta_actions_noop (line 29) | def test_delta_actions_noop(): function test_absolute_actions (line 42) | def test_absolute_actions(): function test_absolute_actions_noop (line 52) | def test_absolute_actions_noop(): function test_make_bool_mask (line 65) | def test_make_bool_mask(): function test_tokenize_prompt (line 70) | def test_tokenize_prompt(): function test_tokenize_no_prompt (line 81) | def test_tokenize_no_prompt(): function test_transform_dict (line 88) | def test_transform_dict(): function test_extract_prompt_from_task (line 114) | def test_extract_prompt_from_task():