SYMBOL INDEX (5058 symbols across 319 files) FILE: .github/analytics/get_repo_metrics.py function load_query_from_file (line 38) | def load_query_from_file(fname, repo_owner, repo_name) -> str: function send_query (line 47) | def send_query(query, query_type, cursor=None): function get_all_responses (line 93) | def get_all_responses(query, query_type): function parse_single_query (line 109) | def parse_single_query(data, query_type): class GithubGrabber (line 129) | class GithubGrabber: method __init__ (line 134) | def __init__(self, query_fname, query_type, repo_owner, repo_name): method load_query (line 162) | def load_query(self): method get (line 167) | def get(self): function _to_datetime (line 176) | def _to_datetime(date_str: str) -> datetime: function _get_issues_features (line 180) | def _get_issues_features(issues): function _get_pr_features (line 205) | def _get_pr_features(prs): function _start_of_month (line 251) | def _start_of_month(date: datetime) -> datetime: function _shift_n_months (line 255) | def _shift_n_months(date: datetime, n: int) -> datetime: function _rolling_window (line 267) | def _rolling_window( function _process_prs (line 297) | def _process_prs(df: pd.DataFrame) -> pd.Series: function _process_issues (line 306) | def _process_issues(df: pd.DataFrame) -> pd.Series: function main (line 323) | def main(_): FILE: benchmarks/nnx_graph_overhead.py class Linear (line 35) | class Linear(nnx.Module): method __init__ (line 36) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 40) | def __call__(self, x): class Block (line 44) | class Block(nnx.Module): method __init__ (line 45) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 49) | def __call__(self, x): class Count (line 53) | class Count(nnx.Variable): class MLP (line 57) | class MLP(nnx.Module): method __init__ (line 58) | def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): method __call__ (line 66) | def __call__(self, x): function main (line 75) | def main(argv): FILE: benchmarks/nnx_mlpmixer_training.py class MlpBlock (line 41) | class MlpBlock(nnx.Module): method __init__ (line 42) | def __init__(self, din: int, mlp_dim: int, rngs: nnx.Rngs): method __call__ (line 47) | def __call__(self, x): class MixerBlock (line 51) | class MixerBlock(nnx.Module): method __init__ (line 52) | def __init__( method __call__ (line 67) | def __call__(self, x): class MlpMixer (line 77) | class MlpMixer(nnx.Module): method __init__ (line 78) | def __init__( method __call__ (line 111) | def __call__(self, *, x, t): function main (line 129) | def main(argv): FILE: benchmarks/nnx_simple_training.py function dataset (line 38) | def dataset(X, Y, batch_size): class Linear (line 44) | class Linear(nnx.Module): method __init__ (line 45) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 49) | def __call__(self, x): class Block (line 53) | class Block(nnx.Module): method __init__ (line 54) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 58) | def __call__(self, x): class Count (line 62) | class Count(nnx.Variable): class MLP (line 66) | class MLP(nnx.Module): method __init__ (line 67) | def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): method __call__ (line 75) | def __call__(self, x): function main (line 84) | def main(argv): FILE: benchmarks/nnx_state_traversal.py class NestedClass (line 34) | class NestedClass(nnx.Module): method __init__ (line 35) | def __init__(self, width, depth): function main (line 42) | def main(argv): FILE: benchmarks/tracing/gemma.py function rsqrt_schedule (line 32) | def rsqrt_schedule(init_value: float, shift: int = 0): function create_learning_rate_schedule (line 38) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): function compute_weighted_cross_entropy (line 52) | def compute_weighted_cross_entropy( function compute_weighted_accuracy (line 82) | def compute_weighted_accuracy(logits, targets, weights=None): function compute_metrics (line 97) | def compute_metrics(logits, labels, weights, label_smoothing=0.0): function train_step (line 110) | def train_step( function get_fake_batch (line 157) | def get_fake_batch(batch_size: int) -> Any: function get_apply_fn_and_args (line 172) | def get_apply_fn_and_args( function test_flax_gemma_trace (line 243) | def test_flax_gemma_trace(state): function test_flax_gemma_lower (line 251) | def test_flax_gemma_lower(state): FILE: benchmarks/tracing/imagenet.py class TrainState (line 35) | class TrainState(train_state.TrainState): function create_model (line 40) | def create_model(*, model_cls, half_precision, **kwargs): function initialized (line 52) | def initialized(key, image_size, model): function cross_entropy_loss (line 63) | def cross_entropy_loss(logits, labels): function create_train_state (line 69) | def create_train_state( function get_fake_batch (line 93) | def get_fake_batch(batch_size: int = 128) -> dict[str, jnp.ndarray]: class BenchmarkResNet (line 103) | class BenchmarkResNet(models.ResNet): method __call__ (line 106) | def __call__(self, x, train: bool = True): function get_apply_fn_and_args (line 143) | def get_apply_fn_and_args( function bench_train_step (line 177) | def bench_train_step(state, batch, learning_rate_fn): function test_flax_imagenet_trace (line 240) | def test_flax_imagenet_trace(state): function test_flax_imagenet_lower (line 248) | def test_flax_imagenet_lower(state): FILE: benchmarks/tracing/lm1b.py function rsqrt_schedule (line 33) | def rsqrt_schedule(init_value: float, shift: int = 0): function create_learning_rate_schedule (line 39) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): function compute_weighted_cross_entropy (line 53) | def compute_weighted_cross_entropy( function compute_weighted_accuracy (line 83) | def compute_weighted_accuracy(logits, targets, weights=None): function get_fake_batch (line 98) | def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, jnp.n... function bench_train_step (line 122) | def bench_train_step(state, batch, config, learning_rate_fn): function get_apply_fn_and_args (line 170) | def get_apply_fn_and_args( function test_flax_lm1b_trace (line 239) | def test_flax_lm1b_trace(state): function test_flax_lm1b_lower (line 247) | def test_flax_lm1b_lower(state): FILE: benchmarks/tracing/mnist.py class CNN (line 29) | class CNN(nnx.Module): method __init__ (line 31) | def __init__(self, rngs: nnx.Rngs): method __call__ (line 42) | def __call__(self, x, rngs: nnx.Rngs): function loss_fn (line 51) | def loss_fn(model: CNN, batch, rngs): function get_fake_batch (line 59) | def get_fake_batch(batch_size: int) -> dict[str, Any]: function get_apply_fn_and_args (line 66) | def get_apply_fn_and_args( function test_flax_mnist_trace (line 82) | def test_flax_mnist_trace(state): function test_flax_mnist_lower (line 90) | def test_flax_mnist_lower(state): FILE: benchmarks/tracing/nlp_seq.py function create_learning_rate_scheduler (line 33) | def create_learning_rate_scheduler( function compute_weighted_cross_entropy (line 71) | def compute_weighted_cross_entropy(logits, targets, weights=None): function compute_weighted_accuracy (line 87) | def compute_weighted_accuracy(logits, targets, weights=None): function get_fake_batch (line 102) | def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, jnp.n... function bench_train_step (line 129) | def bench_train_step(state, batch, config, learning_rate_fn): function get_apply_fn_and_args (line 176) | def get_apply_fn_and_args( function test_flax_nlp_seq_trace (line 227) | def test_flax_nlp_seq_trace(state): function test_flax_nlp_seq_lower (line 235) | def test_flax_nlp_seq_lower(state): FILE: benchmarks/tracing/ogbg_molpcba.py function create_model (line 33) | def create_model( function create_optimizer (line 62) | def create_optimizer( function binary_cross_entropy_with_mask (line 74) | def binary_cross_entropy_with_mask( function predictions_match_labels (line 88) | def predictions_match_labels( function replace_globals (line 96) | def replace_globals(graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: function get_predicted_logits (line 100) | def get_predicted_logits( function get_valid_mask (line 110) | def get_valid_mask( class TrainMetrics (line 119) | class TrainMetrics(metrics.Collection): function ogbg_train_step (line 125) | def ogbg_train_step( function get_fake_graphs (line 157) | def get_fake_graphs(config: ml_collections.ConfigDict) -> jraph.GraphsTu... function get_apply_fn_and_args (line 184) | def get_apply_fn_and_args( function test_flax_ogbg_molpcba_trace (line 212) | def test_flax_ogbg_molpcba_trace(state): function test_flax_ogbg_molpcba_lower (line 220) | def test_flax_ogbg_molpcba_lower(state): FILE: benchmarks/tracing/ppo.py function get_fake_batch (line 28) | def get_fake_batch(batch_size: int = 256) -> tuple[jnp.ndarray, ...]: function get_apply_fn_and_args (line 69) | def get_apply_fn_and_args( function test_flax_ppo_trace (line 116) | def test_flax_ppo_trace(state): function test_flax_ppo_lower (line 124) | def test_flax_ppo_lower(state): FILE: benchmarks/tracing/seq2seq.py function cross_entropy_loss (line 33) | def cross_entropy_loss(logits, labels, lengths): function compute_metrics (line 39) | def compute_metrics(logits, labels, eos_id): function seq2seq_train_step (line 55) | def seq2seq_train_step(state, batch, lstm_rng, eos_id): function get_fake_batch (line 79) | def get_fake_batch(batch_size: int, ctable: CharacterTable) -> dict[str,... function get_apply_fn_and_args (line 83) | def get_apply_fn_and_args( function test_flax_seq2seq_trace (line 114) | def test_flax_seq2seq_trace(state): function test_flax_seq2seq_lower (line 122) | def test_flax_seq2seq_lower(state): FILE: benchmarks/tracing/sst2.py function sigmoid_cross_entropy_with_logits (line 33) | def sigmoid_cross_entropy_with_logits(*, labels: Array, logits: Array) -... function model_from_config (line 41) | def model_from_config(config: ml_collections.ConfigDict): function get_initial_params (line 54) | def get_initial_params(rng, model): function create_train_state (line 61) | def create_train_state(rng, config: ml_collections.ConfigDict, model): function compute_metrics (line 71) | def compute_metrics(*, labels: Array, logits: Array): function train_step (line 84) | def train_step( function get_fake_batch (line 119) | def get_fake_batch(batch_size: int) -> dict[str, Any]: function get_apply_fn_and_args (line 134) | def get_apply_fn_and_args( function test_flax_sst2_trace (line 152) | def test_flax_sst2_trace(state): function test_flax_sst2_lower (line 160) | def test_flax_sst2_lower(state): FILE: benchmarks/tracing/tracing_benchmark.py function clear_caches (line 34) | def clear_caches(state): function benchmark_tracing (line 40) | def benchmark_tracing( function benchmark_lowering (line 54) | def benchmark_lowering( function run_single_example (line 70) | def run_single_example( function run_benchmarks (line 91) | def run_benchmarks() -> None: FILE: benchmarks/tracing/vae.py function binary_cross_entropy_with_logits (line 31) | def binary_cross_entropy_with_logits(logits, labels): function kl_divergence (line 39) | def kl_divergence(mean, logvar): function train_step (line 43) | def train_step(state, batch, z_rng, latents): function get_fake_batch (line 57) | def get_fake_batch(batch_size: int) -> Any: function get_apply_fn_and_args (line 61) | def get_apply_fn_and_args( function test_flax_vae_trace (line 83) | def test_flax_vae_trace(state): function test_flax_vae_lower (line 91) | def test_flax_vae_lower(state): FILE: benchmarks/tracing/wmt.py class TrainState (line 34) | class TrainState(train_state.TrainState): function rsqrt_schedule (line 38) | def rsqrt_schedule(init_value: float, shift: int = 0): function create_learning_rate_schedule (line 44) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): function preferred_dtype (line 58) | def preferred_dtype(config): function compute_weighted_cross_entropy (line 68) | def compute_weighted_cross_entropy( function compute_weighted_accuracy (line 98) | def compute_weighted_accuracy(logits, targets, weights=None): function compute_metrics (line 113) | def compute_metrics(logits, labels, weights, label_smoothing=0.0): function wmt_train_step (line 126) | def wmt_train_step( function get_fake_batch (line 202) | def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, Any]: function get_apply_fn_and_args (line 224) | def get_apply_fn_and_args( function test_flax_wmt_trace (line 298) | def test_flax_wmt_trace(state): function test_flax_wmt_lower (line 306) | def test_flax_wmt_lower(state): FILE: docs/_ext/codediff.py class CodeDiffParser (line 39) | class CodeDiffParser: method parse (line 40) | def parse( method _code_block (line 140) | def _code_block(self, lines): method _tabs (line 156) | def _tabs(self, *contents: tuple[str, list[str]], sync): class CodeDiffDirective (line 172) | class CodeDiffDirective(SphinxDirective): method run (line 182) | def run(self): function setup (line 214) | def setup(app): FILE: docs/_ext/codediff_test.py class CodeDiffTest (line 21) | class CodeDiffTest(parameterized.TestCase): method test_parse (line 22) | def test_parse(self): method test_parse_errors (line 114) | def test_parse_errors(self, input_text, title, groups, error_msg): FILE: docs/_ext/flax_module.py function render_module (line 36) | def render_module(modname: str, qualname: str, app): class FlaxModuleDirective (line 59) | class FlaxModuleDirective(SphinxDirective): method run (line 66) | def run(self): function setup (line 80) | def setup(app): FILE: docs/conf_sphinx_patch.py function generate_autosummary_content (line 32) | def generate_autosummary_content( FILE: docs_nnx/_ext/codediff.py class CodeDiffParser (line 39) | class CodeDiffParser: method parse (line 40) | def parse( method _code_block (line 140) | def _code_block(self, lines): method _tabs (line 156) | def _tabs(self, *contents: tuple[str, list[str]], sync): class CodeDiffDirective (line 172) | class CodeDiffDirective(SphinxDirective): method run (line 182) | def run(self): function setup (line 214) | def setup(app): FILE: docs_nnx/_ext/codediff_test.py class CodeDiffTest (line 21) | class CodeDiffTest(parameterized.TestCase): method test_parse (line 22) | def test_parse(self): method test_parse_errors (line 114) | def test_parse_errors(self, input_text, title, groups, error_msg): FILE: docs_nnx/_ext/flax_module.py function render_module (line 36) | def render_module(modname: str, qualname: str, app): class FlaxModuleDirective (line 59) | class FlaxModuleDirective(SphinxDirective): method run (line 66) | def run(self): function setup (line 80) | def setup(app): FILE: docs_nnx/conf_sphinx_patch.py function generate_autosummary_content (line 32) | def generate_autosummary_content( FILE: examples/cloud/launch_gce.py function generate_startup_file (line 137) | def generate_startup_file(vm_name: str) -> str: function launch_gce (line 162) | def launch_gce(*, vm_name: str, startup_script: str): function print_howto (line 204) | def print_howto(login_args: Sequence[str]): function main (line 232) | def main(_): FILE: examples/gemma/configs/default.py class Config (line 23) | class Config: function get_config (line 132) | def get_config() -> TrainConfig: FILE: examples/gemma/configs/gemma3_4b.py class Config (line 23) | class Config: method replace (line 131) | def replace(self, **kwargs): function get_config (line 135) | def get_config() -> TrainConfig: FILE: examples/gemma/configs/small.py class Config (line 23) | class Config: method replace (line 152) | def replace(self, **kwargs): function get_config (line 156) | def get_config() -> TrainConfig: FILE: examples/gemma/configs/tiny.py class Config (line 23) | class Config: method replace (line 143) | def replace(self, **kwargs): function get_config (line 147) | def get_config() -> TrainConfig: FILE: examples/gemma/helpers.py function _flatten_path (line 29) | def _flatten_path(path: tuple[str | int, ...]) -> str: function module_from_linen_variables (line 41) | def module_from_linen_variables( FILE: examples/gemma/helpers_test.py class ModuleFromLinenVariablesTest (line 30) | class ModuleFromLinenVariablesTest(parameterized.TestCase): method test_same_structure (line 44) | def test_same_structure(self, inputs_shape, num_features, use_bias): method test_different_structure (line 79) | def test_different_structure(self, inputs_shape, num_features, use_bias): FILE: examples/gemma/input_pipeline.py class NormalizeFeatureNamesOp (line 28) | class NormalizeFeatureNamesOp: method __call__ (line 31) | def __call__(self, features: Features) -> Features: function get_raw_dataset (line 38) | def get_raw_dataset(dataset_name: str, split: str) -> tf.data.Dataset: function pack_dataset (line 56) | def pack_dataset( function _pack_with_tf_ops (line 139) | def _pack_with_tf_ops( function shift_data_by_truncation (line 260) | def shift_data_by_truncation(x): function preprocess_data (line 270) | def preprocess_data( function get_datasets (line 324) | def get_datasets( FILE: examples/gemma/input_pipeline_test.py class InputPipelineTest (line 32) | class InputPipelineTest(absltest.TestCase): method setUp (line 34) | def setUp(self): method _get_datasets (line 40) | def _get_datasets(self): method test_train_ds (line 62) | def test_train_ds(self): method test_eval_ds (line 79) | def test_eval_ds(self): FILE: examples/gemma/layers.py class Einsum (line 31) | class Einsum(nnx.Module): method __init__ (line 34) | def __init__( method __call__ (line 46) | def __call__(self, x: ArrayLike) -> Array: method shape (line 50) | def shape(self) -> Shape: class RMSNorm (line 54) | class RMSNorm(nnx.Module): method __init__ (line 57) | def __init__( method __call__ (line 67) | def __call__(self, x: Array) -> Array: FILE: examples/gemma/layers_test.py class EinsumTest (line 25) | class EinsumTest(parameterized.TestCase): method test_einsum (line 40) | def test_einsum(self, inputs_shape, params_shape, eqn, expected_shape): method test_shape (line 55) | def test_shape(self, shape): class RMSNormTest (line 60) | class RMSNormTest(parameterized.TestCase): method test_rmsnorm (line 62) | def test_rmsnorm(self, x, expected): FILE: examples/gemma/main.py function main (line 43) | def main(argv): FILE: examples/gemma/modules.py class AttentionType (line 39) | class AttentionType(enum.Enum): class Embedder (line 44) | class Embedder(nnx.Module): method __init__ (line 47) | def __init__( method encode (line 60) | def encode(self, x: ArrayLike) -> Array: method decode (line 65) | def decode(self, x: ArrayLike) -> Array: method embed_dim (line 69) | def embed_dim(self): method num_embed (line 73) | def num_embed(self): class Attention (line 77) | class Attention(nnx.Module): method __init__ (line 80) | def __init__( method __call__ (line 173) | def __call__( method head_dim (line 282) | def head_dim(self): method num_heads (line 286) | def num_heads(self): method num_kv_heads (line 294) | def num_kv_heads(self): method use_qkv_einsum (line 302) | def use_qkv_einsum(self): method init_cache (line 305) | def init_cache( class FeedForward (line 324) | class FeedForward(nnx.Module): method __init__ (line 327) | def __init__( method __call__ (line 363) | def __call__(self, x: ArrayLike) -> Array: class Block (line 375) | class Block(nnx.Module): method __init__ (line 378) | def __init__( method __call__ (line 507) | def __call__( method init_cache (line 538) | def init_cache( function maybe_with_partitioning (line 551) | def maybe_with_partitioning(fn, axis_rules, axis_rules_args=()): FILE: examples/gemma/modules_test.py class EmbedderTest (line 27) | class EmbedderTest(parameterized.TestCase): method test_encode (line 37) | def test_encode(self, vocab_size, embed_dim, inputs, expected): method test_decode (line 55) | def test_decode(self, vocab_size, embed_dim, inputs, expected): class AttentionTest (line 66) | class AttentionTest(parameterized.TestCase): method test_head_dim (line 76) | def test_head_dim(self, head_dim): method test_use_qkv_einsum (line 101) | def test_use_qkv_einsum( method test_attention (line 133) | def test_attention( method test_sliding_window (line 170) | def test_sliding_window(self, sliding_window_size): class FeedForwardTest (line 211) | class FeedForwardTest(parameterized.TestCase): method test_ffw (line 222) | def test_ffw( class BlockTest (line 243) | class BlockTest(parameterized.TestCase): method test_block (line 258) | def test_block( method test_post_attention_norm (line 313) | def test_post_attention_norm( method test_post_ffw_norm (line 386) | def test_post_ffw_norm( FILE: examples/gemma/params.py function load_and_format_params (line 28) | def load_and_format_params(path: str) -> Params: function load_metadata (line 37) | def load_metadata(path: str) -> Any | None: function load_params (line 45) | def load_params(path: str) -> Params: function param_remapper (line 52) | def param_remapper(orig_params: Params) -> Params: function nest_params (line 77) | def nest_params(params: Params) -> Params: FILE: examples/gemma/positional_embeddings.py function add_positional_embedding (line 23) | def add_positional_embedding( function apply_rope (line 45) | def apply_rope( FILE: examples/gemma/positional_embeddings_test.py class PositionalEmbeddingsTest (line 29) | class PositionalEmbeddingsTest(parameterized.TestCase): method test_adds_positional_embeddings (line 40) | def test_adds_positional_embeddings( method test_rope_positional_embeddings (line 66) | def test_rope_positional_embeddings( FILE: examples/gemma/sampler.py function _sample_top_p (line 38) | def _sample_top_p(probs: jnp.ndarray, p: float, key: jax.Array) -> jnp.n... function _compute_attention_masks (line 53) | def _compute_attention_masks( class _SamplingState (line 81) | class _SamplingState: class SamplerOutput (line 125) | class SamplerOutput: class Sampler (line 141) | class Sampler: method __init__ (line 144) | def __init__( method transformer (line 168) | def transformer(self) -> transformer_lib.Transformer: method transformer_state (line 172) | def transformer_state(self) -> statelib.State: method transformer_state (line 176) | def transformer_state(self, state: statelib.State) -> statelib.State: method dtype (line 201) | def dtype(self) -> jnp.dtype: method _sample_step (line 206) | def _sample_step( method init_sample_state (line 289) | def init_sample_state( method tokenize (line 354) | def tokenize(self, input_string: str) -> jax.Array: method mask_tokens_after_eos_ids (line 362) | def mask_tokens_after_eos_ids(self, token_buffer): method _sample_fn (line 378) | def _sample_fn( method __call__ (line 397) | def __call__( FILE: examples/gemma/sampler_test.py class MockVocab (line 35) | class MockVocab(spm.SentencePieceProcessor): method __init__ (line 37) | def __init__(self): method pad_id (line 58) | def pad_id(self) -> int: method bos_id (line 61) | def bos_id(self) -> int: method eos_id (line 64) | def eos_id(self) -> int: method GetPieceSize (line 67) | def GetPieceSize(self) -> int: # pylint: disable=invalid-name method DecodeIds (line 70) | def DecodeIds(self, ids: Iterable[int]) -> str: # pylint: disable=inv... method EncodeAsIds (line 74) | def EncodeAsIds(self, text: str) -> list[int]: # pylint: disable=inva... class SamplerTest (line 79) | class SamplerTest(parameterized.TestCase): method assertReasonableTensor (line 81) | def assertReasonableTensor(self, array, expected_shape=None): method test_samples (line 86) | def test_samples(self): method test_state_update (line 134) | def test_state_update(self): method test_invalid_state_update (line 172) | def test_invalid_state_update(self): method test_forbidden_tokens (line 216) | def test_forbidden_tokens(self): method test_forward_equivalence (line 264) | def test_forward_equivalence(self): method test_sampler_init_sample_state (line 324) | def test_sampler_init_sample_state(self): method test_sampler_mask_tokens_after_eos_ids (line 365) | def test_sampler_mask_tokens_after_eos_ids(self): method test_sampler_sows_intermediates (line 409) | def test_sampler_sows_intermediates(self): method test_compute_attention_mask (line 497) | def test_compute_attention_mask(self): method test_models_from_kaggle (line 530) | def test_models_from_kaggle(self, url): FILE: examples/gemma/sow_lib.py class LayerIntermediates (line 25) | class LayerIntermediates: method merge (line 38) | def merge(self, decoding_step, layer: nnx.Module): method trim (line 70) | def trim(self, max_length: int): class TransformerIntermediates (line 80) | class TransformerIntermediates: method merge (line 89) | def merge(self, decoding_step, transformer: nnx.Module): method trim (line 109) | def trim(self, max_length: int): class SowConfig (line 118) | class SowConfig: method maybe_sow_embeddings (line 139) | def maybe_sow_embeddings( method maybe_sow_rs_after_attention (line 148) | def maybe_sow_rs_after_attention( method maybe_sow_rs_after_ffw (line 157) | def maybe_sow_rs_after_ffw( method maybe_sow_mlp_hidden_topk (line 166) | def maybe_sow_mlp_hidden_topk( method maybe_sow_attn_logits_topk (line 178) | def maybe_sow_attn_logits_topk( FILE: examples/gemma/tokenizer.py function _dump_chars_to_textfile (line 38) | def _dump_chars_to_textfile( function _train_sentencepiece (line 67) | def _train_sentencepiece( function _load_sentencepiece_tokenizer (line 145) | def _load_sentencepiece_tokenizer( function load_or_train_tokenizer (line 160) | def load_or_train_tokenizer( class TokenizeOp (line 184) | class TokenizeOp: method __call__ (line 188) | def __call__(self, features: Features) -> Features: function load_sentencepiece_processor (line 194) | def load_sentencepiece_processor(vocab_path: str): FILE: examples/gemma/train.py class MeshRules (line 47) | class MeshRules: method __call__ (line 53) | def __call__(self, *keys: str) -> tuple[str, ...]: class TrainConfig (line 60) | class TrainConfig: method replace (line 152) | def replace(self, **kwargs): method __post_init__ (line 155) | def __post_init__(self): function rsqrt_schedule (line 160) | def rsqrt_schedule( function create_learning_rate_schedule (line 183) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): function compute_weighted_cross_entropy (line 198) | def compute_weighted_cross_entropy( function compute_weighted_accuracy (line 240) | def compute_weighted_accuracy(logits, targets, weights=None): function compute_metrics (line 265) | def compute_metrics(logits, labels, weights, label_smoothing=0.0): function train_step (line 283) | def train_step( function eval_step (line 341) | def eval_step( function evaluate (line 368) | def evaluate( function train_and_evaluate (line 393) | def train_and_evaluate(config: TrainConfig, workdir: str): FILE: examples/gemma/transformer.py function make_attention_layers_types (line 37) | def make_attention_layers_types( class QueryPreAttentionNormalisation (line 50) | class QueryPreAttentionNormalisation(enum.Enum): class TransformerConfig (line 83) | class TransformerConfig: method query_pre_attn_scalar (line 111) | def query_pre_attn_scalar(self) -> float: method from_path (line 122) | def from_path(cls, path: str) -> TransformerConfig: method from_params (line 129) | def from_params(cls, params: params_lib.Params) -> TransformerConfig: method from_version_name (line 185) | def from_version_name(cls, name: str, **override) -> TransformerConfig: method from_dict (line 206) | def from_dict(cls, **config: Any) -> TransformerConfig: method gemma_2b (line 215) | def gemma_2b(cls, **override) -> TransformerConfig: method gemma_7b (line 235) | def gemma_7b(cls, **override): method gemma2_2b (line 255) | def gemma2_2b(cls, **override): method gemma2_9b (line 282) | def gemma2_9b(cls, **override): method gemma2_27b (line 307) | def gemma2_27b(cls, **override): method gemma3_1b (line 332) | def gemma3_1b(cls, **override): method gemma3_4b (line 361) | def gemma3_4b(cls, **override): method gemma3_12b (line 391) | def gemma3_12b(cls, **override): method gemma3_27b (line 421) | def gemma3_27b(cls, **override): method __post_init__ (line 450) | def __post_init__(self): function _map_linen_var_names (line 459) | def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]: function _assign_linen_params_to_nnx_state (line 480) | def _assign_linen_params_to_nnx_state( class Transformer (line 497) | class Transformer(nnx.Module): method from_params (line 501) | def from_params( method __init__ (line 522) | def __init__( method __call__ (line 563) | def __call__( method embed_dim (line 612) | def embed_dim(self) -> int: method num_embed (line 616) | def num_embed(self) -> int: method num_layers (line 620) | def num_layers(self) -> int: method init_cache (line 623) | def init_cache( method init_intermediates (line 639) | def init_intermediates( function make_causal_attn_mask (line 693) | def make_causal_attn_mask( function build_positions_from_mask (line 713) | def build_positions_from_mask(input_mask: Array) -> Array: FILE: examples/gemma/transformer_test.py function create_fake_params (line 28) | def create_fake_params(config: transformer_lib.TransformerConfig): class TransformerTest (line 86) | class TransformerTest(parameterized.TestCase): method test_transformer (line 133) | def test_transformer( method test_logit_softcap (line 186) | def test_logit_softcap( method test_creates_cache (line 274) | def test_creates_cache(self, config, cache_size, keys, k_shape, v_shape): method test_forward_no_cache (line 306) | def test_forward_no_cache( method test_attention_types (line 340) | def test_attention_types( method test_load_from_params (line 414) | def test_load_from_params(self, config): method test_sow_intermediates (line 432) | def test_sow_intermediates(self, sow_config): FILE: examples/gemma/utils.py class TrainState (line 36) | class TrainState(train_state.TrainState): function create_device_mesh (line 44) | def create_device_mesh(config: Any): function fill_unspecified_mesh_axes (line 98) | def fill_unspecified_mesh_axes( function _to_array (line 131) | def _to_array(x): function setup_initial_state (line 137) | def setup_initial_state( FILE: examples/imagenet/configs/default.py function get_config (line 19) | def get_config(): function metrics (line 52) | def metrics(): FILE: examples/imagenet/configs/fake_data_benchmark.py function get_config (line 22) | def get_config(): FILE: examples/imagenet/configs/tpu.py function get_config (line 19) | def get_config(): FILE: examples/imagenet/configs/v100_x8.py function get_config (line 20) | def get_config(): FILE: examples/imagenet/configs/v100_x8_mixed_precision.py function get_config (line 20) | def get_config(): FILE: examples/imagenet/imagenet_benchmark.py class ImagenetBenchmark (line 37) | class ImagenetBenchmark(Benchmark): method _test_8x_v100_half_precision (line 41) | def _test_8x_v100_half_precision( method test_8x_v100_half_precision_short (line 76) | def test_8x_v100_half_precision_short(self): method test_8x_v100_half_precision_full (line 88) | def test_8x_v100_half_precision_full(self): FILE: examples/imagenet/imagenet_fake_data_benchmark.py class ImagenetBenchmarkFakeData (line 37) | class ImagenetBenchmarkFakeData(Benchmark): method test_fake_data (line 40) | def test_fake_data(self): FILE: examples/imagenet/input_pipeline.py function distorted_bounding_box_crop (line 28) | def distorted_bounding_box_crop( function _resize (line 80) | def _resize(image, image_size): function _at_least_x_are_equal (line 86) | def _at_least_x_are_equal(a, b, x): function _decode_and_random_crop (line 93) | def _decode_and_random_crop(image_bytes, image_size): function _decode_and_center_crop (line 116) | def _decode_and_center_crop(image_bytes, image_size): function normalize_image (line 144) | def normalize_image(image): function preprocess_for_train (line 150) | def preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE... function preprocess_for_eval (line 169) | def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_... function create_split (line 187) | def create_split( FILE: examples/imagenet/main.py function main (line 43) | def main(argv): FILE: examples/imagenet/models.py class ResNetBlock (line 30) | class ResNetBlock(nn.Module): method __call__ (line 40) | def __call__( class BottleneckResNetBlock (line 60) | class BottleneckResNetBlock(nn.Module): method __call__ (line 70) | def __call__(self, x): class ResNet (line 90) | class ResNet(nn.Module): method __call__ (line 102) | def __call__(self, x, train: bool = True): FILE: examples/imagenet/models_test.py class ResNetTest (line 29) | class ResNetTest(parameterized.TestCase): method test_resnet_model (line 32) | def test_resnet_model(self): method test_resnet_18_model (line 46) | def test_resnet_18_model(self, model): FILE: examples/imagenet/train.py function create_model (line 50) | def create_model(*, model_cls, half_precision, **kwargs): function initialized (line 62) | def initialized(key, image_size, model): function cross_entropy_loss (line 73) | def cross_entropy_loss(logits, labels): function compute_metrics (line 79) | def compute_metrics(logits, labels): function create_learning_rate_fn (line 90) | def create_learning_rate_fn( function train_step (line 112) | def train_step(state, batch, learning_rate_fn): function eval_step (line 174) | def eval_step(state, batch): function prepare_tf_data (line 180) | def prepare_tf_data(xs): function create_input_iter (line 195) | def create_input_iter( class TrainState (line 220) | class TrainState(train_state.TrainState): function restore_checkpoint (line 225) | def restore_checkpoint(state, workdir): function save_checkpoint (line 229) | def save_checkpoint(state, workdir): function create_train_state (line 243) | def create_train_state( function train_and_evaluate (line 268) | def train_and_evaluate( FILE: examples/imagenet/train_test.py class TrainTest (line 37) | class TrainTest(parameterized.TestCase): method setUp (line 39) | def setUp(self): method test_create_model (line 44) | def test_create_model(self): method test_create_model_local (line 53) | def test_create_model_local(self): method test_train_and_evaluate (line 68) | def test_train_and_evaluate(self, model): FILE: examples/linen_design_test/attention_simple.py class Dense (line 27) | class Dense(Module): method __call__ (line 36) | def __call__(self, inputs): class SoftmaxAttn (line 55) | class SoftmaxAttn(Module): method __call__ (line 58) | def __call__(self, weights): class Dropout (line 63) | class Dropout(Module): method __call__ (line 67) | def __call__(self, x, deterministic=False, rng=None): class SoftmaxAttnWDropout (line 81) | class SoftmaxAttnWDropout(Module): method __call__ (line 86) | def __call__(self, x): class RawDotProductAttention (line 92) | class RawDotProductAttention(Module): method __call__ (line 96) | def __call__(self, query, key, value, bias=None, dtype=jnp.float32): class DotProductAttention (line 115) | class DotProductAttention(Module): method __call__ (line 121) | def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): function concise_vmap (line 142) | def concise_vmap(module, in_axes, out_axes, axis_size=None, **var_specs): class MultiHeadDotProductAttention (line 157) | class MultiHeadDotProductAttention(Module): method __call__ (line 166) | def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): FILE: examples/linen_design_test/autoencoder.py class MLP (line 26) | class MLP(Module): method __call__ (line 30) | def __call__(self, x): class AutoEncoder (line 38) | class AutoEncoder(Module): method setup (line 43) | def setup(self): method __call__ (line 48) | def __call__(self, x): method encode (line 51) | def encode(self, x): method decode (line 55) | def decode(self, z): FILE: examples/linen_design_test/dense.py class Dense (line 21) | class Dense(Module): method __call__ (line 28) | def __call__(self, inputs): FILE: examples/linen_design_test/linear_regression.py function predict (line 27) | def predict(params): function loss_fn (line 32) | def loss_fn(params): function init_params (line 37) | def init_params(rng): FILE: examples/linen_design_test/mlp_explicit.py class DenseExplicit (line 26) | class DenseExplicit(Dense): method setup (line 29) | def setup(self): class MLP (line 40) | class MLP(Module): method setup (line 42) | def setup(self): method __call__ (line 52) | def __call__(self, x): FILE: examples/linen_design_test/mlp_inline.py class MLP (line 25) | class MLP(Module): method __call__ (line 29) | def __call__(self, x): FILE: examples/linen_design_test/mlp_lazy.py class MLP (line 25) | class MLP(Module): method setup (line 27) | def setup(self): method __call__ (line 35) | def __call__(self, x): FILE: examples/lm1b/configs/default.py function get_config (line 20) | def get_config(): FILE: examples/lm1b/input_pipeline.py class NormalizeFeatureNamesOp (line 31) | class NormalizeFeatureNamesOp: method __init__ (line 34) | def __init__(self, ds_info: tfds.core.DatasetInfo): method __call__ (line 37) | def __call__(self, features: Features) -> Features: function get_raw_dataset (line 44) | def get_raw_dataset( function pack_dataset (line 69) | def pack_dataset( function _pack_with_tf_ops (line 152) | def _pack_with_tf_ops( function preprocess_data (line 276) | def preprocess_data( function get_datasets (line 321) | def get_datasets( FILE: examples/lm1b/input_pipeline_test.py class InputPipelineTest (line 33) | class InputPipelineTest(absltest.TestCase): method setUp (line 35) | def setUp(self): method _get_datasets (line 41) | def _get_datasets(self): method test_train_ds (line 63) | def test_train_ds(self): method test_eval_ds (line 80) | def test_eval_ds(self): method test_predict_ds (line 91) | def test_predict_ds(self): FILE: examples/lm1b/main.py function main (line 43) | def main(argv): FILE: examples/lm1b/models.py class TransformerConfig (line 37) | class TransformerConfig: function shift_right (line 60) | def shift_right(x, axis=1): function shift_inputs (line 70) | def shift_inputs(x, segment_ids=None, axis=1): function sinusoidal_init (line 80) | def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0): class AddPositionEmbs (line 108) | class AddPositionEmbs(nn.Module): method __call__ (line 120) | def __call__(self, inputs, inputs_positions=None): class MlpBlock (line 171) | class MlpBlock(nn.Module): method __call__ (line 183) | def __call__(self, inputs): class EncoderDecoder1DBlock (line 213) | class EncoderDecoder1DBlock(nn.Module): method __call__ (line 223) | def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None): class Decoder (line 281) | class Decoder(nn.Module): method __call__ (line 293) | def __call__( class TransformerLM (line 377) | class TransformerLM(nn.Module): method __call__ (line 387) | def __call__(self, inputs, inputs_positions=None, inputs_segmentation=... FILE: examples/lm1b/temperature_sampler.py function temperature_sample (line 27) | def temperature_sample( FILE: examples/lm1b/temperature_sampler_test.py class TestTemperatureSampler (line 26) | class TestTemperatureSampler(absltest.TestCase): method test_temperature_sampler (line 28) | def test_temperature_sampler(self): FILE: examples/lm1b/tokenizer.py function _dump_chars_to_textfile (line 35) | def _dump_chars_to_textfile( function _train_sentencepiece (line 64) | def _train_sentencepiece( function _load_sentencepiece_tokenizer (line 123) | def _load_sentencepiece_tokenizer( function load_or_train_tokenizer (line 138) | def load_or_train_tokenizer( class TokenizeOp (line 162) | class TokenizeOp: method __call__ (line 166) | def __call__(self, features: Features) -> Features: FILE: examples/lm1b/train.py function rsqrt_schedule (line 47) | def rsqrt_schedule( function create_learning_rate_schedule (line 70) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): function compute_weighted_cross_entropy (line 85) | def compute_weighted_cross_entropy( function compute_weighted_accuracy (line 127) | def compute_weighted_accuracy(logits, targets, weights=None): function compute_metrics (line 152) | def compute_metrics(logits, labels, weights, label_smoothing=0.0): function train_step (line 170) | def train_step( function eval_step (line 220) | def eval_step(params, batch, config, label_smoothing=0.0): function predict_step (line 229) | def predict_step( function pad_examples (line 270) | def pad_examples(x, desired_batch_size): function tohost (line 276) | def tohost(x): function evaluate (line 282) | def evaluate( function generate_prediction (line 308) | def generate_prediction( function train_and_evaluate (line 362) | def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): FILE: examples/lm1b/train_test.py class TrainTest (line 32) | class TrainTest(absltest.TestCase): method setUp (line 35) | def setUp(self): method test_train_and_evaluate (line 41) | def test_train_and_evaluate(self): FILE: examples/lm1b/utils.py function create_device_mesh (line 33) | def create_device_mesh(config): function fill_unspecified_mesh_axes (line 79) | def fill_unspecified_mesh_axes( function unbox_logicallypartioned_trainstate (line 112) | def unbox_logicallypartioned_trainstate( function init_train_state (line 130) | def init_train_state(model, tx, config, key): function setup_initial_state (line 149) | def setup_initial_state(model, tx, config, rng, mesh): FILE: examples/mnist/configs/default.py function get_config (line 20) | def get_config(): function metrics (line 31) | def metrics(): FILE: examples/mnist/main.py function main (line 43) | def main(argv): FILE: examples/mnist/mnist_benchmark.py class MnistBenchmark (line 35) | class MnistBenchmark(Benchmark): method test_cpu (line 39) | def test_cpu(self): FILE: examples/mnist/train.py class CNN (line 39) | class CNN(nnx.Module): method __init__ (line 42) | def __init__(self, rngs: nnx.Rngs): method __call__ (line 53) | def __call__(self, x, rngs: nnx.Rngs): function loss_fn (line 63) | def loss_fn(model: CNN, batch, rngs): function train_step (line 72) | def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiM... function eval_step (line 83) | def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): function get_datasets (line 88) | def get_datasets( function train_and_evaluate (line 122) | def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) ... FILE: examples/mnist/train_test.py class TrainTest (line 36) | class TrainTest(absltest.TestCase): method setUp (line 39) | def setUp(self): method test_cnn (line 46) | def test_cnn(self): method test_train_and_evaluate (line 55) | def test_train_and_evaluate(self): FILE: examples/nlp_seq/configs/default.py function get_config (line 20) | def get_config(): FILE: examples/nlp_seq/input_pipeline.py class CoNLLAttributes (line 35) | class CoNLLAttributes(enum.Enum): function create_vocabs (line 57) | def create_vocabs(filename, max_num_forms=100000): function create_token (line 106) | def create_token(token, attributes, vocabs): function create_sentence_with_root (line 142) | def create_sentence_with_root(attributes, vocabs): function sentences_from_conll_data (line 163) | def sentences_from_conll_data( function sentence_dataset_dict (line 199) | def sentence_dataset_dict( FILE: examples/nlp_seq/input_pipeline_test.py class InputPipelineTest (line 44) | class InputPipelineTest(absltest.TestCase): method setUp (line 46) | def setUp(self): method test_vocab_creation (line 57) | def test_vocab_creation(self): method testInputBatch (line 74) | def testInputBatch(self): method testInputTargetBatch (line 103) | def testInputTargetBatch(self): FILE: examples/nlp_seq/main.py function main (line 36) | def main(argv): FILE: examples/nlp_seq/models.py class TransformerConfig (line 27) | class TransformerConfig: function sinusoidal_init (line 46) | def sinusoidal_init(max_len=2048): class AddPositionEmbs (line 73) | class AddPositionEmbs(nn.Module): method __call__ (line 83) | def __call__(self, inputs): class MlpBlock (line 116) | class MlpBlock(nn.Module): method __call__ (line 128) | def __call__(self, inputs, deterministic=True): class Encoder1DBlock (line 152) | class Encoder1DBlock(nn.Module): method __call__ (line 162) | def __call__(self, inputs, deterministic): class Transformer (line 198) | class Transformer(nn.Module): method __call__ (line 204) | def __call__(self, *, inputs, train): FILE: examples/nlp_seq/train.py function create_learning_rate_scheduler (line 82) | def create_learning_rate_scheduler( function compute_weighted_cross_entropy (line 142) | def compute_weighted_cross_entropy(logits, targets, weights=None): function compute_weighted_accuracy (line 168) | def compute_weighted_accuracy(logits, targets, weights=None): function compute_metrics (line 193) | def compute_metrics(logits, labels, weights): function train_step (line 206) | def train_step(state, batch, model, learning_rate_fn, dropout_rng=None): function pad_examples (line 239) | def pad_examples(x, desired_batch_size): function main (line 246) | def main(argv): FILE: examples/nnx_toy_examples/01_functional_api.py function dataset (line 27) | def dataset(batch_size): class Linear (line 33) | class Linear(nnx.Module): method __init__ (line 34) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 38) | def __call__(self, x): class Count (line 42) | class Count(nnx.Variable[nnx.A]): class MLP (line 46) | class MLP(nnx.Module): method __init__ (line 47) | def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): method __call__ (line 52) | def __call__(self, x): function train_step (line 63) | def train_step(params, counts, batch): function test_step (line 81) | def test_step(params: nnx.State, counts: nnx.State, batch): FILE: examples/nnx_toy_examples/02_lifted_transforms.py function dataset (line 28) | def dataset(batch_size): class Linear (line 34) | class Linear(nnx.Module): method __init__ (line 35) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 39) | def __call__(self, x): class Count (line 43) | class Count(nnx.Variable): class MLP (line 47) | class MLP(nnx.Module): method __init__ (line 48) | def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): method __call__ (line 53) | def __call__(self, x): function train_step (line 67) | def train_step(model: MLP, optimizer: nnx.Optimizer, batch): function test_step (line 79) | def test_step(model: MLP, batch): FILE: examples/nnx_toy_examples/03_train_state.py function dataset (line 29) | def dataset(batch_size): class Linear (line 35) | class Linear(nnx.Module): method __init__ (line 36) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 40) | def __call__(self, x): class Count (line 44) | class Count(nnx.Variable[nnx.A]): class MLP (line 48) | class MLP(nnx.Module): method __init__ (line 49) | def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): method __call__ (line 54) | def __call__(self, x): class TrainState (line 61) | class TrainState(train_state.TrainState): function train_step (line 79) | def train_step(state: TrainState, batch): function test_step (line 97) | def test_step(state: nnx.TrainState[MLP], batch): FILE: examples/nnx_toy_examples/04_data_parallel_with_jit.py class MLP (line 36) | class MLP(nnx.Module): method __init__ (line 37) | def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs): method __call__ (line 41) | def __call__(self, x): function train_step (line 59) | def train_step(model: MLP, optimizer: nnx.Optimizer, x, y): function dataset (line 69) | def dataset(steps, batch_size): FILE: examples/nnx_toy_examples/05_vae.py class Loss (line 46) | class Loss(nnx.Variable): class Encoder (line 51) | class Encoder(nnx.Module): method __init__ (line 52) | def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 58) | def __call__(self, x: jax.Array) -> jax.Array: class Decoder (line 76) | class Decoder(nnx.Module): method __init__ (line 77) | def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 81) | def __call__(self, z: jax.Array) -> jax.Array: class VAE (line 88) | class VAE(nnx.Module): method __init__ (line 89) | def __init__( method __call__ (line 104) | def __call__(self, x: jax.Array) -> jax.Array: method generate (line 110) | def generate(self, z): function train_step (line 129) | def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array): function forward (line 147) | def forward(model: VAE, x: jax.Array) -> jax.Array: function sample (line 153) | def sample(model: VAE, z: jax.Array) -> jax.Array: FILE: examples/nnx_toy_examples/06_scan_over_layers.py class Block (line 22) | class Block(nnx.Module): method __init__ (line 23) | def __init__(self, dim: int, *, rngs: nnx.Rngs): method __call__ (line 28) | def __call__(self, x: jax.Array): class ScanMLP (line 32) | class ScanMLP(nnx.Module): method __init__ (line 39) | def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs): method __call__ (line 49) | def __call__(self, x: jax.Array) -> jax.Array: FILE: examples/nnx_toy_examples/07_array_leaves.py function dataset (line 28) | def dataset(batch_size): class Linear (line 33) | class Linear(nnx.Module): method __init__ (line 34) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 38) | def __call__(self, x): class MLP (line 42) | class MLP(nnx.Module): method __init__ (line 43) | def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): method __call__ (line 48) | def __call__(self, x): function is_param (line 52) | def is_param(path, value): function train_step (line 62) | def train_step(model: MLP, optimizer: nnx.Optimizer, batch): function test_step (line 75) | def test_step(model: MLP, batch): FILE: examples/nnx_toy_examples/08_save_load_checkpoints.py class MLP (line 24) | class MLP(nnx.Module): method __init__ (line 25) | def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 29) | def __call__(self, x: jax.Array) -> jax.Array: function create_model (line 36) | def create_model(seed: int): function create_and_save (line 40) | def create_and_save(seed: int, path: str): function load_model (line 48) | def load_model(path: str) -> MLP: FILE: examples/nnx_toy_examples/09_parameter_surgery.py function load_pretrained (line 22) | def load_pretrained(): class Classifier (line 27) | class Classifier(nnx.Module): method __init__ (line 28) | def __init__(self, *, rngs: nnx.Rngs): method __call__ (line 32) | def __call__(self, x): FILE: examples/nnx_toy_examples/10_fsdp_and_optimizer.py function named_sharding (line 34) | def named_sharding(*names: str | None) -> NamedSharding: class MeshRules (line 39) | class MeshRules: method __call__ (line 44) | def __call__(self, *keys: str) -> tuple[str, ...]: class MLP (line 55) | class MLP(nnx.Module): method __init__ (line 56) | def __init__(self, din, dmid, dout, rngs: nnx.Rngs): method __call__ (line 70) | def __call__(self, x: jax.Array): class SGDState (line 74) | class SGDState(nnx.Variable): class SGD (line 78) | class SGD(nnx.Pytree): method __init__ (line 79) | def __init__(self, params: nnx.State, lr, decay=0.9): method update (line 94) | def update(self, grads: nnx.State): function create_model (line 113) | def create_model(): function train_step (line 128) | def train_step(model: MLP, optimizer: SGD, x, y): function dataset (line 143) | def dataset(batch_size, num_steps): FILE: examples/nnx_toy_examples/hijax_basic.py function dataset (line 28) | def dataset(batch_size): class Linear (line 34) | class Linear(nnx.Module): method __init__ (line 35) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 39) | def __call__(self, x): class Count (line 43) | class Count(nnx.Variable[nnx.A]): class MLP (line 47) | class MLP(nnx.Module): method __init__ (line 48) | def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): method __call__ (line 53) | def __call__(self, x): function train_step (line 64) | def train_step(model, optimizer, x, y): function test_step (line 75) | def test_step(model: MLP, x, y): FILE: examples/nnx_toy_examples/hijax_demo.py function dataset (line 29) | def dataset(batch_size): class Linear (line 43) | class Linear(nnx.Module): method __init__ (line 44) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): method __call__ (line 52) | def __call__(self, x: jax.Array): class Block (line 58) | class Block(nnx.Module): method __init__ (line 59) | def __init__( method __call__ (line 86) | def __call__( class Model (line 116) | class Model(nnx.Module): method __init__ (line 117) | def __init__( method __call__ (line 148) | def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None): class OptState (line 171) | class OptState(nnx.Variable): ... class SGD (line 179) | class SGD(nnx.Pytree): method __init__ (line 180) | def __init__(self, params, lr: float, decay: float = 0.9): method update (line 195) | def update(self, params, grads): function train_step (line 231) | def train_step(model: Model, optimizer: SGD, rngs: nnx.Rngs, x, y): function test_step (line 249) | def test_step(model: Model, x, y): FILE: examples/ogbg_molpcba/configs/default.py function get_config (line 23) | def get_config(): FILE: examples/ogbg_molpcba/configs/default_graph_net.py function get_config (line 23) | def get_config(): FILE: examples/ogbg_molpcba/configs/hparam_sweep.py function get_config (line 20) | def get_config(): function sweep (line 51) | def sweep(add): FILE: examples/ogbg_molpcba/configs/test.py function get_config (line 20) | def get_config(): FILE: examples/ogbg_molpcba/input_pipeline.py class GraphsTupleSize (line 26) | class GraphsTupleSize(NamedTuple): function get_raw_datasets (line 34) | def get_raw_datasets() -> dict[str, tf.data.Dataset]: function get_datasets (line 43) | def get_datasets( function convert_to_graphs_tuple (line 111) | def convert_to_graphs_tuple( function estimate_padding_budget_for_batch_size (line 168) | def estimate_padding_budget_for_batch_size( function specs_from_graphs_tuple (line 213) | def specs_from_graphs_tuple(graph: jraph.GraphsTuple): function get_graphs_tuple_size (line 236) | def get_graphs_tuple_size(graph: jraph.GraphsTuple): FILE: examples/ogbg_molpcba/input_pipeline_test.py function get_dummy_datasets (line 24) | def get_dummy_datasets(dataset_length: int): class InputPipelineTest (line 53) | class InputPipelineTest(parameterized.TestCase): method setUp (line 55) | def setUp(self): method test_estimate_padding_budget_valid (line 63) | def test_estimate_padding_budget_valid(self, valid_batch_size): method test_estimate_padding_budget_invalid (line 72) | def test_estimate_padding_budget_invalid(self, invalid_batch_size): FILE: examples/ogbg_molpcba/main.py function main (line 43) | def main(argv): FILE: examples/ogbg_molpcba/models.py function add_graphs_tuples (line 24) | def add_graphs_tuples( class MLP (line 35) | class MLP(nn.Module): method __call__ (line 44) | def __call__(self, inputs): class GraphNet (line 55) | class GraphNet(nn.Module): method __call__ (line 69) | def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: class GraphConvNet (line 137) | class GraphConvNet(nn.Module): method pool (line 153) | def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: method __call__ (line 173) | def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: FILE: examples/ogbg_molpcba/models_test.py class ModelsTest (line 26) | class ModelsTest(parameterized.TestCase): method setUp (line 28) | def setUp(self): method test_mlp (line 53) | def test_mlp(self, dropout_rate, output_size, num_layers): method test_graph_net (line 91) | def test_graph_net( method test_graph_conv_net (line 125) | def test_graph_conv_net(self, latent_size: int, output_globals_size: i... FILE: examples/ogbg_molpcba/ogbg_molpcba_benchmark.py class OgbgMolpcbaBenchmark (line 36) | class OgbgMolpcbaBenchmark(Benchmark): method test_1x_v100 (line 39) | def test_1x_v100(self): method test_cpu (line 92) | def test_cpu(self): FILE: examples/ogbg_molpcba/train.py function create_model (line 44) | def create_model( function create_optimizer (line 74) | def create_optimizer( function binary_cross_entropy_with_mask (line 87) | def binary_cross_entropy_with_mask( function predictions_match_labels (line 106) | def predictions_match_labels( function add_prefix_to_keys (line 115) | def add_prefix_to_keys(result: dict[str, Any], prefix: str) -> dict[str,... class MeanAveragePrecision (line 121) | class MeanAveragePrecision( method compute (line 126) | def compute(self): class EvalMetrics (line 156) | class EvalMetrics(metrics.Collection): class TrainMetrics (line 163) | class TrainMetrics(metrics.Collection): function replace_globals (line 168) | def replace_globals(graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: function get_predicted_logits (line 173) | def get_predicted_logits( function get_valid_mask (line 184) | def get_valid_mask( function train_step (line 203) | def train_step( function evaluate_step (line 240) | def evaluate_step( function evaluate_model (line 266) | def evaluate_model( function train_and_evaluate (line 292) | def train_and_evaluate( FILE: examples/ogbg_molpcba/train_test.py function average_with_mask (line 40) | def average_with_mask(arr: jnp.ndarray, mask: jnp.ndarray): function get_dummy_raw_datasets (line 46) | def get_dummy_raw_datasets(dataset_length) -> dict[str, tf.data.Dataset]: function get_dummy_datasets (line 81) | def get_dummy_datasets( class OgbgMolpcbaTrainTest (line 134) | class OgbgMolpcbaTrainTest(parameterized.TestCase): method setUp (line 136) | def setUp(self): method test_binary_cross_entropy_loss (line 161) | def test_binary_cross_entropy_loss(self, probs, labels): method test_mean_average_precision (line 199) | def test_mean_average_precision(self, logits, labels, expected_result): method test_eval_metrics (line 226) | def test_eval_metrics(self, loss, logits, labels, mask, expected_resul... method test_train_metrics (line 251) | def test_train_metrics(self, loss, logits, labels, mask, expected_resu... method test_train_step (line 263) | def test_train_step(self): method test_evaluate_step (line 304) | def test_evaluate_step(self): method test_train_and_evaluate (line 337) | def test_train_and_evaluate(self): FILE: examples/ppo/agent.py function policy_action (line 31) | def policy_action( class RemoteSimulator (line 55) | class RemoteSimulator: method __init__ (line 61) | def __init__(self, game: str): function rcv_action_send_exp (line 72) | def rcv_action_send_exp(conn, game: str): FILE: examples/ppo/configs/default.py function get_config (line 20) | def get_config(): FILE: examples/ppo/env_utils.py class ClipRewardEnv (line 25) | class ClipRewardEnv(gym.RewardWrapper): method __init__ (line 31) | def __init__(self, env): method reward (line 34) | def reward(self, reward): class FrameStack (line 39) | class FrameStack: method __init__ (line 45) | def __init__( method reset (line 54) | def reset(self): method step (line 60) | def step(self, action: int): method _get_array (line 65) | def _get_array(self): function create_env (line 70) | def create_env(game: str, clip_rewards: bool): function get_num_actions (line 80) | def get_num_actions(game: str): FILE: examples/ppo/models.py class ActorCritic (line 21) | class ActorCritic(nn.Module): method __call__ (line 27) | def __call__(self, x): FILE: examples/ppo/ppo_lib.py function gae_advantages (line 39) | def gae_advantages( function loss_fn (line 81) | def loss_fn( function train_step (line 134) | def train_step( function get_experience (line 180) | def get_experience( function process_experience (line 217) | def process_experience( function get_initial_params (line 271) | def get_initial_params(key: jax.Array, model: nn.Module): function create_train_state (line 278) | def create_train_state( function train (line 299) | def train( FILE: examples/ppo/ppo_lib_test.py class TestGAE (line 35) | class TestGAE(absltest.TestCase): method test_gae_shape_on_random (line 37) | def test_gae_shape_on_random(self): method test_gae_hardcoded (line 52) | def test_gae_hardcoded(self): class TestEnvironmentPreprocessing (line 68) | class TestEnvironmentPreprocessing(absltest.TestCase): method choose_random_game (line 70) | def choose_random_game(self): method test_creation (line 82) | def test_creation(self): method test_step (line 89) | def test_step(self): class TestModel (line 104) | class TestModel(absltest.TestCase): method choose_random_outputs (line 106) | def choose_random_outputs(self): method test_model (line 109) | def test_model(self): class TestOptimizationStep (line 125) | class TestOptimizationStep(absltest.TestCase): method generate_random_data (line 127) | def generate_random_data(self, num_actions): method test_optimization_step (line 137) | def test_optimization_step(self): FILE: examples/ppo/ppo_main.py function main (line 48) | def main(argv): FILE: examples/ppo/seed_rl_atari_preprocessing.py class AtariPreprocessing (line 39) | class AtariPreprocessing: method __init__ (line 54) | def __init__( method observation_space (line 104) | def observation_space(self): method action_space (line 115) | def action_space(self): method reward_range (line 119) | def reward_range(self): method metadata (line 123) | def metadata(self): method close (line 126) | def close(self): method apply_random_noops (line 129) | def apply_random_noops(self): method reset (line 141) | def reset(self): method render (line 155) | def render(self, mode): method step (line 169) | def step(self, action): method _fetch_grayscale_observation (line 214) | def _fetch_grayscale_observation(self, output): method _pool_and_resize (line 225) | def _pool_and_resize(self): FILE: examples/ppo/test_episodes.py function policy_test (line 28) | def policy_test( FILE: examples/seq2seq/configs/default.py function get_config (line 20) | def get_config(): FILE: examples/seq2seq/input_pipeline.py class CharacterTable (line 27) | class CharacterTable: method __init__ (line 30) | def __init__(self, chars: str, max_len_query_digit: int = 3) -> None: method pad_id (line 39) | def pad_id(self) -> int: method eos_id (line 43) | def eos_id(self) -> int: method vocab_size (line 47) | def vocab_size(self) -> int: method max_input_len (line 52) | def max_input_len(self) -> int: method max_output_len (line 59) | def max_output_len(self) -> int: method encoder_input_shape (line 67) | def encoder_input_shape(self) -> tuple[int, int, int]: method decoder_input_shape (line 71) | def decoder_input_shape(self) -> tuple[int, int, int]: method encode (line 74) | def encode(self, inputs: str) -> np.ndarray: method decode (line 80) | def decode(self, inputs: Array) -> str: method one_hot (line 89) | def one_hot(self, tokens: np.ndarray) -> np.ndarray: method encode_onehot (line 94) | def encode_onehot( method decode_onehot (line 112) | def decode_onehot(self, batch_inputs: Array) -> np.ndarray: method generate_examples (line 117) | def generate_examples( method get_batch (line 130) | def get_batch(self, batch_size: int) -> dict[str, np.ndarray]: function mask_sequences (line 139) | def mask_sequences(sequence_batch: Array, lengths: Array) -> Array: function get_sequence_lengths (line 146) | def get_sequence_lengths(sequence_batch: Array, eos_id: int) -> Array: FILE: examples/seq2seq/main.py function main (line 33) | def main(argv): FILE: examples/seq2seq/models.py class DecoderLSTMCell (line 31) | class DecoderLSTMCell(nn.RNNCellBase): method __call__ (line 44) | def __call__( method num_feature_axes (line 65) | def num_feature_axes(self) -> int: class Seq2seq (line 69) | class Seq2seq(nn.Module): method __call__ (line 88) | def __call__( method get_seq_lengths (line 131) | def get_seq_lengths(self, inputs: Array) -> Array: FILE: examples/seq2seq/train.py function get_model (line 74) | def get_model(ctable: CTable, *, teacher_force: bool = False) -> models.... function get_initial_params (line 83) | def get_initial_params( function get_train_state (line 96) | def get_train_state(rng: PRNGKey, ctable: CTable) -> train_state.TrainSt... function cross_entropy_loss (line 107) | def cross_entropy_loss( function compute_metrics (line 116) | def compute_metrics( function train_step (line 137) | def train_step( function log_decode (line 164) | def log_decode(question: str, inferred: str, golden: str): function decode (line 173) | def decode( function decode_batch (line 188) | def decode_batch( function train_and_evaluate (line 206) | def train_and_evaluate(workdir: str) -> train_state.TrainState: function main (line 226) | def main(_): FILE: examples/seq2seq/train_test.py function create_ctable (line 33) | def create_ctable(chars='0123456789+= '): function create_train_state (line 37) | def create_train_state(ctable): class TrainTest (line 51) | class TrainTest(absltest.TestCase): method test_character_table (line 53) | def test_character_table(self): method test_mask_sequences (line 62) | def test_mask_sequences(self): method test_get_sequence_lengths (line 70) | def test_get_sequence_lengths(self): method test_train_one_step (line 87) | def test_train_one_step(self): method test_decode_batch (line 98) | def test_decode_batch(self): FILE: examples/sst2/build_vocabulary.py function get_tokenized_sequences (line 28) | def get_tokenized_sequences( FILE: examples/sst2/configs/default.py function get_config (line 20) | def get_config(): FILE: examples/sst2/input_pipeline.py function get_bucket_boundaries (line 34) | def get_bucket_boundaries(bucket_size: int, max_size: int) -> np.ndarray: function get_num_examples (line 56) | def get_num_examples(dataset: tf.data.Dataset) -> int: function get_bucketed_batches (line 61) | def get_bucketed_batches( function vocab_to_hashtable (line 136) | def vocab_to_hashtable( function vocab_to_inverse_hashtable (line 148) | def vocab_to_inverse_hashtable( function _is_text_field (line 163) | def _is_text_field(feature_name_and_type): function _is_class_label (line 169) | def _is_class_label(feature_name_and_type): class TextDataset (line 175) | class TextDataset: method __init__ (line 178) | def __init__( method padded_shapes (line 208) | def padded_shapes(self): method example_length_fn (line 213) | def example_length_fn(self, example: Example) -> tf.Tensor: method add_bos_eos (line 217) | def add_bos_eos(self, sequence: tf.Tensor) -> tf.Tensor: method prepare_example (line 221) | def prepare_example(self, example: Example) -> Example: method get_batches (line 232) | def get_batches( method get_bucketed_batches (line 256) | def get_bucketed_batches( FILE: examples/sst2/input_pipeline_test.py class InputPipelineTest (line 27) | class InputPipelineTest(absltest.TestCase): method setUp (line 29) | def setUp(self): method _get_vocab_path (line 36) | def _get_vocab_path(self): method _get_dataset (line 48) | def _get_dataset(self, vocab_path): method test_bucketed_dataset (line 56) | def test_bucketed_dataset(self): method test_batched_dataset (line 72) | def test_batched_dataset(self): method test_batched_dataset_fixed_length (line 85) | def test_batched_dataset_fixed_length(self): FILE: examples/sst2/main.py function main (line 43) | def main(argv): FILE: examples/sst2/models.py function sequence_mask (line 28) | def sequence_mask(lengths: Array, max_length: int) -> Array: function flip_sequences (line 50) | def flip_sequences(inputs: Array, lengths: Array) -> Array: class WordDropout (line 83) | class WordDropout(nn.Module): method __call__ (line 95) | def __call__(self, inputs: Array, deterministic: bool | None = None): class Embedder (line 106) | class Embedder(nn.Module): method setup (line 129) | def setup(self): method __call__ (line 141) | def __call__( class SimpleLSTM (line 167) | class SimpleLSTM(nn.Module): method __call__ (line 180) | def __call__(self, carry, x): method initialize_carry (line 183) | def initialize_carry(self, input_shape): class SimpleBiLSTM (line 190) | class SimpleBiLSTM(nn.Module): method setup (line 195) | def setup(self): method __call__ (line 199) | def __call__(self, embedded_inputs, lengths): class MLP (line 219) | class MLP(nn.Module): method setup (line 238) | def setup(self): method __call__ (line 243) | def __call__(self, inputs: Array, deterministic: bool | None = None): class KeysOnlyMlpAttention (line 263) | class KeysOnlyMlpAttention(nn.Module): method __call__ (line 283) | def __call__(self, keys: Array, mask: Array) -> Array: class AttentionClassifier (line 309) | class AttentionClassifier(nn.Module): method setup (line 325) | def setup(self): method __call__ (line 337) | def __call__( class TextClassifier (line 376) | class TextClassifier(nn.Module): method setup (line 389) | def setup(self): method embed_token_ids (line 404) | def embed_token_ids( method logits_from_embedded_inputs (line 412) | def logits_from_embedded_inputs( method __call__ (line 424) | def __call__( FILE: examples/sst2/models_test.py class ModelTest (line 28) | class ModelTest(parameterized.TestCase): method test_embedder_returns_correct_output_shape (line 30) | def test_embedder_returns_correct_output_shape(self): method test_lstm_returns_correct_output_shape (line 42) | def test_lstm_returns_correct_output_shape(self): method test_bilstm_returns_correct_output_shape (line 57) | def test_bilstm_returns_correct_output_shape(self): method test_text_classifier_returns_correct_output_shape (line 73) | def test_text_classifier_returns_correct_output_shape(self): FILE: examples/sst2/train.py class Metrics (line 39) | class Metrics(struct.PyTreeNode): function sigmoid_cross_entropy_with_logits (line 48) | def sigmoid_cross_entropy_with_logits(*, labels: Array, logits: Array) -... function get_initial_params (line 57) | def get_initial_params(rng, model): function create_train_state (line 65) | def create_train_state(rng, config: ml_collections.ConfigDict, model): function compute_metrics (line 76) | def compute_metrics(*, labels: Array, logits: Array) -> Metrics: function model_from_config (line 90) | def model_from_config(config: ml_collections.ConfigDict): function train_step (line 104) | def train_step( function eval_step (line 141) | def eval_step( function normalize_batch_metrics (line 157) | def normalize_batch_metrics(batch_metrics: Sequence[Metrics]) -> Metrics: function batch_to_numpy (line 169) | def batch_to_numpy(batch: dict[str, Array]) -> dict[str, Array]: function evaluate_model (line 176) | def evaluate_model( function train_epoch (line 204) | def train_epoch( function train_and_evaluate (line 232) | def train_and_evaluate( FILE: examples/sst2/train_test.py class TrainTest (line 32) | class TrainTest(parameterized.TestCase): method test_train_step_updates_parameters (line 34) | def test_train_step_updates_parameters(self): FILE: examples/sst2/vocabulary.py class Vocabulary (line 24) | class Vocabulary: method __init__ (line 27) | def __init__( method build (line 57) | def build( method _getitem__ (line 90) | def _getitem__(self, key: str): method keys (line 93) | def keys(self): method values (line 96) | def values(self): method __len__ (line 99) | def __len__(self): method pad_idx (line 103) | def pad_idx(self): method unk_idx (line 108) | def unk_idx(self): method bos_idx (line 113) | def bos_idx(self): method eos_idx (line 118) | def eos_idx(self): method load (line 122) | def load(self, path: str) -> None: method save (line 132) | def save(self, path: str) -> None: FILE: examples/vae/configs/default.py function get_config (line 20) | def get_config(): FILE: examples/vae/input_pipeline.py function build_train_set (line 23) | def build_train_set(batch_size, ds_builder): function build_test_set (line 36) | def build_test_set(ds_builder): function prepare_image (line 45) | def prepare_image(x): FILE: examples/vae/main.py function main (line 44) | def main(argv): FILE: examples/vae/models.py class Encoder (line 22) | class Encoder(nn.Module): method __call__ (line 28) | def __call__(self, x): class Decoder (line 36) | class Decoder(nn.Module): method __call__ (line 40) | def __call__(self, z): class VAE (line 47) | class VAE(nn.Module): method setup (line 52) | def setup(self): method __call__ (line 56) | def __call__(self, x, z_rng): method generate (line 62) | def generate(self, z): function reparameterize (line 66) | def reparameterize(rng, mean, logvar): function model (line 72) | def model(latents): FILE: examples/vae/train.py function kl_divergence (line 33) | def kl_divergence(mean, logvar): function binary_cross_entropy_with_logits (line 38) | def binary_cross_entropy_with_logits(logits, labels): function compute_metrics (line 45) | def compute_metrics(recon_x, x, mean, logvar): function train_step (line 51) | def train_step(state, batch, z_rng, latents): function eval_f (line 67) | def eval_f(params, images, z, z_rng, latents): function train_and_evaluate (line 84) | def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): FILE: examples/vae/utils.py function save_image (line 28) | def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format_img... FILE: examples/wmt/bleu.py class UnicodeRegex (line 49) | class UnicodeRegex: method __init__ (line 52) | def __init__(self): method property_chars (line 58) | def property_chars(self, prefix): function bleu_tokenize (line 69) | def bleu_tokenize(string): function _get_ngrams (line 98) | def _get_ngrams(segment, max_order): function compute_bleu_matches (line 117) | def compute_bleu_matches(reference_corpus, translation_corpus, max_order... function bleu_partial (line 165) | def bleu_partial(ref_lines, hyp_lines, case_sensitive=False): function complete_bleu (line 179) | def complete_bleu( function bleu_local (line 221) | def bleu_local(ref_lines, hyp_lines, case_sensitive=False): FILE: examples/wmt/configs/default.py function get_config (line 20) | def get_config(): function metrics (line 119) | def metrics(): FILE: examples/wmt/decode.py function brevity_penalty (line 33) | def brevity_penalty(alpha, length): function add_beam_dim (line 49) | def add_beam_dim(x, beam_size): function flatten_beam_dim (line 59) | def flatten_beam_dim(x): function unflatten_beam_dim (line 66) | def unflatten_beam_dim(x, batch_size, beam_size): function flat_batch_beam_expand (line 74) | def flat_batch_beam_expand(x, beam_size): function gather_beams (line 79) | def gather_beams(nested, beam_indices, batch_size, new_beam_size): function gather_topk_beams (line 106) | def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_si... class BeamState (line 129) | class BeamState: function beam_init (line 147) | def beam_init(batch_size, beam_size, max_decode_len, cache): function beam_search (line 175) | def beam_search( FILE: examples/wmt/input_pipeline.py class NormalizeFeatureNamesOp (line 32) | class NormalizeFeatureNamesOp: method __init__ (line 35) | def __init__(self, ds_info: tfds.core.DatasetInfo, reverse_translation... method __call__ (line 40) | def __call__(self, features: Features) -> Features: function get_raw_dataset (line 46) | def get_raw_dataset( function pack_dataset (line 79) | def pack_dataset( function _pack_with_tf_ops (line 162) | def _pack_with_tf_ops( function preprocess_wmt_data (line 286) | def preprocess_wmt_data( function get_wmt_datasets (line 331) | def get_wmt_datasets( FILE: examples/wmt/input_pipeline_test.py class InputPipelineTest (line 34) | class InputPipelineTest(absltest.TestCase): method setUp (line 36) | def setUp(self): method _get_datasets (line 42) | def _get_datasets(self): method test_train_ds (line 63) | def test_train_ds(self): method test_eval_ds (line 80) | def test_eval_ds(self): method test_predict_ds (line 91) | def test_predict_ds(self): FILE: examples/wmt/main.py function main (line 44) | def main(argv): FILE: examples/wmt/models.py class TransformerConfig (line 34) | class TransformerConfig: function shift_right (line 57) | def shift_right(x, axis=1): function sinusoidal_init (line 67) | def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0): class AddPositionEmbs (line 95) | class AddPositionEmbs(nn.Module): method __call__ (line 107) | def __call__(self, inputs, inputs_positions=None): class MlpBlock (line 158) | class MlpBlock(nn.Module): method __call__ (line 170) | def __call__(self, inputs): class Encoder1DBlock (line 196) | class Encoder1DBlock(nn.Module): method __call__ (line 206) | def __call__(self, inputs, encoder_mask=None): class EncoderDecoder1DBlock (line 245) | class EncoderDecoder1DBlock(nn.Module): method __call__ (line 255) | def __call__( class Encoder (line 317) | class Encoder(nn.Module): method __call__ (line 329) | def __call__(self, inputs, inputs_positions=None, encoder_mask=None): class Decoder (line 374) | class Decoder(nn.Module): method __call__ (line 386) | def __call__( class Transformer (line 463) | class Transformer(nn.Module): method setup (line 472) | def setup(self): method encode (line 495) | def encode(self, inputs, inputs_positions=None, inputs_segmentation=No... method decode (line 526) | def decode( method __call__ (line 596) | def __call__( FILE: examples/wmt/tokenizer.py function _dump_chars_to_textfile (line 35) | def _dump_chars_to_textfile( function _train_sentencepiece (line 64) | def _train_sentencepiece( function _load_sentencepiece_tokenizer (line 123) | def _load_sentencepiece_tokenizer( function load_or_train_tokenizer (line 138) | def load_or_train_tokenizer( class TokenizeOp (line 162) | class TokenizeOp: method __call__ (line 166) | def __call__(self, features: Features) -> Features: FILE: examples/wmt/train.py class TrainState (line 50) | class TrainState(train_state.TrainState): function rsqrt_schedule (line 54) | def rsqrt_schedule( function create_learning_rate_schedule (line 77) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): function compute_weighted_cross_entropy (line 92) | def compute_weighted_cross_entropy( function compute_weighted_accuracy (line 134) | def compute_weighted_accuracy(logits, targets, weights=None): function compute_metrics (line 159) | def compute_metrics(logits, labels, weights, label_smoothing=0.0): function train_step (line 178) | def train_step( function eval_step (line 267) | def eval_step(params, batch, config, label_smoothing=0.0): function initialize_cache (line 276) | def initialize_cache(inputs, max_decode_len, config): function predict_step (line 287) | def predict_step( function pad_examples (line 344) | def pad_examples(x, desired_batch_size): function per_host_sum_pmap (line 350) | def per_host_sum_pmap(in_tree): function tohost (line 373) | def tohost(x): function evaluate (line 379) | def evaluate( function translate_and_calculate_bleu (line 401) | def translate_and_calculate_bleu( function preferred_dtype (line 455) | def preferred_dtype(config): function train_and_evaluate (line 465) | def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): FILE: examples/wmt/train_test.py class TrainTest (line 32) | class TrainTest(absltest.TestCase): method setUp (line 35) | def setUp(self): method test_train_and_evaluate (line 41) | def test_train_and_evaluate(self): FILE: flax/configurations.py class Config (line 24) | class Config: method __init__ (line 36) | def __init__(self): method _add_option (line 39) | def _add_option(self, name, default): method _read (line 44) | def _read(self, name): method update (line 51) | def update(self, name: str, value: Any, /) -> None: method update (line 55) | def update(self, holder: 'FlagHolder[_T]', value: _T, /) -> None: method update (line 58) | def update(self, name_or_holder, value, /): method __repr__ (line 73) | def __repr__(self): method temp_flip_flag (line 78) | def temp_flip_flag(self, var_name: str, var_value: bool): class FlagHolder (line 98) | class FlagHolder(Generic[_T]): method __init__ (line 99) | def __init__(self, name, help): method __bool__ (line 104) | def __bool__(self) -> NoReturn: method value (line 111) | def value(self) -> _T: function bool_flag (line 115) | def bool_flag(name: str, *, default: bool, help: str) -> FlagHolder[bool]: function int_flag (line 147) | def int_flag(name: str, *, default: int | None, help: str) -> FlagHolder... function static_bool_env (line 179) | def static_bool_env(varname: str, default: bool) -> bool: function static_int_env (line 206) | def static_int_env(varname: str, default: int | None) -> int | None: FILE: flax/core/axes_scan.py class _Broadcast (line 31) | class _Broadcast: function build_shaped_array (line 38) | def build_shaped_array(x, batch_dim: bool = False) -> core.ShapedArray: function scan (line 60) | def scan( FILE: flax/core/frozen_dict.py class FrozenKeysView (line 27) | class FrozenKeysView(collections.abc.KeysView): method __repr__ (line 30) | def __repr__(self): class FrozenValuesView (line 34) | class FrozenValuesView(collections.abc.ValuesView): method __repr__ (line 37) | def __repr__(self): function _indent (line 45) | def _indent(x, num_spaces): class FrozenDict (line 54) | class FrozenDict(Mapping[K, V]): method __init__ (line 59) | def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # py... method __getitem__ (line 69) | def __getitem__(self, key): method __setitem__ (line 75) | def __setitem__(self, key, value): method __contains__ (line 78) | def __contains__(self, key): method __iter__ (line 81) | def __iter__(self): method __len__ (line 84) | def __len__(self): method __repr__ (line 87) | def __repr__(self): method __reduce__ (line 90) | def __reduce__(self): method get (line 93) | def get(self, key, default=None): method pretty_repr (line 99) | def pretty_repr(self, num_spaces=4): method __hash__ (line 115) | def __hash__(self): method copy (line 123) | def copy( method keys (line 129) | def keys(self): method values (line 132) | def values(self): method items (line 135) | def items(self): method pop (line 139) | def pop(self, key: K) -> tuple['FrozenDict[K, V]', V]: method unfreeze (line 159) | def unfreeze(self) -> dict[K, V]: method tree_flatten_with_keys (line 167) | def tree_flatten_with_keys(self) -> tuple[tuple[Any, ...], Hashable]: method tree_unflatten (line 179) | def tree_unflatten(cls, keys, values): function _prepare_freeze (line 185) | def _prepare_freeze(xs: Any) -> Any: function freeze (line 198) | def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]: function unfreeze (line 211) | def unfreeze(x: FrozenDict | dict[str, Any]) -> dict[Any, Any]: function copy (line 237) | def copy( function pop (line 269) | def pop( function pretty_repr (line 300) | def pretty_repr(x: Any, num_spaces: int = 4) -> str: function _frozen_dict_state_dict (line 331) | def _frozen_dict_state_dict(xs): function _restore_frozen_dict (line 341) | def _restore_frozen_dict(xs, states): FILE: flax/core/lift.py class TransformContext (line 60) | class TransformContext(Generic[A], threading.local): method push (line 66) | def push(self, a: A): method get (line 73) | def get(self) -> A: function tree_map_rngs (line 77) | def tree_map_rngs(fn, tree): function _dedup_scopes (line 87) | def _dedup_scopes(scopes): function _dup_scopes (line 109) | def _dup_scopes(orig_scopes, scopes, paths): function _transpose (line 121) | def _transpose(xs): function _partial_pack (line 125) | def _partial_pack( function pack (line 281) | def pack( function map_variables (line 340) | def map_variables( function swap_collection (line 409) | def swap_collection(fn: Callable[..., Any], col_a: str, col_b: str): function _split_in_out_axes (line 421) | def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]): function _bwd_wrapper (line 428) | def _bwd_wrapper(treedef, bwd_fn, tangent): function vjp (line 434) | def vjp( function value_and_grad (line 530) | def value_and_grad( function jvp (line 619) | def jvp( function vmap (line 710) | def vmap( function scan (line 872) | def scan( function while_loop (line 1072) | def while_loop( function cond (line 1168) | def cond( function switch (line 1232) | def switch( function custom_vjp (line 1318) | def custom_vjp( function checkpoint (line 1425) | def checkpoint( function _hashable_filter (line 1510) | def _hashable_filter(x): class CountsHolder (line 1523) | class CountsHolder: method __init__ (line 1525) | def __init__(self, flat_d): method make (line 1529) | def make(cls, d): method sub (line 1534) | def sub(self, other): method add (line 1542) | def add(self, other): method unflat (line 1550) | def unflat(self): function set_from_dict (line 1554) | def set_from_dict(original, updates): class _SideEffectCache (line 1565) | class _SideEffectCache(threading.local): method __init__ (line 1567) | def __init__(self): function _restore_rng_counters (line 1574) | def _restore_rng_counters(scopes, fingerprint, capture_old_counts): function jit (line 1598) | def jit( function remat_scan (line 1716) | def remat_scan( function _unzip2 (line 1793) | def _unzip2(xs): function _broadcast_prefix_tree (line 1798) | def _broadcast_prefix_tree(prefix_tree: Any, full_tree: Any) -> list[Any]: function fold_rngs (line 1810) | def fold_rngs( FILE: flax/core/meta.py class AxisMetadata (line 39) | class AxisMetadata(Generic[A], metaclass=abc.ABCMeta): method unbox (line 58) | def unbox(self) -> A: method replace_boxed (line 75) | def replace_boxed(self, val: B) -> 'AxisMetadata[B]': method add_axis (line 88) | def add_axis( method remove_axis (line 110) | def remove_axis( function is_axis_metadata (line 132) | def is_axis_metadata(val: Any) -> bool: function map_axis_meta (line 137) | def map_axis_meta(fn: Callable[[AxisMetadata[Any]], Any], tree: Any) -> ... function add_axis (line 149) | def add_axis(tree: Any, index: int, params: dict[Any, Any]) -> Any: function remove_axis (line 154) | def remove_axis(tree: Any, index: int, params: dict[Any, Any]) -> Any: function unbox (line 159) | def unbox(tree: Any) -> Any: function replace_boxed (line 164) | def replace_boxed(tree: Any, updates: Any) -> Any: function get_global_mesh (line 181) | def get_global_mesh() -> jax.sharding.AbstractMesh | jax.sharding.Mesh |... function global_mesh_defined (line 188) | def global_mesh_defined() -> bool: class Partitioned (line 194) | class Partitioned(struct.PyTreeNode, AxisMetadata[A]): method unbox (line 256) | def unbox(self, apply_constraint=True) -> A: method replace_boxed (line 267) | def replace_boxed(self, val: B) -> 'Partitioned[B]': method _get_partition_name (line 270) | def _get_partition_name(self, params: dict[Any, Any]) -> str: method add_axis (line 275) | def add_axis(self, index: int, params: dict[Any, Any]) -> 'Partitioned... method remove_axis (line 283) | def remove_axis(self, index: int, params: dict[Any, Any]) -> 'Partitio... method get_partition_spec (line 289) | def get_partition_spec(self) -> jax.sharding.PartitionSpec: method get_sharding (line 293) | def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding: method to_nnx_metadata (line 297) | def to_nnx_metadata(self) -> dict[str, Any]: method from_nnx_metadata (line 304) | def from_nnx_metadata(cls, metadata: dict[str, Any]): function with_partitioning (line 311) | def with_partitioning( function _get_leaf_pspec (line 342) | def _get_leaf_pspec(x: Any) -> jax.sharding.PartitionSpec | None: function get_partition_spec (line 352) | def get_partition_spec(tree: Any) -> Any: function get_sharding (line 359) | def get_sharding(tree: Any, mesh: jax.sharding.Mesh) -> Any: FILE: flax/core/nn/attention.py function dot_product_attention (line 34) | def dot_product_attention( function _invert_perm (line 156) | def _invert_perm(perm): class CacheEntry (line 163) | class CacheEntry(struct.PyTreeNode): function multi_head_dot_product_attention (line 169) | def multi_head_dot_product_attention( function make_padding_mask (line 420) | def make_padding_mask( function _make_causal_mask (line 476) | def _make_causal_mask(key, attention_axis=None, self_mask=False): FILE: flax/core/nn/linear.py function _normalize_axes (line 30) | def _normalize_axes(axes, ndim): function dense_general (line 35) | def dense_general( function dense (line 134) | def dense( function _conv_dimension_numbers (line 174) | def _conv_dimension_numbers(input_shape): function conv (line 183) | def conv( function conv_transpose (line 261) | def conv_transpose( class Embedding (line 331) | class Embedding: method lookup (line 334) | def lookup(self, indices): method attend (line 348) | def attend(self, query): function embedding (line 364) | def embedding( FILE: flax/core/nn/normalization.py function _absolute_dims (line 24) | def _absolute_dims(ndim, dims): function batch_norm (line 28) | def batch_norm( function layer_norm (line 86) | def layer_norm( function group_norm (line 129) | def group_norm( FILE: flax/core/nn/stochastic.py function dropout (line 21) | def dropout(scope, inputs, rate, deterministic=False, rng=None): FILE: flax/core/partial_eval.py function _maybe_unknown (line 26) | def _maybe_unknown(x: Any) -> pe.PartialVal: function lazy_init (line 33) | def lazy_init(fn): FILE: flax/core/scope.py class DenyList (line 70) | class DenyList: method __lt__ (line 82) | def __lt__(self, other): method __gt__ (line 89) | def __gt__(self, other): class LazyRng (line 101) | class LazyRng(struct.PyTreeNode): method as_jax_rng (line 107) | def as_jax_rng(self) -> PRNGKey: method create (line 111) | def create( method clear_suffix (line 119) | def clear_suffix(self): function _fold_in_static (line 124) | def _fold_in_static( function is_filter_empty (line 157) | def is_filter_empty(filter_like: Filter) -> bool: function in_filter (line 181) | def in_filter(filter_like: Filter, col: str) -> bool: function filter_to_set (line 207) | def filter_to_set(x: Filter) -> set[str]: function union_filters (line 226) | def union_filters(a: Filter, b: Filter) -> Filter: function subtract_filters (line 251) | def subtract_filters(a: Filter, b: Filter) -> Filter: function intersect_filters (line 276) | def intersect_filters(a: Filter, b: Filter) -> Filter: function group_collections (line 302) | def group_collections( class Variable (line 335) | class Variable(Generic[T]): method __init__ (line 343) | def __init__(self, scope: 'Scope', collection: str, name: str, unbox: ... method value (line 359) | def value(self) -> T: method value (line 365) | def value(self, value: T): method is_mutable (line 379) | def is_mutable(self) -> bool: class _ChildRNGSentinel (line 384) | class _ChildRNGSentinel: class _DefaultSentinel (line 392) | class _DefaultSentinel: function _put_variable (line 402) | def _put_variable(target, key, val): class Scope (line 414) | class Scope: method __init__ (line 428) | def __init__( method __eq__ (line 470) | def __eq__(self, other: Any) -> bool: method __hash__ (line 484) | def __hash__(self) -> int: method root (line 489) | def root(self) -> 'Scope': method path_text (line 493) | def path_text(self) -> str: method invalid (line 498) | def invalid(self) -> bool: method _check_valid (line 502) | def _check_valid(self): method temporary (line 507) | def temporary(self): method invalidate (line 514) | def invalidate(self): method mutable_variables (line 518) | def mutable_variables(self) -> VariableDict | dict[str, Any]: method variables (line 528) | def variables(self) -> VariableDict | dict[str, Any]: method _validate_trace_level (line 535) | def _validate_trace_level(self): method rewound (line 538) | def rewound(self, rewind_rngs: bool = False) -> 'Scope': method name_reserved (line 563) | def name_reserved(self, name: str, col: str | None = None) -> bool: method reserve (line 581) | def reserve(self, name: str, col: str | None = None): method default_name (line 598) | def default_name(self, prefix: str) -> str: method push (line 614) | def push( method child (line 654) | def child( method is_mutable_collection (line 695) | def is_mutable_collection(self, col: str) -> bool: method is_collection_empty (line 699) | def is_collection_empty(self, col: str) -> bool: method _mutable_collection (line 705) | def _mutable_collection(self, col: str) -> MutableCollection: method _collection (line 733) | def _collection(self, col: str) -> Collection: method has_rng (line 746) | def has_rng(self, name: str) -> bool: method make_rng (line 750) | def make_rng(self, name: str = 'params') -> PRNGKey: method get_variable (line 762) | def get_variable(self, col: str, name: str, default: Any = None) -> Any: method has_variable (line 781) | def has_variable(self, col: str, name: str) -> bool: method put_variable (line 791) | def put_variable(self, col: str, name: str, value: Any): method variable (line 808) | def variable( method variable (line 818) | def variable( method variable (line 830) | def variable( method variable (line 842) | def variable( method variable (line 853) | def variable( method param (line 893) | def param( method param (line 899) | def param( method param (line 910) | def param( method param (line 921) | def param( method param (line 931) | def param( method _populate_collections (line 989) | def _populate_collections(self): method has_flag (line 994) | def has_flag(self, key) -> bool: method get_flag (line 997) | def get_flag(self, key, default=no_flag) -> Any: function _unfreeze_variables (line 1003) | def _unfreeze_variables(variables, mutable): function bind (line 1013) | def bind( function apply (line 1050) | def apply( function init (line 1103) | def init( function lazy_init (line 1137) | def lazy_init( function _is_valid_collection (line 1173) | def _is_valid_collection(col: VariableDict): function _is_valid_variables (line 1183) | def _is_valid_variables(variables: VariableDict) -> bool: function _is_valid_rng (line 1200) | def _is_valid_rng(rng: Array): function _is_valid_rngs (line 1223) | def _is_valid_rngs(rngs: PRNGKey | RNGSequences): FILE: flax/core/spmd.py function get_pspec (line 28) | def get_pspec(sharding, sharding_rules = None) -> PartitionSpec: function map_sharding (line 32) | def map_sharding(f, sharding): function get_mesh (line 40) | def get_mesh(sharding): function apply_rules (line 48) | def apply_rules(sharding, sharding_rules): function _apply_sharding (line 57) | def _apply_sharding(value, sharding, mesh): function shard_value (line 69) | def shard_value(value, out_sharding, sharding_rules, mesh): class _AxisRules (line 104) | class _AxisRules(threading.local): function set_logical_axis_rules (line 114) | def set_logical_axis_rules(rules: LogicalRules): function get_logical_axis_rules (line 119) | def get_logical_axis_rules() -> LogicalRules: function logical_axis_rules (line 125) | def logical_axis_rules(rules: LogicalRules): function composite_rules (line 135) | def composite_rules(rule1, rule2): function from_sharding_rules (line 153) | def from_sharding_rules( FILE: flax/core/tracers.py function current_trace (line 21) | def current_trace(): function check_trace_level (line 32) | def check_trace_level(base_level): FILE: flax/cursor.py class Indexable (line 34) | class Indexable(Protocol): method __getitem__ (line 35) | def __getitem__(self, key) -> Any: class AccessType (line 39) | class AccessType(enum.Enum): class ParentKey (line 45) | class ParentKey(Generic[A]): function is_named_tuple (line 51) | def is_named_tuple(obj): function _traverse_tree (line 60) | def _traverse_tree(path, obj, *, update_fn=None, cond_fn=None): class Cursor (line 120) | class Cursor(Generic[A]): method __init__ (line 125) | def __init__(self, obj: A, parent_key: ParentKey[A] | None): method _root (line 133) | def _root(self) -> 'Cursor[A]': method _path (line 140) | def _path(self) -> str: method __getitem__ (line 152) | def __getitem__(self, key) -> 'Cursor[A]': method __getattr__ (line 169) | def __getattr__(self, name) -> 'Cursor[A]': method __setitem__ (line 182) | def __setitem__(self, key, value): method __setattr__ (line 187) | def __setattr__(self, name, value): method set (line 190) | def set(self, value) -> A: method build (line 224) | def build(self) -> A: method apply_update (line 284) | def apply_update( method find (line 382) | def find(self, cond_fn: Callable[[str, Any], bool]) -> 'Cursor[A]': method find_all (line 473) | def find_all( method __str__ (line 545) | def __str__(self): method __repr__ (line 548) | def __repr__(self): method _pretty_repr (line 551) | def _pretty_repr(self, indent=2, _prefix_indent=0): method __len__ (line 579) | def __len__(self): method __iter__ (line 582) | def __iter__(self): method __reversed__ (line 591) | def __reversed__(self): method __add__ (line 600) | def __add__(self, other): method __sub__ (line 603) | def __sub__(self, other): method __mul__ (line 606) | def __mul__(self, other): method __matmul__ (line 609) | def __matmul__(self, other): method __truediv__ (line 612) | def __truediv__(self, other): method __floordiv__ (line 615) | def __floordiv__(self, other): method __mod__ (line 618) | def __mod__(self, other): method __divmod__ (line 621) | def __divmod__(self, other): method __pow__ (line 624) | def __pow__(self, other): method __lshift__ (line 627) | def __lshift__(self, other): method __rshift__ (line 630) | def __rshift__(self, other): method __and__ (line 633) | def __and__(self, other): method __xor__ (line 636) | def __xor__(self, other): method __or__ (line 639) | def __or__(self, other): method __radd__ (line 642) | def __radd__(self, other): method __rsub__ (line 645) | def __rsub__(self, other): method __rmul__ (line 648) | def __rmul__(self, other): method __rmatmul__ (line 651) | def __rmatmul__(self, other): method __rtruediv__ (line 654) | def __rtruediv__(self, other): method __rfloordiv__ (line 657) | def __rfloordiv__(self, other): method __rmod__ (line 660) | def __rmod__(self, other): method __rdivmod__ (line 663) | def __rdivmod__(self, other): method __rpow__ (line 666) | def __rpow__(self, other): method __rlshift__ (line 669) | def __rlshift__(self, other): method __rrshift__ (line 672) | def __rrshift__(self, other): method __rand__ (line 675) | def __rand__(self, other): method __rxor__ (line 678) | def __rxor__(self, other): method __ror__ (line 681) | def __ror__(self, other): method __neg__ (line 684) | def __neg__(self): method __pos__ (line 687) | def __pos__(self): method __abs__ (line 690) | def __abs__(self): method __invert__ (line 693) | def __invert__(self): method __round__ (line 696) | def __round__(self, ndigits=None): method __lt__ (line 699) | def __lt__(self, other): method __le__ (line 702) | def __le__(self, other): method __eq__ (line 705) | def __eq__(self, other): method __ne__ (line 708) | def __ne__(self, other): method __gt__ (line 711) | def __gt__(self, other): method __ge__ (line 714) | def __ge__(self, other): function cursor (line 718) | def cursor(obj: A) -> Cursor[A]: FILE: flax/errors.py class FlaxError (line 52) | class FlaxError(Exception): method __init__ (line 53) | def __init__(self, message): method __reduce__ (line 63) | def __reduce__(self): class TraceContextError (line 72) | class TraceContextError(FlaxError): class LazyInitError (line 81) | class LazyInitError(FlaxError): method __init__ (line 101) | def __init__(self, partial_val): class InvalidRngError (line 113) | class InvalidRngError(FlaxError): method __init__ (line 167) | def __init__(self, msg): class ApplyScopeInvalidVariablesTypeError (line 174) | class ApplyScopeInvalidVariablesTypeError(FlaxError): method __init__ (line 181) | def __init__(self): class ApplyScopeInvalidVariablesStructureError (line 189) | class ApplyScopeInvalidVariablesStructureError(FlaxError): method __init__ (line 196) | def __init__(self, variables): class ScopeParamNotFoundError (line 205) | class ScopeParamNotFoundError(FlaxError): method __init__ (line 228) | def __init__(self, param_name, scope_path): class ScopeCollectionNotFound (line 235) | class ScopeCollectionNotFound(FlaxError): method __init__ (line 249) | def __init__(self, col_name, var_name, scope_path): class ScopeParamShapeError (line 256) | class ScopeParamShapeError(FlaxError): method __init__ (line 284) | def __init__(self, param_name, scope_path, value_shape, init_shape): class ScopeVariableNotFoundError (line 292) | class ScopeVariableNotFoundError(FlaxError): method __init__ (line 300) | def __init__(self, name, col, scope_path): class InvalidFilterError (line 307) | class InvalidFilterError(FlaxError): method __init__ (line 310) | def __init__(self, filter_like): class InvalidScopeError (line 314) | class InvalidScopeError(FlaxError): method __init__ (line 323) | def __init__(self, scope_name): class ModifyScopeVariableError (line 327) | class ModifyScopeVariableError(FlaxError): method __init__ (line 347) | def __init__(self, col, variable_name, scope_path): class ImmutableVariableError (line 354) | class ImmutableVariableError(FlaxError): method __init__ (line 366) | def __init__(self, message): class JaxTransformError (line 370) | class JaxTransformError(FlaxError): method __init__ (line 379) | def __init__(self): class PartitioningUnspecifiedError (line 388) | class PartitioningUnspecifiedError(FlaxError): method __init__ (line 395) | def __init__(self, target): class NameInUseError (line 407) | class NameInUseError(FlaxError): method __init__ (line 456) | def __init__(self, key_type, value, module_name): class AssignSubModuleError (line 464) | class AssignSubModuleError(FlaxError): method __init__ (line 502) | def __init__(self, cls): class SetAttributeInModuleSetupError (line 509) | class SetAttributeInModuleSetupError(FlaxError): method __init__ (line 541) | def __init__(self): class SetAttributeFrozenModuleError (line 545) | class SetAttributeFrozenModuleError(FlaxError): method __init__ (line 576) | def __init__(self, module_cls, attr_name, attr_val): class MultipleMethodsCompactError (line 584) | class MultipleMethodsCompactError(FlaxError): method __init__ (line 598) | def __init__(self): class ReservedModuleAttributeError (line 602) | class ReservedModuleAttributeError(FlaxError): method __init__ (line 611) | def __init__(self, annotations): class ApplyModuleInvalidMethodError (line 617) | class ApplyModuleInvalidMethodError(FlaxError): method __init__ (line 628) | def __init__(self, method): class CallCompactUnboundModuleError (line 634) | class CallCompactUnboundModuleError(FlaxError): method __init__ (line 656) | def __init__(self): class CallSetupUnboundModuleError (line 660) | class CallSetupUnboundModuleError(FlaxError): method __init__ (line 689) | def __init__(self): class CallUnbindOnUnboundModuleError (line 693) | class CallUnbindOnUnboundModuleError(FlaxError): method __init__ (line 716) | def __init__(self): class CallShareScopeOnUnboundModuleError (line 719) | class CallShareScopeOnUnboundModuleError(FlaxError): method __init__ (line 735) | def __init__(self): class InvalidInstanceModuleError (line 738) | class InvalidInstanceModuleError(FlaxError): method __init__ (line 756) | def __init__(self): class IncorrectPostInitOverrideError (line 763) | class IncorrectPostInitOverrideError(FlaxError): method __init__ (line 784) | def __init__(self): class DescriptorAttributeError (line 790) | class DescriptorAttributeError(FlaxError): method __init__ (line 806) | def __init__(self): class InvalidCheckpointError (line 813) | class InvalidCheckpointError(FlaxError): method __init__ (line 822) | def __init__(self, path, step): class MPACheckpointingRequiredError (line 829) | class MPACheckpointingRequiredError(FlaxError): method __init__ (line 840) | def __init__(self, path, step): class MPARestoreTargetRequiredError (line 848) | class MPARestoreTargetRequiredError(FlaxError): method __init__ (line 858) | def __init__(self, path, step, key=None): class MPARestoreDataCorruptedError (line 870) | class MPARestoreDataCorruptedError(FlaxError): method __init__ (line 876) | def __init__(self, step, path): class TransformedMethodReturnValueError (line 889) | class TransformedMethodReturnValueError(FlaxError): method __init__ (line 892) | def __init__(self, name): class TransformTargetError (line 898) | class TransformTargetError(FlaxError): method __init__ (line 923) | def __init__(self, target): class AlreadyExistsError (line 936) | class AlreadyExistsError(FlaxError): method __init__ (line 943) | def __init__(self, path): class CursorFindError (line 952) | class CursorFindError(FlaxError): method __init__ (line 959) | def __init__(self, cursor=None, cursor2=None): class TraverseTreeError (line 970) | class TraverseTreeError(FlaxError): method __init__ (line 984) | def __init__(self, update_fn, cond_fn): FILE: flax/ids.py class UUIDManager (line 20) | class UUIDManager: method __init__ (line 32) | def __init__(self): method __call__ (line 36) | def __call__(self): class FlaxId (line 45) | class FlaxId: method __init__ (line 48) | def __init__(self, rawid): method __eq__ (line 51) | def __eq__(self, other): method __hash__ (line 54) | def __hash__(self): method __repr__ (line 57) | def __repr__(self): method __deepcopy__ (line 60) | def __deepcopy__(self, memo): method __copy__ (line 64) | def __copy__(self): FILE: flax/io.py class BackendMode (line 33) | class BackendMode(Enum): function override_mode (line 69) | def override_mode(override: BackendMode): function set_mode (line 85) | def set_mode(override: BackendMode): function GFile (line 97) | def GFile(name, mode): # pylint: disable=invalid-name function listdir (line 109) | def listdir(path): function isdir (line 118) | def isdir(path): function copy (line 127) | def copy(src, dst, overwrite=False): function rename (line 139) | def rename(src, dst, overwrite=False): function exists (line 150) | def exists(path): function makedirs (line 159) | def makedirs(path): function glob (line 168) | def glob(pattern): function remove (line 179) | def remove(path): function rmtree (line 189) | def rmtree(path): function getsize (line 199) | def getsize(path): FILE: flax/jax_utils.py function _pmap_device_order (line 30) | def _pmap_device_order(): function replicate (line 34) | def replicate(tree, devices=None): function unreplicate (line 48) | def unreplicate(tree): function pmean (line 68) | def pmean(xs, axis_name): function partial_eval_by_shape (line 73) | def partial_eval_by_shape(fn, input_spec, *args, **kwargs): function _parse_spec (line 114) | def _parse_spec(spec): function prefetch_to_device (line 123) | def prefetch_to_device(iterator, size, devices=None): function _scan_nd (line 168) | def _scan_nd(body_fn, init, xs, n=1, unroll=(1,)): function _invert_perm (line 190) | def _invert_perm(perm): function scan_in_dim (line 197) | def scan_in_dim(body_fn, init, xs, axis=(0,), unroll=(1,), keepdims=False): function pad_shard_unpad (line 256) | def pad_shard_unpad( FILE: flax/linen/activation.py class PReLU (line 59) | class PReLU(Module): method __call__ (line 85) | def __call__(self, inputs: Array) -> Array: FILE: flax/linen/attention.py function dot_product_attention_weights (line 47) | def dot_product_attention_weights( function dot_product_attention (line 166) | def dot_product_attention( class MultiHeadDotProductAttention (line 298) | class MultiHeadDotProductAttention(Module): method __call__ (line 409) | def __call__( method __call__ (line 423) | def __call__( method __call__ (line 436) | def __call__( class MultiHeadAttention (line 692) | class MultiHeadAttention(MultiHeadDotProductAttention): class SelfAttention (line 775) | class SelfAttention(MultiHeadDotProductAttention): method __call__ (line 787) | def __call__( # type: ignore function make_attention_mask (line 830) | def make_attention_mask( function make_causal_mask (line 862) | def make_causal_mask( function combine_masks (line 890) | def combine_masks( FILE: flax/linen/batch_apply.py function ndim_at_least (line 21) | def ndim_at_least(x, num_dims): function arbitrary_mergeable_leaf (line 26) | def arbitrary_mergeable_leaf(min_num_dims, args, kwargs): function merge_leading_dims (line 36) | def merge_leading_dims(x, num_dims): function split_leading_dim (line 45) | def split_leading_dim(x, to_dim): class BatchApply (line 49) | class BatchApply: method __init__ (line 86) | def __init__(self, f, num_dims=2): method __call__ (line 96) | def __call__(self, *args, **kwargs): FILE: flax/linen/combinators.py class Sequential (line 23) | class Sequential(Module): method __post_init__ (line 94) | def __post_init__(self): method __call__ (line 102) | def __call__(self, *args, **kwargs): FILE: flax/linen/dtypes.py function canonicalize_dtype (line 22) | def canonicalize_dtype( function promote_dtype (line 54) | def promote_dtype(*args, dtype=None, inexact=True) -> list[Any]: FILE: flax/linen/experimental/layers_with_named_axes.py class Dense (line 45) | class Dense(nn.Module): method __call__ (line 76) | def __call__(self, inputs: Array) -> Array: class Embed (line 120) | class Embed(nn.Module): method setup (line 147) | def setup(self): method __call__ (line 156) | def __call__(self, inputs: Array) -> Array: method attend (line 179) | def attend(self, query: Array) -> Array: function _canonicalize_axes (line 196) | def _canonicalize_axes(rank: int, axes: Axes) -> Sequence[int]: function _abs_sq (line 203) | def _abs_sq(x): function _compute_stats (line 211) | def _compute_stats(x: Array, axes: Axes): function _normalize (line 234) | def _normalize( class LayerNorm (line 282) | class LayerNorm(nn.Module): method __call__ (line 317) | def __call__(self, x): FILE: flax/linen/fp8_ops.py class Fp8MetaTyRules (line 43) | class Fp8MetaTyRules: method physical_element_aval (line 46) | def physical_element_aval(dtype) -> core.ShapedArray: method replicate_trailing_dims (line 51) | def replicate_trailing_dims(ctx, val, aval): method logical_sharding (line 56) | def logical_sharding(aval, phys_sharding): method physical_sharding (line 60) | def physical_sharding(aval, sharding): method convert_from (line 65) | def convert_from(fp8_meta_dtype, other_dtype) -> bool: method convert_to (line 69) | def convert_to(other_dtype, fp8_meta_dtype) -> bool: method add (line 74) | def add(dt, x, y): method zero (line 80) | def zero(dt): method tangent_dtype (line 86) | def tangent_dtype(dtype): method full (line 90) | def full(shape, fill_value, dtype): method global_sharded_result_handler (line 96) | def global_sharded_result_handler(aval, out_sharding, committed): class fp8_meta_dtype (line 108) | class fp8_meta_dtype(dtypes.extended): pass class fp8_meta_dtype_wrapper (line 112) | class fp8_meta_dtype_wrapper(dtypes.ExtendedDType): method __repr__ (line 117) | def __repr__(self) -> str: function get_fp8_max (line 125) | def get_fp8_max(fp8_dtype, out_dtype): function quantize (line 130) | def quantize(x, q_dtype, scale, compute_dtype): function dequantize (line 139) | def dequantize(x, dq_dtype, scale): function qdq (line 142) | def qdq(x, q_dtype, scale, compute_dtype): function compute_scale (line 147) | def compute_scale(amax, scale, fp8_max, margin=0): function compute_amax_history (line 161) | def compute_amax_history(x, amax_history): function update_fp8_meta (line 167) | def update_fp8_meta( function quantize_dequantize_update (line 188) | def quantize_dequantize_update(x, q_dtype, scale, amax_history, compute_... function _fm32_to_float32 (line 193) | def _fm32_to_float32(value): function dot_general_transpose_lhs (line 198) | def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, function dot_general_transpose_rhs (line 232) | def dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, function in_qdq (line 243) | def in_qdq(compute_dtype, q_dtype, inp, scale, amax_history): function in_qdq_fwd (line 250) | def in_qdq_fwd(compute_dtype, q_dtype, inp, scale, amax_history): function in_qdq_bwd (line 257) | def in_qdq_bwd(compute_dtype, q_dtype, res, g): function out_qdq (line 267) | def out_qdq(compute_dtype, q_dtype, out, scale, amax_history): function out_qdq_fwd (line 271) | def out_qdq_fwd(compute_dtype, q_dtype, out, scale, amax_history): function out_qdq_bwd (line 275) | def out_qdq_bwd(compute_dtype, q_dtype, res, g): function in_q (line 287) | def in_q(compute_dtype, q_dtype, inp, scale, amax_history): function in_q_fwd (line 292) | def in_q_fwd(compute_dtype, q_dtype, inp, scale, amax_history): function in_q_bwd (line 297) | def in_q_bwd(compute_dtype, q_dtype, res, _): function out_dq (line 306) | def out_dq(dq_type, lhs_scale, rhs_scale, out): function out_dq_fwd (line 314) | def out_dq_fwd(dq_type, lhs_scale, rhs_scale, out): function out_dq_bwd (line 317) | def out_dq_bwd(dq_type, _, g): function quantized_dot (line 323) | def quantized_dot( function quantized_dot_fwd (line 343) | def quantized_dot_fwd( function quantized_dot_bwd (line 374) | def quantized_dot_bwd( function fp8_scaled_dot_general (line 450) | def fp8_scaled_dot_general( function dot_general_with_precision (line 497) | def dot_general_with_precision( function dot_general_with_precision_jvp (line 512) | def dot_general_with_precision_jvp( function _parse_dot_inputs (line 529) | def _parse_dot_inputs(*args, **kwargs): class Fp8DotGeneralBase (line 542) | class Fp8DotGeneralBase(module.Module): method setup (line 547) | def setup(self) -> None: class Fp8DotGeneralOp (line 582) | class Fp8DotGeneralOp(Fp8DotGeneralBase): method __post_init__ (line 583) | def __post_init__(self): method __call__ (line 592) | def __call__(self, *args, **kwargs): class Fp8DirectDotGeneralOp (line 614) | class Fp8DirectDotGeneralOp(Fp8DotGeneralBase): method __call__ (line 615) | def __call__(self, *args, **kwargs): class NANOOFp8DotGeneralOp (line 637) | class NANOOFp8DotGeneralOp(Fp8DotGeneralOp): class Fp8Einsum (line 641) | class Fp8Einsum(Fp8DotGeneralBase): method __call__ (line 643) | def __call__(self, eqn, lhs: jnp.ndarray, rhs: jnp.ndarray, FILE: flax/linen/initializers.py function zeros_init (line 43) | def zeros_init() -> Initializer: function ones_init (line 56) | def ones_init() -> Initializer: FILE: flax/linen/kw_only_dataclasses.py class _KwOnlyType (line 70) | class _KwOnlyType: method __repr__ (line 73) | def __repr__(self): function field (line 80) | def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs): function dataclass (line 106) | def dataclass(cls=None, extra_fields=None, **kwargs): function _process_class (line 129) | def _process_class(cls: type[M], extra_fields=None, **kwargs): FILE: flax/linen/linear.py class PromoteDtypeFn (line 44) | class PromoteDtypeFn(Protocol): method __call__ (line 45) | def __call__( function _normalize_axes (line 52) | def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]: function _canonicalize_tuple (line 57) | def _canonicalize_tuple(x: Sequence[int] | int) -> tuple[int, ...]: class DenseGeneral (line 64) | class DenseGeneral(Module): method __call__ (line 118) | def __call__(self, inputs: Array) -> Array: class Dense (line 214) | class Dense(Module): method __call__ (line 254) | def __call__(self, inputs: Array) -> Array: class Einsum (line 298) | class Einsum(Module): method __call__ (line 343) | def __call__(self, inputs: Array, einsum_str: str | None = None) -> Ar... method _get_bias_shape (line 402) | def _get_bias_shape(self, einsum_str: str, lhs: Array, rhs: Array): function _conv_dimension_numbers (line 434) | def _conv_dimension_numbers(input_shape): function canonicalize_padding (line 443) | def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: class _Conv (line 467) | class _Conv(Module): method shared_weights (line 529) | def shared_weights(self) -> bool: # type: ignore method __call__ (line 541) | def __call__(self, inputs: Array) -> Array: class Conv (line 734) | class Conv(_Conv): method shared_weights (line 801) | def shared_weights(self) -> bool: class ConvLocal (line 805) | class ConvLocal(_Conv): method shared_weights (line 872) | def shared_weights(self) -> bool: class ConvTranspose (line 876) | class ConvTranspose(Module): method __call__ (line 955) | def __call__(self, inputs: Array) -> Array: class Embed (line 1110) | class Embed(Module): method setup (line 1165) | def setup(self): method __call__ (line 1173) | def __call__(self, inputs: Array) -> Array: method attend (line 1196) | def attend(self, query: Array) -> Array: FILE: flax/linen/module.py function _get_fn_name (line 86) | def _get_fn_name(fn): function _indent (line 92) | def _indent(x: str, num_spaces: int): function _attr_repr (line 100) | def _attr_repr(value: Any): function _module_repr (line 111) | def _module_repr(module: 'Module', num_spaces: int = 4): class _CallInfo (line 153) | class _CallInfo: class _CallInfoContext (line 166) | class _CallInfoContext(threading.local): method get_call_index (line 170) | def get_call_index(self) -> int: function _tabulate_context (line 177) | def _tabulate_context(): class _DynamicContext (line 187) | class _DynamicContext(threading.local): method __init__ (line 193) | def __init__(self): class _Sentinel (line 205) | class _Sentinel: method __copy__ (line 206) | def __copy__(self): method __deepcopy__ (line 209) | def __deepcopy__(self, memo): method __reduce__ (line 213) | def __reduce__(self): function _get_unspecified_parent (line 217) | def _get_unspecified_parent(): function _derive_profiling_name (line 229) | def _derive_profiling_name(module, fn): function enable_named_call (line 236) | def enable_named_call(): function disable_named_call (line 251) | def disable_named_call(): function override_named_call (line 261) | def override_named_call(enable: bool = True): class InterceptorContext (line 282) | class InterceptorContext: class ThreadLocalStack (line 297) | class ThreadLocalStack(threading.local): method __init__ (line 300) | def __init__(self): method push (line 303) | def push(self, elem: Any) -> None: method pop (line 306) | def pop(self) -> Any: method __iter__ (line 309) | def __iter__(self) -> Iterator[Any]: method __len__ (line 312) | def __len__(self) -> int: method __repr__ (line 315) | def __repr__(self) -> str: function intercept_methods (line 327) | def intercept_methods(interceptor: Interceptor): function run_interceptors (line 395) | def run_interceptors( function _sorted_items (line 426) | def _sorted_items(x): function _get_suffix_value_pairs (line 431) | def _get_suffix_value_pairs( function _map_over_modules_in_tree (line 443) | def _map_over_modules_in_tree(fn, tree_or_leaf): function _freeze_attr (line 458) | def _freeze_attr(val: Any) -> Any: function compact (line 477) | def compact(fun: _CallableT) -> _CallableT: function nowrap (line 505) | def nowrap(fun: _CallableT) -> _CallableT: function compact_name_scope (line 548) | def compact_name_scope(fun: _CallableT) -> _CallableT: function _get_local_method_names (line 629) | def _get_local_method_names( function _get_local_descriptor_names (line 652) | def _get_local_descriptor_names( function wrap_method_once (line 677) | def wrap_method_once(fun: Callable[..., Any]) -> Callable[..., Any]: function wrap_descriptor_once (line 707) | def wrap_descriptor_once(descriptor) -> 'DescriptorWrapper': function _wrap_hash (line 723) | def _wrap_hash(hash_fn: Callable[..., Any]) -> Callable[..., Any]: function _get_unbound_fn (line 743) | def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., A... function _map_submodules (line 772) | def _map_submodules(fn: Callable[['Module'], Any], tree): class SetupState (line 778) | class SetupState(enum.IntEnum): class _ModuleInternalState (line 788) | class _ModuleInternalState: method reset (line 805) | def reset(self) -> None: method export (line 815) | def export(self) -> '_ModuleInternalState': method reimport (line 829) | def reimport(self, other: '_ModuleInternalState') -> None: class ParentDescriptor (line 861) | class ParentDescriptor: method __get__ (line 874) | def __get__(self, obj, objtype=None): method __set__ (line 881) | def __set__(self, obj, value): class Descriptor (line 886) | class Descriptor(tpe.Protocol): method __get__ (line 889) | def __get__(self, obj, objtype=None) -> Any: method __set__ (line 892) | def __set__(self, obj, value) -> None: method __delete__ (line 895) | def __delete__(self, obj) -> None: method __set_name__ (line 898) | def __set_name__(self, owner, name) -> None: class DescriptorWrapper (line 902) | class DescriptorWrapper: function create_descriptor_wrapper (line 906) | def create_descriptor_wrapper(descriptor: Descriptor): function module_field (line 956) | def module_field(*, kw_only: bool = False, default: Any | None = ...) ->... class ModuleBase (line 975) | class ModuleBase: class Module (line 983) | class Module(ModuleBase): method __init__ (line 1025) | def __init__(self, *args, **kwargs): method __call__ (line 1029) | def __call__(self, *args, **kwargs) -> Any: method __init_subclass__ (line 1034) | def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None: method _customized_dataclass_transform (line 1055) | def _customized_dataclass_transform(cls, kw_only: bool): method _find_compact_name_scope_methods (line 1127) | def _find_compact_name_scope_methods(cls): method _wrap_module_attributes (line 1138) | def _wrap_module_attributes(cls): method _call_wrapped_method (line 1170) | def _call_wrapped_method(self, fun, args, kwargs): method __setattr__ (line 1255) | def __setattr__(self, name: str, val: Any): method __getattr__ (line 1302) | def __getattr__(self, name: str) -> Any: method __dir__ (line 1319) | def __dir__(self) -> list[str]: method __post_init__ (line 1324) | def __post_init__(self) -> None: method __repr__ (line 1392) | def __repr__(self) -> str: method setup (line 1395) | def setup(self) -> None: method _register_submodules (line 1430) | def _register_submodules(self, name, val): method _try_setup (line 1481) | def _try_setup(self, shallow: bool = False) -> None: method _validate_setup (line 1515) | def _validate_setup(self) -> None: method _name_taken (line 1525) | def _name_taken( method _initialization_allowed (line 1537) | def _initialization_allowed(self): method path (line 1545) | def path(self): method clone (line 1577) | def clone( method copy (line 1653) | def copy( method variable (line 1677) | def variable( method variable (line 1687) | def variable( method variable (line 1699) | def variable( method variable (line 1711) | def variable( method variable (line 1722) | def variable( method param (line 1787) | def param( method param (line 1793) | def param( method param (line 1804) | def param( method param (line 1815) | def param( method param (line 1824) | def param( method has_variable (line 1885) | def has_variable(self, col: str, name: str) -> bool: method is_mutable_collection (line 1902) | def is_mutable_collection(self, col: str) -> bool: method has_rng (line 1908) | def has_rng(self, name: str) -> bool: method make_rng (line 1914) | def make_rng(self, name: str = 'params') -> PRNGKey: method is_initializing (line 1957) | def is_initializing(self) -> bool: method _module_checks (line 1971) | def _module_checks(self): method bind (line 1982) | def bind( method unbind (line 2043) | def unbind(self: M) -> tuple[M, VariableDict]: method apply (line 2092) | def apply( method init_with_output (line 2252) | def init_with_output( method init (line 2316) | def init( method lazy_init (line 2467) | def lazy_init( method variables (line 2518) | def variables(self) -> VariableDict: method get_variable (line 2524) | def get_variable(self, col: str, name: str, default: T | None = None) ... method put_variable (line 2541) | def put_variable(self, col: str, name: str, value: Any): method sow (line 2554) | def sow(self, col: str, name: str, value: Any) -> bool: method sow (line 2558) | def sow( method sow (line 2568) | def sow( method perturb (line 2655) | def perturb( method tabulate (line 2727) | def tabulate( method module_paths (line 2857) | def module_paths( function merge_param (line 2923) | def merge_param(name: str, a: T | None, b: T | None) -> T: function apply (line 2968) | def apply( function init_with_output (line 3038) | def init_with_output( function init (line 3109) | def init( class CompactNameScope (line 3176) | class CompactNameScope(Module): method __call__ (line 3181) | def __call__(self, *args, **kwargs) -> Any: method __call__ (line 3191) | def __call__(self, *args, **kwargs) -> Any: class CompactNameScope (line 3186) | class CompactNameScope: method __call__ (line 3181) | def __call__(self, *args, **kwargs) -> Any: method __call__ (line 3191) | def __call__(self, *args, **kwargs) -> Any: function share_scope (line 3195) | def share_scope(module: Module, other: Module, /): FILE: flax/linen/normalization.py function _canonicalize_axes (line 45) | def _canonicalize_axes(rank: int, axes: Axes) -> tuple[int, ...]: function _abs_sq (line 52) | def _abs_sq(x): function _compute_stats (line 60) | def _compute_stats( function _normalize (line 154) | def _normalize( function _l2_normalize (line 229) | def _l2_normalize(x, axis=None, eps=1e-12): class BatchNorm (line 247) | class BatchNorm(Module): method __call__ (line 324) | def __call__( class LayerNorm (line 424) | class LayerNorm(Module): method __call__ (line 501) | def __call__(self, x, *, mask: jax.Array | None = None): class RMSNorm (line 541) | class RMSNorm(Module): method __call__ (line 601) | def __call__(self, x, *, mask: jax.Array | None = None): class GroupNorm (line 642) | class GroupNorm(Module): method __call__ (line 727) | def __call__(self, x, *, mask: jax.Array | None = None): class InstanceNorm (line 822) | class InstanceNorm(Module): method __call__ (line 901) | def __call__(self, x, *, mask: jax.Array | None = None): class SpectralNorm (line 947) | class SpectralNorm(Module): method __call__ (line 1069) | def __call__(self, *args, update_stats: bool, **kwargs): method _spectral_normalize (line 1104) | def _spectral_normalize(self, path, vs, update_stats): class WeightNorm (line 1184) | class WeightNorm(Module): method __call__ (line 1312) | def __call__(self, *args, **kwargs): method _l2_normalize (line 1339) | def _l2_normalize(self, path, vs): FILE: flax/linen/partitioning.py class AxisMetadata (line 85) | class AxisMetadata: function _param_with_axes_sow_reduce_fn (line 91) | def _param_with_axes_sow_reduce_fn(x, y): function param_with_axes (line 123) | def param_with_axes( class PartitionedVariable (line 174) | class PartitionedVariable(flax.core.scope.Variable): method __init__ (line 184) | def __init__( method value (line 208) | def value(self): method value (line 216) | def value(self, value): function _core_variable_with_axes (line 223) | def _core_variable_with_axes( function variable_with_axes (line 245) | def variable_with_axes( function get_axis_names (line 305) | def get_axis_names(axes_metadata): function _tree_map_axes (line 338) | def _tree_map_axes(fn, tree): function _is_mutable (line 346) | def _is_mutable(axis_col: str) -> bool: function _add_axis_to_metadata (line 369) | def _add_axis_to_metadata(fn, axis_pos, axis_name, axis_col='params_axes'): function scan_with_axes (line 416) | def scan_with_axes( function vmap_with_axes (line 472) | def vmap_with_axes( function core_remat_static (line 526) | def core_remat_static( function remat (line 583) | def remat( FILE: flax/linen/pooling.py function pool (line 22) | def pool(inputs, init, reduce_fn, window_shape, strides, padding): function avg_pool (line 79) | def avg_pool( function max_pool (line 110) | def max_pool(inputs, window_shape, strides=None, padding='VALID'): function min_pool (line 128) | def min_pool(inputs, window_shape, strides=None, padding='VALID'): FILE: flax/linen/recurrent.py class RNNCellBase (line 57) | class RNNCellBase(Module): method initialize_carry (line 61) | def initialize_carry( method num_feature_axes (line 76) | def num_feature_axes(self) -> int: class LSTMCell (line 81) | class LSTMCell(RNNCellBase): method __call__ (line 135) | def __call__(self, carry, inputs): method initialize_carry (line 176) | def initialize_carry( method num_feature_axes (line 195) | def num_feature_axes(self) -> int: class DenseParams (line 199) | class DenseParams(Module): method __call__ (line 210) | def __call__(self, inputs: Array) -> tuple[Array, Array | None]: class OptimizedLSTMCell (line 224) | class OptimizedLSTMCell(RNNCellBase): method __call__ (line 282) | def __call__( method initialize_carry (line 365) | def initialize_carry( method num_feature_axes (line 385) | def num_feature_axes(self) -> int: class SimpleCell (line 389) | class SimpleCell(RNNCellBase): method __call__ (line 446) | def __call__(self, carry, inputs): method initialize_carry (line 484) | def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): method num_feature_axes (line 499) | def num_feature_axes(self) -> int: class GRUCell (line 503) | class GRUCell(RNNCellBase): method __call__ (line 555) | def __call__(self, carry, inputs): method initialize_carry (line 598) | def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): method num_feature_axes (line 613) | def num_feature_axes(self) -> int: class MGUCell (line 617) | class MGUCell(RNNCellBase): method __call__ (line 686) | def __call__(self, carry, inputs): method initialize_carry (line 733) | def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): method num_feature_axes (line 748) | def num_feature_axes(self) -> int: class ConvLSTMCell (line 752) | class ConvLSTMCell(RNNCellBase): method __call__ (line 814) | def __call__(self, carry, inputs): method initialize_carry (line 858) | def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): method num_feature_axes (line 878) | def num_feature_axes(self) -> int: class RNN (line 882) | class RNN(Module): method __call__ (line 1016) | def __call__( function _select_last_carry (line 1166) | def _select_last_carry(sequence: A, seq_lengths: jnp.ndarray) -> A: function _expand_dims_like (line 1175) | def _expand_dims_like(x, target): function flip_sequences (line 1180) | def flip_sequences( function _concatenate (line 1241) | def _concatenate(a: Array, b: Array) -> Array: class RNNBase (line 1246) | class RNNBase(Protocol): method __call__ (line 1247) | def __call__( class Bidirectional (line 1262) | class Bidirectional(Module): method __call__ (line 1282) | def __call__( FILE: flax/linen/spmd.py class _UnassignedAxis (line 52) | class _UnassignedAxis: method __repr__ (line 55) | def __repr__(self): method __bool__ (line 58) | def __bool__(self): function _mesh_assignment_free (line 65) | def _mesh_assignment_free(new_assignment, existing_assignments): function _logical_to_mesh_axes (line 76) | def _logical_to_mesh_axes( function logical_to_mesh_axes (line 114) | def logical_to_mesh_axes( function logical_to_mesh (line 161) | def logical_to_mesh(tree: Any, rules: LogicalRules | None = None) -> Any: function logical_to_mesh_sharding (line 170) | def logical_to_mesh_sharding( class RulesFallback (line 183) | class RulesFallback(enum.Enum): function _with_sharding_constraint (line 191) | def _with_sharding_constraint( function _with_sharding_constraint_one_fallback (line 206) | def _with_sharding_constraint_one_fallback( function _is_axis_spec (line 231) | def _is_axis_spec(x): function _is_logical_spec (line 239) | def _is_logical_spec(x): function with_logical_constraint (line 245) | def with_logical_constraint( class LogicallyPartitioned (line 276) | class LogicallyPartitioned(meta.Partitioned): method __eq__ (line 283) | def __eq__(self, other): method unbox (line 290) | def unbox(self, apply_constraint=True) -> Any: method to_nnx_metadata (line 302) | def to_nnx_metadata(self) -> dict[str, Any]: method from_nnx_metadata (line 312) | def from_nnx_metadata(cls, metadata: dict[str, Any]): function with_logical_partitioning (line 320) | def with_logical_partitioning( FILE: flax/linen/stochastic.py class Dropout (line 26) | class Dropout(Module): method __call__ (line 69) | def __call__( FILE: flax/linen/summary.py class _ValueRepresentation (line 52) | class _ValueRepresentation(ABC): method render (line 56) | def render(self) -> str: class _ArrayRepresentation (line 61) | class _ArrayRepresentation(_ValueRepresentation): method from_array (line 66) | def from_array(cls, x: Array) -> '_ArrayRepresentation': method render_array (line 70) | def render_array(cls, x) -> str: method render (line 73) | def render(self): class _PartitionedArrayRepresentation (line 79) | class _PartitionedArrayRepresentation(_ValueRepresentation): method from_partitioned (line 84) | def from_partitioned( method render (line 91) | def render(self): class _ObjectRepresentation (line 96) | class _ObjectRepresentation(_ValueRepresentation): method render (line 99) | def render(self): class Row (line 104) | class Row: method __post_init__ (line 134) | def __post_init__(self): method size_and_bytes (line 140) | def size_and_bytes( class Table (line 153) | class Table(list[Row]): method __init__ (line 163) | def __init__( function tabulate (line 174) | def tabulate( function _get_flops (line 327) | def _get_flops(fn, *args, **kwargs): function _get_call_flops (line 336) | def _get_call_flops( function _get_module_table (line 425) | def _get_module_table( function _get_module_variables (line 491) | def _get_module_variables( function _get_path_variables (line 521) | def _get_path_variables( function _process_inputs (line 543) | def _process_inputs(args, kwargs) -> Any: function _render_table (line 559) | def _render_table( function _summary_tree_map (line 659) | def _summary_tree_map(f, tree, *rest): function _size_and_bytes_repr (line 663) | def _size_and_bytes_repr(size: int, num_bytes: int) -> str: function _size_and_bytes (line 670) | def _size_and_bytes(pytree: Any) -> tuple[int, int]: function _get_rich_repr (line 679) | def _get_rich_repr(obj, console_kwargs): function _as_yaml_str (line 686) | def _as_yaml_str(value) -> str: function _normalize_structure (line 702) | def _normalize_structure(obj): function _bytes_repr (line 723) | def _bytes_repr(num_bytes): function _get_value_representation (line 737) | def _get_value_representation(x: Any) -> _ValueRepresentation: function _from_value_representation (line 750) | def _from_value_representation(x: _ValueRepresentation) -> Any: function _represent_tree (line 765) | def _represent_tree(x): function _maybe_render (line 775) | def _maybe_render(x): FILE: flax/linen/transforms.py function clean_clone (line 70) | def clean_clone(x): class VariablePlaceholder (line 81) | class VariablePlaceholder: class InstancePlaceholder (line 91) | class InstancePlaceholder: function _memoize_by_id (line 99) | def _memoize_by_id(fn, refs): function get_module_scopes (line 118) | def get_module_scopes(module, args=None, kwargs=None): function set_module_scopes (line 195) | def set_module_scopes(module, args, kwargs, scopes): function _test_transformed_return_values (line 285) | def _test_transformed_return_values(tree, method_name): function module_class_lift_transform (line 299) | def module_class_lift_transform( function decorator_lift_transform (line 385) | def decorator_lift_transform( class _HashableProxy (line 439) | class _HashableProxy: method from_module (line 450) | def from_module(cls, module: Module) -> '_HashableProxy': method __hash__ (line 455) | def __hash__(self): method __eq__ (line 458) | def __eq__(self, other): method module (line 462) | def module(self): function _module_fingerprint (line 466) | def _module_fingerprint(module: Module) -> tuple[type[Any], Any]: function _fingerprint_recursive (line 470) | def _fingerprint_recursive( function _check_field_is_hashable (line 551) | def _check_field_is_hashable(path: tuple[str, ...], x: Any): function decorator_lift_transform_cached (line 560) | def decorator_lift_transform_cached(transform, class_fn, **trafo_kwargs): function fork_rngs (line 643) | def fork_rngs(module: Module): function module_class_lift_transform_cached (line 660) | def module_class_lift_transform_cached( function _is_module_class (line 762) | def _is_module_class(target: TransformTarget) -> bool: function lift_transform (line 771) | def lift_transform( function lift_transform_cached (line 789) | def lift_transform_cached( function lift_direct_transform (line 807) | def lift_direct_transform( function vmap (line 834) | def vmap( function jit (line 927) | def jit( function checkpoint (line 997) | def checkpoint( function remat_scan (line 1087) | def remat_scan( function scan (line 1153) | def scan( function map_variables (line 1339) | def map_variables( function vjp (line 1417) | def vjp( function value_and_grad (line 1502) | def value_and_grad( function grad (line 1591) | def grad( function jvp (line 1668) | def jvp( function while_loop (line 1763) | def while_loop( function _cond_wrapper (line 1833) | def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs): function cond (line 1839) | def cond( function _switch_wrapper (line 1901) | def _switch_wrapper(*args, variables, rngs, n_branches): function switch (line 1911) | def switch( function _custom_vjp_single_scope_fn (line 2002) | def _custom_vjp_single_scope_fn( function custom_vjp (line 2015) | def custom_vjp( function named_call (line 2099) | def named_call(class_fn, force=True): function add_metadata_axis (line 2124) | def add_metadata_axis( function fold_rngs (line 2166) | def fold_rngs( FILE: flax/metrics/tensorboard.py function _flatten_dict (line 29) | def _flatten_dict(input_dict, parent_key='', sep='.'): function _as_default (line 68) | def _as_default(summary_writer: tf.summary.SummaryWriter, auto_flush: bo... class SummaryWriter (line 82) | class SummaryWriter: method __init__ (line 85) | def __init__(self, log_dir, auto_flush=True): method close (line 102) | def close(self): method flush (line 109) | def flush(self): method scalar (line 112) | def scalar(self, tag, value, step): method image (line 124) | def image(self, tag, image, step, max_outputs=3): method audio (line 156) | def audio(self, tag, audiodata, step, sample_rate=44100, max_outputs=3): method histogram (line 186) | def histogram(self, tag, values, step, bins=None): method text (line 200) | def text(self, tag, textdata, step): method write (line 214) | def write(self, tag, tensor, step, metadata=None): method hparams (line 229) | def hparams(self, hparams): FILE: flax/nnx/__init__.py function __getattr__ (line 230) | def __getattr__(name): FILE: flax/nnx/bridge/interop.py function nnx_in_bridge_mdl (line 26) | def nnx_in_bridge_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Mod... function linen_in_bridge_mdl (line 69) | def linen_in_bridge_mdl(linen_module: nn_module.Module, FILE: flax/nnx/bridge/module.py class ModuleStackEntry (line 45) | class ModuleStackEntry: class ModuleContext (line 54) | class ModuleContext(threading.local): class ModuleState (line 63) | class ModuleState(statelib.State): class Scope (line 70) | class Scope(Pytree): method __init__ (line 71) | def __init__(self, rngs: rnglib.Rngs, mutable: CollectionFilter): method copy (line 75) | def copy(self): class _HasSetup (line 79) | class _HasSetup(tp.Protocol): method setup (line 80) | def setup(self) -> None: ... function has_setup (line 83) | def has_setup(x: tp.Any) -> tp.TypeGuard[_HasSetup]: function _maybe_call_setup (line 87) | def _maybe_call_setup(module: Module): function _bind_module (line 104) | def _bind_module(parent: Module, module: Module) -> Module: function current_context (line 115) | def current_context() -> ModuleStackEntry | None: function current_module (line 119) | def current_module() -> Module | None: function _auto_submodule_name (line 127) | def _auto_submodule_name(parent_ctx, cls): class ModuleMeta (line 134) | class ModuleMeta(nnx_module.ModuleMeta): method _pytree_meta_construct (line 136) | def _pytree_meta_construct(cls, self, *args, **kwargs): function _module_meta_call (line 141) | def _module_meta_call(cls: type[M], *args, **kwargs) -> M: class AttrPriority (line 184) | class AttrPriority(enum.IntEnum): class PriorityStr (line 191) | class PriorityStr(str): method __new__ (line 194) | def __new__(cls, priority: AttrPriority, value: str): method _check_and_get_priority (line 199) | def _check_and_get_priority(self, other) -> AttrPriority: method __lt__ (line 208) | def __lt__(self, other) -> bool: method __gt__ (line 214) | def __gt__(self, other) -> bool: class ModuleBase (line 220) | class ModuleBase: class Module (line 227) | class Module(nnx_module.Module, ModuleBase, metaclass=ModuleMeta): method __init_subclass__ (line 228) | def __init_subclass__(cls) -> None: method __getattribute__ (line 234) | def __getattribute__(self, name: str): method _getattr (line 237) | def _getattr(self, name: str) -> tp.Any: method _setattr (line 243) | def _setattr(self, name: str, value: tp.Any) -> None: method _graph_node_flatten (line 255) | def _graph_node_flatten(self): method set_attr_priority (line 264) | def set_attr_priority(self, name: str, value: AttrPriority): method make_rng (line 267) | def make_rng(self, name: str = 'default') -> jax.Array: method param (line 272) | def param( # type: ignore[invalid-annotation] method variable (line 322) | def variable( # type: ignore[invalid-annotation] method _get_variables (line 376) | def _get_variables(self) -> tp.Mapping: method variables (line 411) | def variables(self): method apply (line 415) | def apply( method init (line 502) | def init( method init_with_output (line 520) | def init_with_output( method is_initializing (line 539) | def is_initializing(self) -> bool: function compact (line 543) | def compact(f: F) -> F: function _get_unbound_fn (line 561) | def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable: FILE: flax/nnx/bridge/variables.py function sort_variable_types (line 31) | def sort_variable_types(types: tp.Iterable[type]): class NNXMeta (line 43) | class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]): method unbox (line 50) | def unbox(self) -> A: method replace_boxed (line 53) | def replace_boxed(self, val: B) -> 'NNXMeta[B]': method add_axis (line 56) | def add_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': method remove_axis (line 60) | def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[... method get_partition_spec (line 64) | def get_partition_spec(self) -> jax.sharding.PartitionSpec: method to_nnx_variable (line 71) | def to_nnx_variable(self) -> variablelib.Variable: function is_vanilla_variable (line 75) | def is_vanilla_variable(vs: variablelib.Variable) -> bool: function to_linen_var (line 89) | def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata: function get_col_name (line 101) | def get_col_name(keypath: tp.Sequence[Any]) -> str: function to_nnx_var (line 108) | def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variablelib.Vari... function _recursive_merge (line 123) | def _recursive_merge(dict1, dict2): function linen_vars_to_nnx_attrs (line 130) | def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str... function nnx_attrs_to_linen_vars (line 152) | def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: function with_partitioning (line 167) | def with_partitioning( FILE: flax/nnx/bridge/wrappers.py class Functional (line 41) | class Functional(tp.Generic[M]): method init (line 47) | def init(self, *, rngs: tp.Optional[Rngs] = None) -> State: method apply (line 56) | def apply(self, *states: tp.Any): function functional (line 61) | def functional(cls: tp.Type[M]) -> tp.Callable[..., Functional[M]]: function _set_initializing (line 68) | def _set_initializing(module: Module, initializing: bool): function lazy_init (line 74) | def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs): function current_linen_module (line 92) | def current_linen_module() -> linen.Module | None: class ToNNX (line 98) | class ToNNX(Module): method __init__ (line 128) | def __init__( method rngs (line 144) | def rngs(self) -> Rngs | None: method module (line 152) | def module(self) -> linen.Module: method _setattr (line 159) | def _setattr(self, name, value): method lazy_init (line 164) | def lazy_init(self, *args, **kwargs): method __getattr__ (line 168) | def __getattr__(self, name: str): method __call__ (line 178) | def __call__( function linen_rngs_dict (line 248) | def linen_rngs_dict(linen_module: linen.Module, add_default: bool = False): function _get_module_method (line 259) | def _get_module_method(module, method: tp.Callable[..., Any] | str | None): class ToLinen (line 281) | class ToLinen(linen.Module): method __call__ (line 327) | def __call__( method __getattr__ (line 384) | def __getattr__(self, name: str): method _update_variables (line 396) | def _update_variables(self, module): class _Missing (line 435) | class _Missing: function to_linen (line 442) | def to_linen( function to_linen_class (line 463) | def to_linen_class( FILE: flax/nnx/extract.py class PrefixMapping (line 36) | class PrefixMapping(abc.ABC): method map_prefix (line 38) | def map_prefix( function check_consistent_aliasing (line 45) | def check_consistent_aliasing( function check_consistent_aliasing2 (line 110) | def check_consistent_aliasing2( function broadcast_prefix (line 159) | def broadcast_prefix( function broadcast_prefix2 (line 183) | def broadcast_prefix2( function broadcast_prefix_map (line 198) | def broadcast_prefix_map( class GraphDefState (line 211) | class GraphDefState(struct.PyTreeNode): class NodeStates (line 219) | class NodeStates(struct.PyTreeNode): method graphdef (line 225) | def graphdef(self) -> graphlib.GraphDef[tp.Any]: method state (line 231) | def state(self) -> tp.Any: method from_split (line 239) | def from_split( method from_states (line 250) | def from_states( method from_prefixes (line 258) | def from_prefixes( function default_split_fn (line 268) | def default_split_fn( function to_tree (line 274) | def to_tree( function to_tree2 (line 336) | def to_tree2( function from_tree2 (line 398) | def from_tree2(tree: tp.Any, /) -> tp.Any: function merge_tree_node (line 420) | def merge_tree_node( function is_tree_node (line 428) | def is_tree_node(x): function from_tree (line 432) | def from_tree( function clear_non_graph_nodes (line 485) | def clear_non_graph_nodes(tree): class Mask (line 495) | class Mask(tp.NamedTuple): function mask_at (line 498) | def mask_at(t: tuple, index: int | None) -> tuple: function replace_at (line 506) | def replace_at(t: tuple, index: int, value: tp.Any) -> tuple: function updates_and_snapshot (line 512) | def updates_and_snapshot(args: A) -> tuple[A, A]: function check_no_aliases (line 529) | def check_no_aliases(fn_name: str, /, **kwargs): function check_prefix (line 562) | def check_prefix(prefix: tp.Any, prefix_name: str, fn_name: str): function variable_changed (line 578) | def variable_changed(post: variablelib.Variable, pre: variablelib.Variab... function mask_variable_updates (line 589) | def mask_variable_updates( function apply_variable_updates (line 617) | def apply_variable_updates(args_tree: A, updates_tree: A): function treemap_copy_args (line 628) | def treemap_copy_args(f): function check_same_variables (line 636) | def check_same_variables(inputs, outputs, transform_name: str = ''): function update_carry_variables (line 650) | def update_carry_variables(init_val, val_out): FILE: flax/nnx/filterlib.py function to_predicate (line 32) | def to_predicate(filter: Filter) -> Predicate: function filters_to_predicates (line 57) | def filters_to_predicates( class HasTag (line 71) | class HasTag(tp.Protocol): function _has_tag (line 75) | def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]: class WithTag (line 80) | class WithTag: method __call__ (line 83) | def __call__(self, path: PathParts, x: tp.Any): method __repr__ (line 86) | def __repr__(self): class PathContains (line 91) | class PathContains: method __call__ (line 95) | def __call__(self, path: PathParts, x: tp.Any): method __repr__ (line 100) | def __repr__(self): class PathIn (line 104) | class PathIn: method __init__ (line 105) | def __init__(self, *paths: PathParts): method __call__ (line 108) | def __call__(self, path: PathParts, x: tp.Any): method __repr__ (line 111) | def __repr__(self): method __eq__ (line 115) | def __eq__(self, other): method __hash__ (line 118) | def __hash__(self): class OfType (line 123) | class OfType: method __call__ (line 126) | def __call__(self, path: PathParts, x: tp.Any): method __repr__ (line 129) | def __repr__(self): class Any (line 133) | class Any: method __init__ (line 134) | def __init__(self, *filters: Filter): method __call__ (line 139) | def __call__(self, path: PathParts, x: tp.Any): method __repr__ (line 142) | def __repr__(self): method __eq__ (line 145) | def __eq__(self, other): method __hash__ (line 148) | def __hash__(self): class All (line 152) | class All: method __init__ (line 153) | def __init__(self, *filters: Filter): method __call__ (line 158) | def __call__(self, path: PathParts, x: tp.Any): method __repr__ (line 161) | def __repr__(self): method __eq__ (line 164) | def __eq__(self, other): method __hash__ (line 167) | def __hash__(self): class Not (line 171) | class Not: method __init__ (line 172) | def __init__(self, collection_filter: Filter, /): method __call__ (line 175) | def __call__(self, path: PathParts, x: tp.Any): method __repr__ (line 178) | def __repr__(self): method __eq__ (line 181) | def __eq__(self, other): method __hash__ (line 184) | def __hash__(self): class Everything (line 188) | class Everything: method __call__ (line 189) | def __call__(self, path: PathParts, x: tp.Any): method __repr__ (line 192) | def __repr__(self): method __eq__ (line 195) | def __eq__(self, other): method __hash__ (line 198) | def __hash__(self): class Nothing (line 202) | class Nothing: method __call__ (line 203) | def __call__(self, path: PathParts, x: tp.Any): method __repr__ (line 206) | def __repr__(self): method __eq__ (line 209) | def __eq__(self, other): method __hash__ (line 212) | def __hash__(self): FILE: flax/nnx/graphlib.py function _tree_mode_suggestion_api (line 53) | def _tree_mode_suggestion_api(fn_name: str) -> str: function _tree_mode_suggestion_transform (line 63) | def _tree_mode_suggestion_transform(fn_name: str) -> str: function _check_valid_pytree (line 74) | def _check_valid_pytree( class NoUpdate (line 103) | class NoUpdate: ... class Repeated (line 111) | class Repeated: ... class ArrayRefOutput (line 119) | class ArrayRefOutput(reprlib.Representable): method __nnx_repr__ (line 122) | def __nnx_repr__(self): method __treescope_repr__ (line 126) | def __treescope_repr__(self, path, subtree_renderer): function is_node_leaf (line 149) | def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[LeafType]: class IndexMap (line 153) | class IndexMap(dict[Index, tp.Any]): method from_refmap (line 155) | def from_refmap(refmap: RefMap) -> IndexMap: class RefMap (line 166) | class RefMap(tp.MutableMapping[tp.Any, int], reprlib.MappingReprMixin): method __init__ (line 169) | def __init__( method from_indexmap (line 181) | def from_indexmap(indexmap: IndexMap) -> RefMap: method get (line 186) | def get(self, key: tp.Any, default: int | None = None) -> int | None: ... method __getitem__ (line 189) | def __getitem__(self, key: tp.Any) -> int: method __setitem__ (line 192) | def __setitem__(self, key: tp.Any, value: int): method __delitem__ (line 195) | def __delitem__(self, key: tp.Any): method __len__ (line 198) | def __len__(self) -> int: method __contains__ (line 201) | def __contains__(self, key: tp.Any) -> bool: method __iter__ (line 204) | def __iter__(self) -> tp.Iterator[tp.Any]: method items (line 208) | def items(self) -> tp.ItemsView[tp.Any, int]: class NodeImplBase (line 222) | class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): method node_dict (line 226) | def node_dict(self, node: Node) -> dict[Key, tp.Any]: class GraphNodeImpl (line 235) | class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]): class PytreeNodeImpl (line 244) | class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): function register_graph_node_type (line 259) | def register_graph_node_type( function register_pytree_node_type (line 282) | def register_pytree_node_type( function is_node (line 302) | def is_node(x: tp.Any) -> bool: function is_graph_node (line 310) | def is_graph_node(x: tp.Any) -> bool: function is_node_type (line 318) | def is_node_type(x: type[tp.Any]) -> bool: function get_node_impl (line 322) | def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None: function get_node_impl_for_type (line 338) | def get_node_impl_for_type( function _type_aware_sort (line 351) | def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]: class NodeRef (line 363) | class NodeRef(tp.Generic[Node], reprlib.Representable): method __nnx_repr__ (line 366) | def __nnx_repr__(self): method __treescope_repr__ (line 370) | def __treescope_repr__(self, path, subtree_renderer): class VariableDef (line 387) | class VariableDef(reprlib.Representable, tp.Generic[Node]): method with_no_outer_index (line 394) | def with_no_outer_index(self) -> VariableDef: method with_same_outer_index (line 405) | def with_same_outer_index(self) -> VariableDef: method with_matching_outer_index (line 416) | def with_matching_outer_index(self, other) -> VariableDef: method __nnx_repr__ (line 427) | def __nnx_repr__(self): method __treescope_repr__ (line 434) | def __treescope_repr__(self, path, subtree_renderer): class ArrayRefDef (line 456) | class ArrayRefDef(reprlib.Representable): method with_no_outer_index (line 460) | def with_no_outer_index(self): method with_same_outer_index (line 466) | def with_same_outer_index(self): method with_matching_outer_index (line 472) | def with_matching_outer_index(self, other): method __nnx_repr__ (line 478) | def __nnx_repr__(self): method __treescope_repr__ (line 483) | def __treescope_repr__(self, path, subtree_renderer): class NodeDef (line 497) | class NodeDef(tp.Generic[Node], reprlib.Representable): method with_no_outer_index (line 508) | def with_no_outer_index(self) -> NodeDef[Node]: method with_same_outer_index (line 517) | def with_same_outer_index(self) -> NodeDef[Node]: method with_matching_outer_index (line 526) | def with_matching_outer_index(self, other) -> NodeDef[Node]: method __nnx_repr__ (line 535) | def __nnx_repr__(self): method __treescope_repr__ (line 544) | def __treescope_repr__(self, path, subtree_renderer): class TreeNodeDef (line 567) | class TreeNodeDef(tp.Generic[Node]): method with_no_outer_index (line 572) | def with_no_outer_index(self) -> TreeNodeDef[Node]: method with_same_outer_index (line 575) | def with_same_outer_index(self) -> TreeNodeDef[Node]: method with_matching_outer_index (line 578) | def with_matching_outer_index(self, other) -> TreeNodeDef[Node]: class NodeAttr (line 591) | class NodeAttr: class LeafAttr (line 598) | class LeafAttr: class GraphDef (line 613) | class GraphDef(tp.Generic[Node]): method __hash__ (line 618) | def __hash__(self) -> int: method with_no_outer_index (line 621) | def with_no_outer_index(self) -> GraphDef[Node]: method with_matching_outer_index (line 631) | def with_matching_outer_index(self, other) -> GraphDef[Node]: method with_same_outer_index (line 641) | def with_same_outer_index(self) -> GraphDef[Node]: method apply (line 652) | def apply( function _tree_flatten (line 678) | def _tree_flatten( function flatten (line 750) | def flatten( # type: ignore[invalid-annotation] function flatten (line 759) | def flatten( # type: ignore[invalid-annotation] function flatten (line 772) | def flatten( # type: ignore[invalid-annotation] function flatten (line 785) | def flatten( # type: ignore[invalid-annotation] function flatten (line 797) | def flatten( # type: ignore[invalid-annotation] class DataElem (line 862) | class DataElem: class StaticElem (line 867) | class StaticElem: function _graph_flatten (line 870) | def _graph_flatten( function _get_sorted_leaves (line 1036) | def _get_sorted_leaves( function _tree_unflatten (line 1054) | def _tree_unflatten( function unflatten (line 1083) | def unflatten( # type: ignore[invalid-annotation] function _graph_unflatten (line 1164) | def _graph_unflatten( function graph_pop (line 1356) | def graph_pop( function _graph_pop (line 1372) | def _graph_pop( function _graph_update_dynamic (line 1428) | def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): class StaticCache (line 1500) | class StaticCache(tp.NamedTuple): method create (line 1509) | def create( class GraphContext (line 1529) | class GraphContext(threading.local): class set_graph_mode (line 1544) | class set_graph_mode(BaseConfigContext): class set_graph_updates (line 1549) | class set_graph_updates(BaseConfigContext): function static_cache (line 1555) | def static_cache(static_cache: tp.MutableMapping[tp.Any, StaticCache]): function _cached_partial (line 1571) | def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args, graph: bo... class SplitContext (line 1681) | class SplitContext: method split (line 1687) | def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: .... method split (line 1690) | def split( # type: ignore[invalid-annotation] method split (line 1695) | def split( method split (line 1704) | def split( method flatten (line 1722) | def flatten( # type: ignore[invalid-annotation] method flatten (line 1731) | def flatten( # type: ignore[invalid-annotation] method flatten (line 1738) | def flatten( # type: ignore[invalid-annotation] method flatten (line 1746) | def flatten( # type: ignore[invalid-annotation] method flatten (line 1759) | def flatten( # type: ignore[invalid-annotation] function split_context (line 1841) | def split_context(ctxtag: tp.Hashable | None = None): class MergeContext (line 1858) | class MergeContext: method merge (line 1863) | def merge( # type: ignore[invalid-annotation] method unflatten (line 1887) | def unflatten( # type: ignore[invalid-annotation] function merge_context (line 1981) | def merge_context() -> tp.Generator[MergeContext, None, None]: ... # ty... function merge_context (line 1984) | def merge_context( function merge_context (line 1988) | def merge_context(ctxtag: tp.Hashable | None = None, inner: bool | None ... class UpdateContext (line 2007) | class UpdateContext: method __hash__ (line 2019) | def __hash__(self): method __eq__ (line 2022) | def __eq__(self, other): method flatten_end (line 2025) | def flatten_end(self, ref_index: RefMap): method unflatten_end (line 2037) | def unflatten_end(self, index_ref: IndexMap, inner_merge: bool): class UpdateContextManager (line 2045) | class UpdateContextManager: method __enter__ (line 2048) | def __enter__(self): method __exit__ (line 2069) | def __exit__(self, *args): method __call__ (line 2086) | def __call__(self, f: F) -> F: function update_context (line 2095) | def update_context(tag: tp.Hashable): function current_update_context (line 2198) | def current_update_context(tag: tp.Hashable) -> UpdateContext: function _split_state (line 2210) | def _split_state( function split (line 2224) | def split( # type: ignore[invalid-annotation] function split (line 2228) | def split( # type: ignore[invalid-annotation] function split (line 2232) | def split( # type: ignore[invalid-annotation] function split (line 2244) | def split( # type: ignore[invalid-annotation] function _to_nested_state (line 2325) | def _to_nested_state( function _merge_to_flat_state (line 2340) | def _merge_to_flat_state(states: tp.Iterable[tp.Any]): function merge (line 2355) | def merge( # type: ignore[invalid-annotation] function update (line 2416) | def update(node, state: tp.Any, /, *states: tp.Any) -> None: function state (line 2461) | def state(node, /, *, graph: bool | None = None) -> GraphState: ... function state (line 2463) | def state(node, first: filterlib.Filter, /, *, graph: bool | None = None... function state (line 2465) | def state( function state (line 2473) | def state( function map (line 2530) | def map( function graphdef (line 2572) | def graphdef( function pop (line 2601) | def pop( function pop (line 2609) | def pop( function pop (line 2618) | def pop( function clone (line 2681) | def clone(node: Node, variables: bool = True, *, graph: bool | None = No... function vars_as (line 2708) | def vars_as( function pure (line 2759) | def pure(tree: A) -> A: function call (line 2810) | def call( function set_metadata (line 2900) | def set_metadata( function iter_graph (line 2937) | def iter_graph( function _iter_graph (line 2987) | def _iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: function _iter_tree (line 3009) | def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: function iter_children (line 3059) | def iter_children( function recursive_map (line 3131) | def recursive_map( function _recursive_map_graph (line 3181) | def _recursive_map_graph( function _recursive_map_tree (line 3221) | def _recursive_map_tree( function find_duplicates (line 3275) | def find_duplicates(node: tp.Any, /, *, only: filterlib.Filter = ...) ->... function _node_paths (line 3330) | def _node_paths( class Static (line 3359) | class Static(tp.Generic[A]): class GenericPytree (line 3370) | class GenericPytree: ... function is_pytree_node (line 3376) | def is_pytree_node( function _key_path_to_key (line 3391) | def _key_path_to_key(key: tp.Any) -> Key: function jax_to_nnx_path (line 3408) | def jax_to_nnx_path(jax_path: tuple, /): class IndexesPytreeDef (line 3412) | class IndexesPytreeDef(tp.NamedTuple): function _flatten_pytree (line 3417) | def _flatten_pytree(pytree: tp.Any): function _unflatten_pytree (line 3432) | def _unflatten_pytree( function _list_set_key (line 3448) | def _list_set_key(x: list[tp.Any], key: int, value: tp.Any): function _mutable_mapping_set_key (line 3467) | def _mutable_mapping_set_key( function _mutable_mapping_pop_key (line 3473) | def _mutable_mapping_pop_key(x: tp.MutableMapping[Key, tp.Any], key: Key): FILE: flax/nnx/helpers.py class Dict (line 36) | class Dict(reprlib.MappingReprMixin, Module, tp.MutableMapping[str, A]): method __init__ (line 54) | def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ... method __init__ (line 57) | def __init__( method __init__ (line 61) | def __init__(self, *args, **kwargs): method __getitem__ (line 65) | def __getitem__(self, key) -> A: method __setitem__ (line 71) | def __setitem__(self, key, value): method __iter__ (line 74) | def __iter__(self) -> tp.Iterator[str]: method __len__ (line 77) | def __len__(self) -> int: method __hash__ (line 83) | def __hash__(self) -> int: method __delitem__ (line 86) | def __delitem__(self, key: str) -> None: method __getattr__ (line 93) | def __getattr__(self, key: str) -> A: method __setattr__ (line 95) | def __setattr__(self, key: str, value: A) -> None: class List (line 99) | class List(reprlib.SequenceReprMixin, Module, tp.MutableSequence[A]): method __init__ (line 116) | def __init__(self, it: tp.Iterable[A] | None = None, /): method _get_elem (line 126) | def _get_elem(self, key: int) -> A: method _set_elem (line 129) | def _set_elem(self, key: int, value: A) -> None: method _del_elem (line 132) | def _del_elem(self, key: int) -> None: method __len__ (line 135) | def __len__(self) -> int: method append (line 138) | def append(self, value: A) -> None: method insert (line 142) | def insert(self, index: int, value: A) -> None: method __iter__ (line 158) | def __iter__(self) -> tp.Iterator[A]: method __getitem__ (line 163) | def __getitem__(self, index: int) -> A: ... method __getitem__ (line 165) | def __getitem__(self, index: slice) -> tp.List[A]: ... method __getitem__ (line 166) | def __getitem__(self, index: int | slice) -> A | tp.List[A]: method __setitem__ (line 179) | def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -... method _graph_node_set_key (line 198) | def _graph_node_set_key(self, key: str, value: tp.Any): method __delitem__ (line 209) | def __delitem__(self, index: int | slice) -> None: class Sequential (line 230) | class Sequential(Module): method __init__ (line 253) | def __init__(self, *fns: tp.Callable[..., tp.Any]): method __call__ (line 260) | def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) ->... class ModuleDefApply (line 294) | class ModuleDefApply(tp.Protocol, tp.Generic[M]): method __call__ (line 295) | def __call__( class TrainState (line 300) | class TrainState(tp.Generic[M], struct.PyTreeNode): method create (line 308) | def create( method __getattr__ (line 328) | def __getattr__(self, key: str) -> tp.Any: ... method apply (line 330) | def apply( method apply_gradients (line 349) | def apply_gradients(self: TS, grads: State, **kwargs) -> TS: function has_keyword_arg (line 361) | def has_keyword_arg(func: tp.Callable[..., tp.Any], name: str) -> bool: FILE: flax/nnx/ids.py class UUIDManager (line 19) | class UUIDManager: method __init__ (line 31) | def __init__(self): method __call__ (line 35) | def __call__(self): class UUID (line 44) | class UUID: method __init__ (line 47) | def __init__(self, rawid): method __eq__ (line 50) | def __eq__(self, other): method __hash__ (line 53) | def __hash__(self): method __repr__ (line 56) | def __repr__(self): method __deepcopy__ (line 59) | def __deepcopy__(self, memo): method __copy__ (line 63) | def __copy__(self): FILE: flax/nnx/module.py class ModuleMeta (line 49) | class ModuleMeta(PytreeMeta): class Module (line 55) | class Module(Pytree, metaclass=ModuleMeta): method sow (line 86) | def sow( method perturb (line 187) | def perturb( method iter_modules (line 279) | def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]: method iter_children (line 293) | def iter_children(self) -> tp.Iterator[tuple[Key, Module]]: method set_attributes (line 308) | def set_attributes( method train (line 369) | def train(self, **attributes): method eval (line 405) | def eval(self, **attributes): function view (line 440) | def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found... function with_attributes (line 518) | def with_attributes( function _parse_docstring_args (line 589) | def _parse_docstring_args(doc_str: str) -> dict[str, str]: function view_info (line 620) | def view_info(node: Module, /, *, only: filterlib.Filter = ..., graph: b... function first_from (line 703) | def first_from(*args: tp.Optional[A], error_msg: str) -> A: function iter_modules (line 719) | def iter_modules( function capture (line 774) | def capture( function capture (line 782) | def capture( function capture (line 789) | def capture(fn: tp.Callable[P, R] | type[variableslib.Variable], *var_ty... function _collect_state_by_path (line 928) | def _collect_state_by_path(state): function _navigate_to_path (line 946) | def _navigate_to_path(state, path): function _extract_captures (line 955) | def _extract_captures(module, state, var_types): function _add_capturing (line 969) | def _add_capturing(cls, variable_type): function _remove_capturing (line 986) | def _remove_capturing(cls): FILE: flax/nnx/nn/activations.py class PReLU (line 81) | class PReLU(nnx.Module): method __init__ (line 112) | def __init__( method __call__ (line 128) | def __call__(self, inputs: Array) -> Array: FILE: flax/nnx/nn/attention.py function dot_product_attention_weights (line 52) | def dot_product_attention_weights( function dot_product_attention (line 190) | def dot_product_attention( class MultiHeadAttention (line 322) | class MultiHeadAttention(Module): method __init__ (line 407) | def __init__( method __call__ (line 577) | def __call__( method init_cache (line 747) | def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): method set_view (line 781) | def set_view( function make_attention_mask (line 828) | def make_attention_mask( function make_causal_mask (line 860) | def make_causal_mask( function combine_masks (line 888) | def combine_masks( FILE: flax/nnx/nn/dtypes.py function canonicalize_dtype (line 22) | def canonicalize_dtype( function promote_dtype (line 54) | def promote_dtype(args: T, /, *, dtype=None, inexact=True) -> T: FILE: flax/nnx/nn/initializers.py function zeros_init (line 41) | def zeros_init() -> Initializer: function ones_init (line 54) | def ones_init() -> Initializer: FILE: flax/nnx/nn/linear.py function canonicalize_padding (line 52) | def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: function _conv_dimension_numbers (line 76) | def _conv_dimension_numbers(input_shape): function _normalize_axes (line 85) | def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]: function _canonicalize_tuple (line 90) | def _canonicalize_tuple(x: tp.Sequence[int] | int) -> tuple[int, ...]: class LinearGeneral (line 97) | class LinearGeneral(Module): method __init__ (line 156) | def __init__( method __call__ (line 253) | def __call__(self, inputs: Array, out_sharding = None) -> Array: class Linear (line 313) | class Linear(Module): method __init__ (line 357) | def __init__( method __call__ (line 400) | def __call__(self, inputs: Array, out_sharding = None) -> Array: class Einsum (line 435) | class Einsum(Module): method __init__ (line 482) | def __init__( method __call__ (line 527) | def __call__( method _infer_broadcasted_bias_shape (line 579) | def _infer_broadcasted_bias_shape( method _einsum_str_check (line 611) | def _einsum_str_check(self, einsum_str): class Conv (line 624) | class Conv(Module): method __init__ (line 715) | def __init__( method __call__ (line 782) | def __call__(self, inputs: Array, out_sharding=None) -> Array: class ConvTranspose (line 912) | class ConvTranspose(Module): method __init__ (line 1018) | def __init__( method __call__ (line 1079) | def __call__(self, inputs: Array) -> Array: class Embed (line 1217) | class Embed(Module): method __init__ (line 1271) | def __init__( method __call__ (line 1294) | def __call__(self, inputs: Array, out_sharding=None) -> Array: method attend (line 1321) | def attend(self, query: Array, out_sharding=None) -> Array: FILE: flax/nnx/nn/lora.py class LoRAParam (line 36) | class LoRAParam(variablelib.Param[A]): class LoRA (line 40) | class LoRA(Module): method __init__ (line 89) | def __init__( method __call__ (line 123) | def __call__(self, x: jax.Array): class LoRALinear (line 135) | class LoRALinear(Linear): method __init__ (line 179) | def __init__( method __call__ (line 212) | def __call__(self, x: jax.Array, out_sharding = None): FILE: flax/nnx/nn/normalization.py function _canonicalize_axes (line 35) | def _canonicalize_axes(rank: int, axes: Axes) -> tp.Tuple[int, ...]: function _abs_sq (line 42) | def _abs_sq(x): function _compute_stats (line 50) | def _compute_stats( function _normalize (line 134) | def _normalize( function _l2_normalize (line 186) | def _l2_normalize(x, axis=None, eps=1e-12): class BatchNorm (line 201) | class BatchNorm(Module): method __init__ (line 289) | def __init__( method __call__ (line 343) | def __call__( method set_view (line 414) | def set_view( class LayerNorm (line 428) | class LayerNorm(Module): method __init__ (line 491) | def __init__( method __call__ (line 541) | def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): class RMSNorm (line 579) | class RMSNorm(Module): method __init__ (line 636) | def __init__( method __call__ (line 675) | def __call__(self, x, mask: tp.Optional[jax.Array] = None): class GroupNorm (line 713) | class GroupNorm(Module): method __init__ (line 793) | def __init__( method __call__ (line 872) | def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): class WeightNorm (line 928) | class WeightNorm(nnx.Module): method __init__ (line 979) | def __init__( method _weightnorm_inplace (line 1014) | def _weightnorm_inplace(self, path, param): method __call__ (line 1051) | def __call__(self, x: Array, *args, **kwargs) -> Array: class InstanceNorm (line 1071) | class InstanceNorm(Module): method __init__ (line 1149) | def __init__( method __call__ (line 1196) | def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): class SpectralNorm (line 1242) | class SpectralNorm(Module): method __init__ (line 1304) | def __init__( method __call__ (line 1370) | def __call__( method _spectral_normalize_inplace (line 1400) | def _spectral_normalize_inplace(self, path, orig_param, update_stats): FILE: flax/nnx/nn/recurrent.py class RNNCellBase (line 47) | class RNNCellBase(Module): method initialize_carry (line 50) | def initialize_carry( method __call__ (line 68) | def __call__( method num_feature_axes (line 86) | def num_feature_axes(self) -> int: function modified_orthogonal (line 90) | def modified_orthogonal(key: Array, shape: Shape, dtype: Dtype = jnp.flo... class LSTMCell (line 95) | class LSTMCell(RNNCellBase): method __init__ (line 114) | def __init__( method __call__ (line 195) | def __call__( method initialize_carry (line 218) | def initialize_carry( method num_feature_axes (line 252) | def num_feature_axes(self) -> int: class OptimizedLSTMCell (line 256) | class OptimizedLSTMCell(RNNCellBase): method __init__ (line 303) | def __init__( method __call__ (line 373) | def __call__( method initialize_carry (line 406) | def initialize_carry( method num_feature_axes (line 441) | def num_feature_axes(self) -> int: class SimpleCell (line 445) | class SimpleCell(RNNCellBase): method __init__ (line 467) | def __init__( method __call__ (line 537) | def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]... method initialize_carry (line 544) | def initialize_carry( method num_feature_axes (line 578) | def num_feature_axes(self) -> int: class GRUCell (line 582) | class GRUCell(RNNCellBase): method __init__ (line 624) | def __init__( method __call__ (line 694) | def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]... method initialize_carry (line 727) | def initialize_carry( method num_feature_axes (line 762) | def num_feature_axes(self) -> int: class RNN (line 766) | class RNN(Module): method __init__ (line 774) | def __init__( method __call__ (line 808) | def __call__( function _select_last_carry (line 920) | def _select_last_carry(sequence: A, seq_lengths: jnp.ndarray) -> A: function _expand_dims_like (line 929) | def _expand_dims_like(x, target): function flip_sequences (line 934) | def flip_sequences( function _concatenate (line 994) | def _concatenate(a: Array, b: Array) -> Array: class RNNBase (line 999) | class RNNBase(Protocol): method __call__ (line 1000) | def __call__( class Bidirectional (line 1014) | class Bidirectional(Module): method __init__ (line 1046) | def __init__( method __call__ (line 1075) | def __call__( FILE: flax/nnx/nn/stochastic.py class Dropout (line 27) | class Dropout(Module): method __init__ (line 71) | def __init__( method __call__ (line 96) | def __call__( method set_view (line 159) | def set_view( FILE: flax/nnx/proxy_caller.py function _identity (line 25) | def _identity(x): class GetItem (line 29) | class GetItem: class GetAttr (line 34) | class GetAttr: class DelayedAccessor (line 39) | class DelayedAccessor: method __call__ (line 42) | def __call__(self, x): method __getattr__ (line 50) | def __getattr__(self, name): method __getitem__ (line 53) | def __getitem__(self, key): class _AccessorCall (line 59) | class _AccessorCall(tp.Protocol): method __call__ (line 60) | def __call__(self, accessor: DelayedAccessor, /, *args, **kwargs) -> t... class CallableProxy (line 63) | class CallableProxy: method __init__ (line 64) | def __init__( method __call__ (line 70) | def __call__(self, *args, **kwargs): method __getattr__ (line 73) | def __getattr__(self, name) -> CallableProxy: method __getitem__ (line 76) | def __getitem__(self, key) -> CallableProxy: class ApplyCaller (line 80) | class ApplyCaller(tp.Protocol, tp.Generic[A]): method __getattr__ (line 81) | def __getattr__(self, __name) -> ApplyCaller[A]: method __getitem__ (line 84) | def __getitem__(self, __name) -> ApplyCaller[A]: method __call__ (line 87) | def __call__(self, *args, **kwargs) -> tuple[tp.Any, A]: FILE: flax/nnx/pytreelib.py function data (line 60) | def data(value: A, /) -> A: ... function data (line 62) | def data( function data (line 73) | def data(value: tp.Any = MISSING, /, **kwargs) -> tp.Any: function register_data_type (line 116) | def register_data_type(type_: T, /) -> T: function is_data (line 156) | def is_data(value: tp.Any, /) -> bool: function has_data (line 201) | def has_data(value: tp.Any, /) -> list[tp.Any]: function static (line 221) | def static(value: A, /) -> A: ... function static (line 223) | def static( function static (line 234) | def static(value: tp.Any = MISSING, /, **kwargs) -> tp.Any: function dataclass (line 272) | def dataclass(cls: type[A], /) -> type[A]: ... function dataclass (line 274) | def dataclass( function dataclass (line 287) | def dataclass( function _collect_stats (line 312) | def _collect_stats( class ObjectContext (line 350) | class ObjectContext(threading.local): class PytreeState (line 358) | class PytreeState(reprlib.Representable): method __init__ (line 361) | def __init__(self, initializing: bool = False, is_setup: bool = False): method trace_state (line 367) | def trace_state(self) -> tracers.TraceState: method initializing (line 371) | def initializing(self) -> bool: method is_setup (line 375) | def is_setup(self) -> bool: method __nnx_repr__ (line 378) | def __nnx_repr__(self): method __treescope_repr__ (line 382) | def __treescope_repr__(self, path, subtree_renderer): function _flatten_pytree_state (line 391) | def _flatten_pytree_state(state: PytreeState): function _unflatten_pytree_state (line 395) | def _unflatten_pytree_state(static: tuple[bool, bool], _): function check_pytree (line 407) | def check_pytree(pytree): class PytreeMeta (line 416) | class PytreeMeta(ABCMeta): method __call__ (line 419) | def __call__(cls, *args: Any, **kwargs: Any) -> Any: method _pytree_meta_construct (line 422) | def _pytree_meta_construct(cls, self, *args, **kwargs): function _graph_node_meta_call (line 427) | def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P: class ArrayRepr (line 447) | class ArrayRepr(reprlib.Representable): method from_array (line 452) | def from_array(array: jax.Array | np.ndarray) -> ArrayRepr: method __nnx_repr__ (line 455) | def __nnx_repr__(self): class VariableRepr (line 461) | class VariableRepr(reprlib.Representable): method __nnx_repr__ (line 466) | def __nnx_repr__(self): class MutableArrayRepr (line 472) | class MutableArrayRepr(reprlib.Representable): method from_array (line 477) | def from_array(array: jax.Array | np.ndarray) -> MutableArrayRepr: method __nnx_repr__ (line 480) | def __nnx_repr__(self): function _to_shape_dtype (line 485) | def _to_shape_dtype(x): class AttributeStatus (line 500) | class AttributeStatus(tp.NamedTuple): class Pytree (line 505) | class Pytree(reprlib.Representable, metaclass=PytreeMeta): method __init_subclass__ (line 513) | def __init_subclass__( method _object__nodes (line 625) | def _object__nodes(self): method _object__state (line 634) | def _object__state(self): method __setattr__ (line 644) | def __setattr__(self, name: str, value: Any) -> None: method _setattr (line 647) | def _setattr(self, name, value: tp.Any) -> None: method _check_value (line 671) | def _check_value(self, key, value, new_status: AttributeStatus | None): method _check_valid_context (line 774) | def _check_valid_context(self, error_msg: tp.Callable[[], str]) -> None: method __deepcopy__ (line 778) | def __deepcopy__(self: P, memo=None) -> P: method __nnx_repr__ (line 784) | def __nnx_repr__(self): method __treescope_repr__ (line 834) | def __treescope_repr__(self, path, subtree_renderer): method __getstate__ (line 889) | def __getstate__(self): method __setstate__ (line 892) | def __setstate__(self, state): method _pytree__flatten_with_paths (line 900) | def _pytree__flatten_with_paths(self): method _pytree__flatten (line 933) | def _pytree__flatten(self): method _pytree__unflatten (line 962) | def _pytree__unflatten( method _graph_node_flatten (line 978) | def _graph_node_flatten(self): method _graph_node_set_key (line 999) | def _graph_node_set_key(self, key, value: tp.Any): method _graph_node_pop_key (line 1013) | def _graph_node_pop_key(self, key): method __delattr__ (line 1020) | def __delattr__(self, name: str) -> None: method _graph_node_create_empty (line 1030) | def _graph_node_create_empty(node_type: tp.Type[P]) -> P: method _graph_node_clear (line 1034) | def _graph_node_clear(self): method _graph_node_init (line 1037) | def _graph_node_init(self, attributes: tp.Iterable[tuple[str | int, tp... method __call__ (line 1042) | def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ... class Object (line 1045) | class Object(Pytree, pytree=False): method __init_subclass__ (line 1048) | def __init_subclass__(cls, **kwargs): function _maybe_int (line 1057) | def _maybe_int(x): function _get_str (line 1063) | def _get_str(x): FILE: flax/nnx/reprlib.py function supports_color (line 26) | def supports_color() -> bool: class Color (line 42) | class Color(tp.NamedTuple): class ReprContext (line 92) | class ReprContext(threading.local): function colorized (line 100) | def colorized(x, /): class Object (line 133) | class Object: method elem_sep (line 144) | def elem_sep(self): class Attr (line 149) | class Attr: class Representable (line 158) | class Representable: method __nnx_repr__ (line 161) | def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]: method __repr__ (line 164) | def __repr__(self) -> str: method __str__ (line 172) | def __str__(self) -> str: function get_repr (line 176) | def get_repr(obj: Representable) -> str: class MappingReprMixin (line 234) | class MappingReprMixin(Representable): method __nnx_repr__ (line 235) | def __nnx_repr__(self): class PrettyMapping (line 242) | class PrettyMapping(Representable): method __nnx_repr__ (line 245) | def __nnx_repr__(self): class SequenceReprMixin (line 253) | class SequenceReprMixin(Representable): method __nnx_repr__ (line 254) | def __nnx_repr__(self): class PrettySequence (line 262) | class PrettySequence(Representable): method __nnx_repr__ (line 265) | def __nnx_repr__(self): FILE: flax/nnx/rnglib.py class KeylessInitializer (line 45) | class KeylessInitializer(tp.Protocol): method __call__ (line 46) | def __call__( function _to_keyless (line 55) | def _to_keyless( function _function_to_method (line 61) | def _function_to_method(random_f): function _initializer_to_method (line 69) | def _initializer_to_method( class RngState (line 85) | class RngState(Variable[jax.Array]): class RngCount (line 89) | class RngCount(RngState): ... class RngKey (line 92) | class RngKey(RngState): ... class RngStream (line 98) | class RngStream(Pytree): method __init__ (line 100) | def __init__( method __call__ (line 120) | def __call__(self) -> jax.Array: method split (line 126) | def split(self, k: int | tuple[int, ...]): method fork (line 130) | def fork(self, *, split: int | tuple[int, ...] | None = None): class Rngs (line 323) | class Rngs(Pytree): method __init__ (line 372) | def __init__( method _get_stream (line 401) | def _get_stream(self, name: str, error_type: type[Exception]) -> RngSt... method __getitem__ (line 412) | def __getitem__(self, name: str): method __getattr__ (line 415) | def __getattr__(self, name: str): method __call__ (line 418) | def __call__(self): method __iter__ (line 421) | def __iter__(self) -> tp.Iterator[str]: method __len__ (line 426) | def __len__(self) -> int: method __contains__ (line 431) | def __contains__(self, name: tp.Any) -> bool: method items (line 434) | def items(self): method split (line 439) | def split(self, k: tp.Mapping[filterlib.Filter, int | tuple[int, ...]]... method fork (line 487) | def fork( class SplitBackups (line 716) | class SplitBackups(struct.PyTreeNode, tp.Iterable[StreamBackup]): method __iter__ (line 719) | def __iter__(self) -> tp.Iterator[StreamBackup]: method __enter__ (line 722) | def __enter__(self): method __exit__ (line 725) | def __exit__(self, *args): function split_rngs (line 730) | def split_rngs( function split_rngs (line 740) | def split_rngs( function split_rngs (line 750) | def split_rngs( function split_rngs (line 757) | def split_rngs( function _graph_split_rngs (line 899) | def _graph_split_rngs( function _tree_split_rngs (line 933) | def _tree_split_rngs( function fork_rngs (line 968) | def fork_rngs( function fork_rngs (line 978) | def fork_rngs( function fork_rngs (line 985) | def fork_rngs( function backup_keys (line 1073) | def backup_keys(node: tp.Any, /, *, graph: bool | None = None): function _scalars_only (line 1080) | def _scalars_only( function _match_shape (line 1093) | def _match_shape( function reseed (line 1101) | def reseed( function restore_rngs (line 1172) | def restore_rngs(backups: tp.Iterable[StreamBackup], /): FILE: flax/nnx/spmd.py function add_axis (line 35) | def add_axis(tree: A, index: int, transform_metadata: tp.Mapping) -> A: function remove_axis (line 65) | def remove_axis( function _get_partition_name_and_metadata (line 101) | def _get_partition_name_and_metadata( function with_partitioning (line 118) | def with_partitioning( function get_var_pspec (line 133) | def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None: function get_partition_spec (line 149) | def get_partition_spec(tree: A) -> A: function get_named_sharding (line 164) | def get_named_sharding(tree: A, mesh: jax.sharding.Mesh) -> A: function get_abstract_model (line 174) | def get_abstract_model(init_fn, mesh, *, graph: bool | None = None): function abstract_with_sharding (line 185) | def abstract_with_sharding( FILE: flax/nnx/statelib.py class NestedStateRepr (line 38) | class NestedStateRepr(reprlib.Representable): method __init__ (line 39) | def __init__(self, state: State): method __nnx_repr__ (line 42) | def __nnx_repr__(self): method __treescope_repr__ (line 50) | def __treescope_repr__(self, path, subtree_renderer): class FlatState (line 59) | class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.Representable): method __init__ (line 65) | def __init__(self, items: tp.Iterable[tuple[PathParts, V]], /, *, sort... method from_sorted_keys_values (line 76) | def from_sorted_keys_values( method paths (line 85) | def paths(self) -> tp.Tuple[PathParts, ...]: method leaves (line 89) | def leaves(self) -> list[V]: method __nnx_repr__ (line 92) | def __nnx_repr__(self): method __getitem__ (line 99) | def __getitem__(self, index: int) -> tuple[PathParts, V]: ... method __getitem__ (line 101) | def __getitem__(self, index: slice) -> FlatState[V]: ... method __getitem__ (line 102) | def __getitem__( method __len__ (line 109) | def __len__(self) -> int: method __iter__ (line 112) | def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]: method to_nested_state (line 115) | def to_nested_state(self) -> State[Key, V]: method split (line 119) | def split(self, first: filterlib.Filter, /) -> FlatState[V]: ... method split (line 122) | def split( method split (line 131) | def split( method split (line 135) | def split( # type: ignore[misc] method filter (line 155) | def filter(self, first: filterlib.Filter, /) -> FlatState[V]: ... method filter (line 158) | def filter( method filter (line 166) | def filter( method merge (line 185) | def merge( function _flat_state_pytree_flatten (line 201) | def _flat_state_pytree_flatten(x: FlatState[V]): function _flat_state_pytree_unflatten (line 205) | def _flat_state_pytree_unflatten( class State (line 221) | class State(MutableMapping[K, V], reprlib.Representable): method __init__ (line 225) | def __init__( method raw_mapping (line 251) | def raw_mapping(self) -> dict[K, tp.Mapping[K, tp.Any] | V]: method __contains__ (line 254) | def __contains__(self, key) -> bool: method __getitem__ (line 257) | def __getitem__(self, key: K) -> State | V: # type: ignore method __getattr__ (line 263) | def __getattr__(self, key: K) -> State | V: # type: ignore[misc] method __setitem__ (line 268) | def __setitem__(self, key: K, value: State | V) -> None: method __delitem__ (line 278) | def __delitem__(self, key: K) -> None: method __iter__ (line 281) | def __iter__(self) -> tp.Iterator[K]: method __len__ (line 284) | def __len__(self) -> int: method __nnx_repr__ (line 287) | def __nnx_repr__(self): method __treescope_repr__ (line 295) | def __treescope_repr__(self, path, subtree_renderer): method map (line 308) | def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]: method flat_state (line 316) | def flat_state(self) -> FlatState[V]: method from_flat_path (line 325) | def from_flat_path( method to_pure_dict (line 337) | def to_pure_dict(self, method replace_by_pure_dict (line 347) | def replace_by_pure_dict(self, method split (line 359) | def split(self, first: filterlib.Filter, /) -> State[K, V]: ... method split (line 362) | def split( method split (line 371) | def split( method split (line 375) | def split( # type: ignore[misc] method filter (line 386) | def filter( method filter (line 393) | def filter( method filter (line 401) | def filter( method merge (line 415) | def merge(cls, state: tp.Mapping[K, V], /, *states: tp.Mapping[K, V]): method __or__ (line 423) | def __or__(self, other: State[K, V]) -> State[K, V]: method __sub__ (line 428) | def __sub__(self, other: State[K, V]) -> State[K, V]: method __init_subclass__ (line 436) | def __init_subclass__(cls) -> None: function _state_flatten_with_keys (line 447) | def _state_flatten_with_keys(x: State): function _state_unflatten (line 453) | def _state_unflatten( function map_state (line 467) | def map_state(f: tp.Callable[[tuple, tp.Any], tp.Any], state: State) -> ... function to_flat_state (line 483) | def to_flat_state(state: State) -> FlatState: function from_flat_state (line 494) | def from_flat_state( function to_pure_dict (line 511) | def to_pure_dict( function restore_int_paths (line 529) | def restore_int_paths(pure_dict: dict[str, tp.Any]): function replace_by_pure_dict (line 571) | def replace_by_pure_dict( function split_state (line 606) | def split_state(state: State, first: filterlib.Filter, /) -> State: ... function split_state (line 610) | def split_state( function split_state (line 620) | def split_state( function split_state (line 625) | def split_state( # type: ignore[misc] function filter_state (line 674) | def filter_state( function filter_state (line 682) | def filter_state( function filter_state (line 691) | def filter_state( function merge_state (line 739) | def merge_state(state: tp.Mapping, /, *states: tp.Mapping, function diff (line 788) | def diff(state: State, other: State) -> State: function _split_state (line 799) | def _split_state( function create_path_filters (line 831) | def create_path_filters(state: State): FILE: flax/nnx/summary.py class NoneDumper (line 48) | class NoneDumper(yaml.SafeDumper): class SizeBytes (line 56) | class SizeBytes(typing.SizeBytes): method __repr__ (line 57) | def __repr__(self) -> str: class ObjectInfo (line 61) | class ObjectInfo(tp.NamedTuple): function _collect_stats (line 70) | def _collect_stats( class ArrayRepr (line 121) | class ArrayRepr: method from_array (line 126) | def from_array(cls, x: jax.Array | np.ndarray): method __str__ (line 129) | def __str__(self): class CallInfo (line 135) | class CallInfo: class SimpleObjectRepr (line 145) | class SimpleObjectRepr: method __init__ (line 146) | def __init__(self, obj: tp.Any): method __str__ (line 149) | def __str__(self): method __repr__ (line 152) | def __repr__(self): function _to_dummy_array (line 156) | def _to_dummy_array(x): function _pure_nnx_vjp (line 166) | def _pure_nnx_vjp(f, model, *args, **kwargs): function filter_rng_streams (line 174) | def filter_rng_streams(row: CallInfo): function _create_obj_env (line 177) | def _create_obj_env(object_types): function _get_inputs_repr (line 186) | def _get_inputs_repr(args, kwargs): function _save_call_info (line 202) | def _save_call_info(counter, tracer_args, f, node_stats, compute_flops, ... function _overwrite_methods (line 246) | def _overwrite_methods(env): function _get_flops (line 251) | def _get_flops(e) -> int: function tabulate (line 255) | def tabulate( function _get_rich_repr (line 503) | def _get_rich_repr(obj, console_kwargs): function _size_and_bytes (line 510) | def _size_and_bytes(pytree: tp.Any) -> tuple[int, int]: function _size_and_bytes_repr (line 519) | def _size_and_bytes_repr(size: int, num_bytes: int) -> str: function _bytes_repr (line 526) | def _bytes_repr(num_bytes): function _has_shape_dtype (line 540) | def _has_shape_dtype(value): function _normalize_values (line 544) | def _normalize_values(x): function _maybe_pytree_to_dict (line 552) | def _maybe_pytree_to_dict(pytree: tp.Any): function _unflatten_to_simple_structure (line 566) | def _unflatten_to_simple_structure( function _as_yaml_str (line 645) | def _as_yaml_str(value) -> str: function _render_array (line 665) | def _render_array(x): function _sort_variable_types (line 671) | def _sort_variable_types(types: tp.Iterable[type]) -> list[type]: FILE: flax/nnx/tracers.py function current_jax_trace (line 25) | def current_jax_trace(): class TraceState (line 32) | class TraceState(reprlib.Representable): method __init__ (line 35) | def __init__(self): method jax_trace (line 39) | def jax_trace(self): method is_valid (line 42) | def is_valid(self) -> bool: method __nnx_repr__ (line 45) | def __nnx_repr__(self): method __treescope_repr__ (line 49) | def __treescope_repr__(self, path, subtree_renderer): method __eq__ (line 57) | def __eq__(self, other): method __getstate__ (line 64) | def __getstate__(self): method __setstate__ (line 67) | def __setstate__(self, state): function _flatten_trace_state (line 70) | def _flatten_trace_state(trace_state: TraceState): function _unflatten_trace_state (line 74) | def _unflatten_trace_state(_1, _2): FILE: flax/nnx/training/metrics.py class MetricState (line 29) | class MetricState(Variable): class Metric (line 35) | class Metric(Pytree): method __init__ (line 39) | def __init__(self): method reset (line 42) | def reset(self) -> None: method update (line 46) | def update(self, **kwargs) -> None: method compute (line 50) | def compute(self): method split (line 54) | def split(self, *filters: filterlib.Filter): class Average (line 58) | class Average(Metric): method __init__ (line 83) | def __init__(self, argname: str = 'values'): method reset (line 97) | def reset(self) -> None: method update (line 102) | def update(self, mask: jax.Array | None = None, **kwargs) -> None: method compute (line 129) | def compute(self) -> jax.Array: class Statistics (line 135) | class Statistics: class Welford (line 141) | class Welford(Metric): method __init__ (line 166) | def __init__(self, argname: str = 'values'): method reset (line 181) | def reset(self) -> None: method update (line 187) | def update(self, **kwargs) -> None: method compute (line 209) | def compute(self) -> Statistics: class Accuracy (line 223) | class Accuracy(Average): method __init__ (line 260) | def __init__(self, threshold: float | None = None, *args, **kwargs): method update (line 277) | def update( # type: ignore[override] class MultiMetric (line 317) | class MultiMetric(Metric): method __init__ (line 396) | def __init__(self, **metrics): method reset (line 410) | def reset(self) -> None: method update (line 415) | def update(self, **updates) -> None: method compute (line 436) | def compute(self) -> dict[str, tp.Any]: FILE: flax/nnx/training/optimizer.py class OptState (line 31) | class OptState(Variable): class OptArray (line 37) | class OptArray(OptState): class OptVariable (line 43) | class OptVariable(OptState): function to_opt_state (line 49) | def to_opt_state(tree): class _Missing (line 64) | class _Missing: function _check_grads_arg_passed (line 69) | def _check_grads_arg_passed(f: F) -> F: function _check_wrt_arg_passed (line 80) | def _check_wrt_arg_passed(f: F) -> F: class Optimizer (line 91) | class Optimizer(Pytree, tp.Generic[M]): method __init__ (line 129) | def __init__( method __getattribute__ (line 159) | def __getattribute__(self, name: str) -> tp.Any: method update (line 168) | def update(self, model: M, grads, /, **kwargs): class ModelAndOptimizer (line 217) | class ModelAndOptimizer(Optimizer[M]): method __init__ (line 224) | def __init__(self, model: M, tx: optax.GradientTransformation, *, wrt:... method update (line 228) | def update(self, grads, /, **kwargs): # type: ignore FILE: flax/nnx/transforms/autodiff.py class DiffState (line 60) | class DiffState: class SimpleGradFn (line 66) | class SimpleGradFn: method __post_init__ (line 71) | def __post_init__(self): method __call__ (line 75) | def __call__(self, *args, **kwargs): class GradFn (line 93) | class GradFn: method __post_init__ (line 98) | def __post_init__(self): method __call__ (line 101) | def __call__(self, *pure_args): function _grad_general (line 132) | def _grad_general( function grad (line 290) | def grad( function grad (line 302) | def grad( function grad (line 312) | def grad( function value_and_grad (line 441) | def value_and_grad( function value_and_grad (line 453) | def value_and_grad( function value_and_grad (line 463) | def value_and_grad( class SimpleVjpFn (line 559) | class SimpleVjpFn: method __post_init__ (line 564) | def __post_init__(self): method __call__ (line 568) | def __call__(self, *args): function vjp (line 585) | def vjp( function vjp (line 594) | def vjp( function vjp (line 601) | def vjp( class SimpleJvpFn (line 725) | class SimpleJvpFn: method __post_init__ (line 730) | def __post_init__(self): method __call__ (line 734) | def __call__(self, *args): function jvp (line 751) | def jvp( function jvp (line 761) | def jvp( function jvp (line 768) | def jvp( function jvp (line 775) | def jvp( class SimpleCustomVjpFn (line 902) | class SimpleCustomVjpFn: method __post_init__ (line 907) | def __post_init__(self): method __call__ (line 911) | def __call__(self, *args): class SimpleFwdFn (line 942) | class SimpleFwdFn: method __post_init__ (line 946) | def __post_init__(self): method __call__ (line 950) | def __call__(self, *args): class SimpleBwdFn (line 964) | class SimpleBwdFn: method __post_init__ (line 968) | def __post_init__(self): method __call__ (line 972) | def __call__(self, *args): class SimpleCustomVjp (line 983) | class SimpleCustomVjp(tp.Generic[A]): method __init__ (line 984) | def __init__( method __call__ (line 999) | def __call__( method defvjp (line 1016) | def defvjp( function _custom_vjp_merge_fn (line 1046) | def _custom_vjp_merge_fn( function _custom_vjp_split_fn (line 1058) | def _custom_vjp_split_fn( function _extract_nodedefs (line 1094) | def _extract_nodedefs(x, *, nodedefs: deque[graphlib.GraphDef]): class CustomVjpFnWrapper (line 1101) | class CustomVjpFnWrapper: method __post_init__ (line 1108) | def __post_init__(self): method __call__ (line 1111) | def __call__(self, *pure_args): class FwdFn (line 1144) | class FwdFn: method __post_init__ (line 1151) | def __post_init__(self): method __call__ (line 1154) | def __call__(self, *pure_args): class BwdFn (line 1197) | class BwdFn: method __post_init__ (line 1201) | def __post_init__(self): method __call__ (line 1204) | def __call__(self, *args): class CustomVjp (line 1233) | class CustomVjp(tp.Generic[A]): method __init__ (line 1234) | def __init__( method __call__ (line 1266) | def __call__( method defvjp (line 1343) | def defvjp( function custom_vjp (line 1355) | def custom_vjp( function custom_vjp (line 1363) | def custom_vjp( function custom_vjp (line 1369) | def custom_vjp( class SimpleRematFn (line 1558) | class SimpleRematFn: method __post_init__ (line 1562) | def __post_init__(self): method __call__ (line 1566) | def __call__(self, *args, **kwargs): function remat (line 1578) | def remat( function remat (line 1587) | def remat( function remat (line 1596) | def remat( FILE: flax/nnx/transforms/compilation.py class StateSharding (line 50) | class StateSharding(extract.PrefixMapping): method __init__ (line 51) | def __init__( method filters (line 70) | def filters(self) -> tuple[filterlib.Filter, ...]: method shardings (line 74) | def shardings(self) -> tuple[tp.Any, ...]: method map_prefix (line 77) | def map_prefix( method __repr__ (line 86) | def __repr__(self): method __eq__ (line 89) | def __eq__(self, other): method __hash__ (line 96) | def __hash__(self): function _jit_split_fn (line 100) | def _jit_split_fn(ctx: graphlib.SplitContext, path, prefix, x): function _jit_merge_fn (line 107) | def _jit_merge_fn(ctx: graphlib.MergeContext, path, prefix, leaf) -> tp.... class JitFn (line 114) | class JitFn: method __post_init__ (line 121) | def __post_init__(self): method __call__ (line 127) | def __call__(self, *pure_args, **pure_kwargs): function jit (line 149) | def jit( function jit (line 165) | def jit( function jit (line 181) | def jit( class PartialState (line 407) | class PartialState: function _flatten_to_partial_state (line 425) | def _flatten_to_partial_state( class SimpleJitFn (line 440) | class SimpleJitFn: method __post_init__ (line 447) | def __post_init__(self): method __call__ (line 451) | def __call__(self, *args, **kwargs): class SimpleJitWrapped (line 474) | class SimpleJitWrapped(tp.Generic[P, R]): method __init__ (line 476) | def __init__( method _maybe_to_tree (line 547) | def _maybe_to_tree(self, args, kwargs): method _maybe_from_tree (line 558) | def _maybe_from_tree(self, out): method __call__ (line 563) | def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: method __get__ (line 572) | def __get__(self, obj, objtype=None): method eval_shape (line 577) | def eval_shape(self, *args, **kwargs): method trace (line 585) | def trace(self, *args, **kwargs): method lower (line 593) | def lower(self, *args, **kwargs): function jit_partial (line 600) | def jit_partial( class JitWrapped (line 751) | class JitWrapped(tp.Generic[P, R]): method __init__ (line 760) | def __init__( method __get__ (line 824) | def __get__(self, obj, objtype=None): method _get_pure_args_kwargs (line 829) | def _get_pure_args_kwargs(self, args, kwargs): method _get_non_pure_out (line 842) | def _get_non_pure_out(self, pure_args_out, pure_kwargs_out, pure_out, /): method __call__ (line 851) | def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: method eval_shape (line 861) | def eval_shape(self, *args, **kwargs): method trace (line 872) | def trace(self, *args, **kwargs) -> Traced: method lower (line 886) | def lower(self, *args, **kwargs) -> Lowered: class Stage (line 904) | class Stage: method _inner_obj (line 908) | def _inner_obj(self) -> tp.Any: method in_tree (line 912) | def in_tree(self) -> jax.tree_util.PyTreeDef: method in_avals (line 916) | def in_avals(self): method donate_argnums (line 920) | def donate_argnums(self): class Compiled (line 924) | class Compiled(Stage): method _inner_obj (line 937) | def _inner_obj(self): method args_info (line 941) | def args_info(self) -> tp.Any: # PyTree of ArgInfo method call (line 945) | def call(*args, **kwargs): method __call__ (line 948) | def __call__(self, *args, **kwargs): method out_tree (line 962) | def out_tree(self) -> jax.tree_util.PyTreeDef: method as_text (line 965) | def as_text(self) -> str | None: method cost_analysis (line 976) | def cost_analysis(self) -> tp.Any | None: method memory_analysis (line 990) | def memory_analysis(self) -> tp.Any | None: method runtime_executable (line 1004) | def runtime_executable(self) -> tp.Any | None: method input_shardings (line 1017) | def input_shardings(self): # PyTree[sharding.Sharding] method output_shardings (line 1021) | def output_shardings(self): # PyTree[sharding.Sharding] method input_layouts (line 1025) | def input_layouts(self): class Lowered (line 1030) | class Lowered(Stage): method _inner_obj (line 1044) | def _inner_obj(self): method args_info (line 1048) | def args_info(self) -> tp.Any: # PyTree of ArgInfo method out_tree (line 1052) | def out_tree(self): method from_flat_info (line 1056) | def from_flat_info( method compile (line 1067) | def compile( method as_text (line 1074) | def as_text( method compiler_ir (line 1091) | def compiler_ir(self, dialect: str | None = None) -> tp.Any | None: method cost_analysis (line 1108) | def cost_analysis(self) -> tp.Any | None: class Traced (line 1123) | class Traced(Stage): method _inner_obj (line 1135) | def _inner_obj(self): method out_info (line 1139) | def out_info(self): method lower (line 1142) | def lower( class SimpleCompiled (line 1151) | class SimpleCompiled(Stage): method _inner_obj (line 1156) | def _inner_obj(self): method args_info (line 1160) | def args_info(self) -> tp.Any: method call (line 1164) | def call(*args, **kwargs): method __call__ (line 1167) | def __call__(self, *args, **kwargs): method out_tree (line 1177) | def out_tree(self) -> jax.tree_util.PyTreeDef: method as_text (line 1180) | def as_text(self) -> str | None: method cost_analysis (line 1183) | def cost_analysis(self) -> tp.Any | None: method memory_analysis (line 1186) | def memory_analysis(self) -> tp.Any | None: method runtime_executable (line 1189) | def runtime_executable(self) -> tp.Any | None: method input_shardings (line 1193) | def input_shardings(self): method output_shardings (line 1197) | def output_shardings(self): method input_layouts (line 1201) | def input_layouts(self): class SimpleLowered (line 1206) | class SimpleLowered(Stage): method _inner_obj (line 1211) | def _inner_obj(self): method args_info (line 1215) | def args_info(self) -> tp.Any: method out_tree (line 1219) | def out_tree(self): method compile (line 1222) | def compile( method as_text (line 1228) | def as_text( method compiler_ir (line 1233) | def compiler_ir(self, dialect: str | None = None) -> tp.Any | None: method cost_analysis (line 1236) | def cost_analysis(self) -> tp.Any | None: class SimpleTraced (line 1241) | class SimpleTraced(Stage): method _inner_obj (line 1246) | def _inner_obj(self): method out_info (line 1250) | def out_info(self): method lower (line 1253) | def lower( class SimpleShardMapFn (line 1268) | class SimpleShardMapFn: method __post_init__ (line 1273) | def __post_init__(self): method __call__ (line 1277) | def __call__(self, *args): class ShardMapFn (line 1290) | class ShardMapFn: method __post_init__ (line 1297) | def __post_init__(self): method __call__ (line 1300) | def __call__(self, *pure_args, **pure_kwargs): function shard_map (line 1322) | def shard_map( function shard_map (line 1337) | def shard_map( function shard_map (line 1350) | def shard_map( function _fun_signature (line 1642) | def _fun_signature(fun: tp.Callable) -> inspect.Signature | None: function _resolve_argnums (line 1649) | def _resolve_argnums( FILE: flax/nnx/transforms/general.py function split_inputs (line 33) | def split_inputs( function split_inputs (line 38) | def split_inputs( function split_inputs (line 43) | def split_inputs( function merge_inputs (line 162) | def merge_inputs( function merge_inputs (line 167) | def merge_inputs( function merge_inputs (line 172) | def merge_inputs( FILE: flax/nnx/transforms/iteration.py class Carry (line 56) | class Carry: function _apply_axis_fn (line 65) | def _apply_axis_fn( function transform_metadata (line 82) | def transform_metadata( function transform_metadata (line 93) | def transform_metadata( function transform_metadata (line 104) | def transform_metadata( class StateAxes (line 162) | class StateAxes(extract.PrefixMapping, tp.Mapping): method __init__ (line 164) | def __init__( method filters (line 185) | def filters(self) -> tuple[filterlib.Filter, ...]: method axes (line 189) | def axes(self) -> tuple[Index | type[Carry] | None, ...]: method map_prefix (line 192) | def map_prefix( method __repr__ (line 201) | def __repr__(self): method items (line 204) | def items(self): method __getitem__ (line 207) | def __getitem__(self, key): method __iter__ (line 210) | def __iter__(self): method __len__ (line 213) | def __len__(self): method __eq__ (line 216) | def __eq__(self, other): method __hash__ (line 223) | def __hash__(self): function _update_variable_sharding_metadata (line 233) | def _update_variable_sharding_metadata( function _vmap_split_fn (line 260) | def _vmap_split_fn(ctx: graphlib.SplitContext, path, prefix, x): class SimpleVmapFn (line 269) | class SimpleVmapFn: method __post_init__ (line 274) | def __post_init__(self): method __call__ (line 278) | def __call__(self, *args, **kwargs): class SimplePmapFn (line 291) | class SimplePmapFn: method __post_init__ (line 296) | def __post_init__(self): method __call__ (line 300) | def __call__(self, *args, **kwargs): class VmapFn (line 313) | class VmapFn: method __post_init__ (line 319) | def __post_init__(self): method __call__ (line 322) | def __call__(self, *pure_args: tuple[tp.Any, ...]): function vmap (line 346) | def vmap( function vmap (line 362) | def vmap( function vmap (line 378) | def vmap( class PmapFn (line 590) | class PmapFn: method __post_init__ (line 596) | def __post_init__(self): method __call__ (line 599) | def __call__(self, *pure_args: tuple[tp.Any, ...]): function pmap (line 623) | def pmap( function pmap (line 642) | def pmap( function pmap (line 661) | def pmap( class Broadcasted (line 856) | class Broadcasted(struct.PyTreeNode): function _get_carry_argnum (line 859) | def _get_carry_argnum(axes, is_in_axes: bool): function _check_out_axes (line 889) | def _check_out_axes(out_axes): function _check_carry_same_references (line 912) | def _check_carry_same_references(carry_arg, carry_arg_out): function _scan_split_in (line 932) | def _scan_split_in( function _scan_split_out (line 1019) | def _scan_split_out( function _scan_merge_in (line 1112) | def _scan_merge_in( function _scan_merge_out (line 1135) | def _scan_merge_out( class ScanFn (line 1221) | class ScanFn: method __post_init__ (line 1229) | def __post_init__(self): method __call__ (line 1232) | def __call__( class SimpleScanFn (line 1361) | class SimpleScanFn: method __post_init__ (line 1370) | def __post_init__(self): method __call__ (line 1374) | def __call__(self, *args): function scan (line 1414) | def scan( function scan (line 1432) | def scan( function scan (line 1450) | def scan( function _simple_scan (line 1599) | def _simple_scan( function _graph_updates_scan (line 1700) | def _graph_updates_scan( function pure_jax_fancy_scan (line 1827) | def pure_jax_fancy_scan( class SimpleWhileLoopBodyFn (line 2001) | class SimpleWhileLoopBodyFn: method __post_init__ (line 2005) | def __post_init__(self): method __call__ (line 2009) | def __call__(self, val): class SimpleWhileLoopCondFn (line 2021) | class SimpleWhileLoopCondFn: method __post_init__ (line 2025) | def __post_init__(self): method __call__ (line 2028) | def __call__(self, val): class WhileLoopCondFn (line 2035) | class WhileLoopCondFn: method __post_init__ (line 2038) | def __post_init__(self): method __call__ (line 2041) | def __call__(self, pure_val): function _reconsile_index_mapping (line 2047) | def _reconsile_index_mapping(tree_to_fix, example_tree): function _add_fake_index_mapping (line 2060) | def _add_fake_index_mapping(tree: tp.Any): function _remove_index_mapping (line 2075) | def _remove_index_mapping(tree: tp.Any): class WhileLoopBodyFn (line 2094) | class WhileLoopBodyFn: method __post_init__ (line 2097) | def __post_init__(self): method __call__ (line 2101) | def __call__(self, pure_val): function while_loop (line 2128) | def while_loop(cond_fun: tp.Callable[[T], tp.Any], class SimpleForiLoopBodyFn (line 2198) | class SimpleForiLoopBodyFn: method __post_init__ (line 2202) | def __post_init__(self): method __call__ (line 2206) | def __call__(self, i, val): class ForiLoopBodyFn (line 2218) | class ForiLoopBodyFn: method __post_init__ (line 2221) | def __post_init__(self): method __call__ (line 2225) | def __call__(self, i, pure_val_in): function fori_loop (line 2233) | def fori_loop(lower: int, upper: int, FILE: flax/nnx/transforms/transforms.py function resolve_kwargs (line 56) | def resolve_kwargs( function resolve_kwargs (line 62) | def resolve_kwargs() -> tp.Callable[[F], F]: ... function resolve_kwargs (line 63) | def resolve_kwargs( function _resolve_bound_callable (line 101) | def _resolve_bound_callable( function _raise_bound_method_error (line 144) | def _raise_bound_method_error(transform_name: str): class LiftedModule (line 157) | class LiftedModule(tp.Generic[M], Module): # type: ignore[ignored-abstr... method _call (line 159) | def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any: method _submodule (line 164) | def _submodule(self) -> M: method __call__ (line 167) | def __call__(self, *args, **kwargs) -> tp.Any: method call (line 171) | def call(self) -> tp.Any: class ValueMetadata (line 190) | class ValueMetadata: function _flatten_value_metadata (line 196) | def _flatten_value_metadata( function _unflatten_value_metadata (line 203) | def _unflatten_value_metadata(aux_data, children): function _to_value_metadata (line 216) | def _to_value_metadata(node): function _to_variable (line 233) | def _to_variable(node): class SimpleEvalShapeFn (line 258) | class SimpleEvalShapeFn: method __post_init__ (line 262) | def __post_init__(self): method __call__ (line 266) | def __call__(self, *args, **kwargs): function eval_shape (line 276) | def eval_shape( class CheckifyFn (line 337) | class CheckifyFn: method __post_init__ (line 340) | def __post_init__(self): method __call__ (line 343) | def __call__(self, *pure_args, **pure_kwargs): class SimpleCheckifyFn (line 357) | class SimpleCheckifyFn: method __post_init__ (line 361) | def __post_init__(self): method __call__ (line 365) | def __call__(self, *args): function checkify (line 376) | def checkify( class SimpleCondFn (line 470) | class SimpleCondFn: method __post_init__ (line 474) | def __post_init__(self): method __call__ (line 478) | def __call__(self, *args): function cond (line 489) | def cond( function switch (line 547) | def switch( FILE: flax/nnx/traversals.py class _EmptyNode (line 28) | class _EmptyNode: function flatten_mapping (line 40) | def flatten_mapping(xs: Mapping[Any, Any], function flatten_mapping (line 50) | def flatten_mapping(xs: Mapping[Any, Any], function flatten_mapping (line 59) | def flatten_mapping(xs: Mapping[Any, Any], function flatten_to_sequence (line 122) | def flatten_to_sequence( function unflatten_mapping (line 170) | def unflatten_mapping( function unflatten_mapping (line 177) | def unflatten_mapping( function unflatten_mapping (line 184) | def unflatten_mapping(xs: Mapping[str, Any], /, *, sep: str) -> dict[Any... function unflatten_mapping (line 188) | def unflatten_mapping(xs: Any, /, *, sep: str | None = None) -> dict[Any... FILE: flax/nnx/variablelib.py class VariableContext (line 66) | class VariableContext(threading.local): class use_eager_sharding (line 75) | class use_eager_sharding(BaseConfigContext): function using_eager_sharding (line 111) | def using_eager_sharding() -> bool: class VarDefaults (line 134) | class VarDefaults(tp.Mapping[str, tp.Any]): method __getitem__ (line 138) | def __getitem__(self, key: str) -> tp.Any: method __iter__ (line 141) | def __iter__(self) -> tp.Iterator[str]: method __len__ (line 144) | def __len__(self) -> int: function var_defaults (line 149) | def var_defaults() -> VarDefaults: ... function var_defaults (line 153) | def var_defaults( function var_defaults (line 158) | def var_defaults( class VarDefaultsContext (line 195) | class VarDefaultsContext: method __init__ (line 196) | def __init__( method __enter__ (line 209) | def __enter__(self): method __exit__ (line 215) | def __exit__(self, exc_type, exc_value, traceback): method __call__ (line 221) | def __call__(self, f: F) -> F: function is_array_ref (line 250) | def is_array_ref(x) -> tp.TypeGuard[Ref]: class VariableMetadata (line 257) | class VariableMetadata(tp.Generic[A]): class VariableQDD (line 276) | class VariableQDD: method to_tangent_qdd (line 281) | def to_tangent_qdd(self): method normalize (line 285) | def normalize(self): class VariableEffect (line 289) | class VariableEffect(jax.core.Effect): ... function _bind_new_variable (line 296) | def _bind_new_variable( function _new_hijax_from_variable (line 310) | def _new_hijax_from_variable(variable: Variable) -> HijaxVariable: class NewVariable (line 324) | class NewVariable(hjx.HiPrimitive): method is_high (line 325) | def is_high(self, *leaves, treedef, var_type, has_qdd, ref) -> bool: method impl (line 328) | def impl(self, *leaves, treedef, var_type, has_qdd, ref): method abstract_eval (line 333) | def abstract_eval(self, *leaves, treedef, var_type, has_qdd, ref): method to_lojax (line 344) | def to_lojax(self, *leaves, treedef, var_type, has_qdd, ref): method jvp (line 347) | def jvp(_, primals, tangents, *, treedef, var_type, has_qdd, ref): method transpose (line 360) | def transpose( function _set_hijax_state (line 386) | def _set_hijax_state(hijax_var, variable: Variable): class SetVariable (line 393) | class SetVariable(hjx.HiPrimitive): method is_high (line 396) | def is_high(_, *leaf_avals, treedef, var_type) -> bool: method impl (line 400) | def impl(_, hijax_var: HijaxVariable, *leaves, treedef, var_type): method abstract_eval (line 410) | def abstract_eval( method to_lojax (line 426) | def to_lojax(_, hijax_var: HijaxVariable, *leaves, treedef, var_type): method jvp (line 436) | def jvp(_, primals, tangents, *, treedef, var_type): method transpose (line 454) | def transpose(_, *args, treedef, var_type): function _get_hijax_state (line 461) | def _get_hijax_state(hijax_var: HijaxVariable | AbstractVariable) -> Var... class GetVariable (line 491) | class GetVariable(hjx.HiPrimitive): method impl (line 494) | def impl( method abstract_eval (line 499) | def abstract_eval(self, abstract_var, *, treedef, avals, var_type, has... method to_lojax (line 505) | def to_lojax( method jvp (line 510) | def jvp(_, primals, tangents, *, treedef, avals, var_type, has_qdd): method transpose (line 533) | def transpose(_, out, hijax_var, *, treedef, avals, var_type, has_qdd): function _variable_has_changed (line 559) | def _variable_has_changed(old: Variable, new: Variable) -> bool: function _as_hijax_property (line 569) | def _as_hijax_property(name: str, *, get: bool, set: bool) -> property: function _as_aval_property (line 592) | def _as_aval_property(p: property) -> hjx.aval_property: function _as_hijax_attribute (line 598) | def _as_hijax_attribute(name: str) -> property: function _as_hijax_method (line 615) | def _as_hijax_method(name: str) -> tp.Any: function _as_tracer_method (line 632) | def _as_tracer_method(name: str): function _not_an_attribute_property (line 644) | def _not_an_attribute_property(name: str): class HijaxVariableMeta (line 652) | class HijaxVariableMeta(type): method __instancecheck__ (line 653) | def __instancecheck__(self, instance): class HijaxVariable (line 663) | class HijaxVariable( method _new (line 674) | def _new( method value (line 694) | def value(self) -> A: method value (line 703) | def value(self, new_value: A): method var_type (line 712) | def var_type(self) -> type[Variable[A]]: method ref (line 726) | def ref(self) -> bool: method copy_from (line 732) | def copy_from(self, other: Variable[A] | HijaxVariable[A]) -> None: method update_from_state (line 739) | def update_from_state(self, variable_state: Variable[A] | HijaxVariabl... method from_metadata (line 759) | def from_metadata(cls, value: A, metadata: dict[str, tp.Any]): method cur_qdd (line 834) | def cur_qdd(self): method type_state (line 837) | def type_state(self): function _to_abstract_variable (line 842) | def _to_abstract_variable(hijax_var: HijaxVariable): class AbstractVariable (line 864) | class AbstractVariable(tp.Generic[A], hjx.MutableHiType): method ref (line 873) | def ref(self) -> bool: method hijax (line 877) | def hijax(self): method __init__ (line 882) | def __init__( method dtype (line 900) | def dtype(self): method ndim (line 904) | def ndim(self): method size (line 908) | def size(self): method shape (line 912) | def shape(self): method __getattr__ (line 915) | def __getattr__(self, name: str): method from_metadata (line 945) | def from_metadata(self, value, metadata: dict[str, tp.Any]): method __str__ (line 954) | def __str__(self): method __repr__ (line 957) | def __repr__(self): method __treescope_repr__ (line 961) | def __treescope_repr__(self, path, subtree_renderer): method __hash__ (line 1034) | def __hash__(self): method __eq__ (line 1043) | def __eq__(self, other): method str_short (line 1048) | def str_short(self, short_dtypes=False, **_) -> str: # type: ignore method lo_ty_qdd (line 1052) | def lo_ty_qdd(self, variable_state: VariableQDD) -> list: # type: ignore method new_from_loval (line 1055) | def new_from_loval( # type: ignore[override] method read_loval (line 1073) | def read_loval(self, variable_state: VariableQDD, variable) -> list: ... method update_from_loval (line 1082) | def update_from_loval( # type: ignore[override] method to_tangent_aval (line 1093) | def to_tangent_aval(self): function _remap_sharding_metadata (line 1107) | def _remap_sharding_metadata(metadata: dict[str, tp.Any]) -> None: function _variable_operator (line 1124) | def _variable_operator(name: str) -> tp.Callable[[Variable[A], tp.Any], A]: function _variable_unary_operator (line 1135) | def _variable_unary_operator(name: str) -> tp.Callable[[Variable[A]], A]: class VariableMeta (line 1144) | class VariableMeta(type): method __new__ (line 1145) | def __new__(cls, cls_name, bases, attrs): method __instancecheck__ (line 1150) | def __instancecheck__(self, instance): method __call__ (line 1164) | def __call__(cls, *args, **kwargs): method _variable_meta_call (line 1167) | def _variable_meta_call(cls, *args, **kwargs): class Variable (line 1174) | class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableM... method var_type (line 1243) | def var_type(self): method hijax (line 1247) | def hijax(self) -> bool: method ref (line 1251) | def ref(self) -> bool: method shape (line 1255) | def shape(self: Variable[jax.Array]) -> tuple[int, ...]: method sharding_names (line 1259) | def sharding_names(self): method __init__ (line 1267) | def __init__( method _can_update (line 1363) | def _can_update(self) -> bool: method _check_can_update (line 1370) | def _check_can_update(self): method __getattr__ (line 1376) | def __getattr__(self, name: str) -> tp.Any: method __setattr__ (line 1381) | def __setattr__(self, name: str, value: tp.Any): method __delattr__ (line 1393) | def __delattr__(self, name: str): method type (line 1406) | def type(self): method get_metadata (line 1412) | def get_metadata( method get_metadata (line 1417) | def get_metadata(self, name: str, default: tp.Any = MISSING) -> tp.Any... method get_metadata (line 1419) | def get_metadata( method set_metadata (line 1454) | def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ... method set_metadata (line 1457) | def set_metadata(self, name: str, value: tp.Any, /) -> None: ... method set_metadata (line 1460) | def set_metadata(self, **metadata: tp.Any) -> None: ... method set_metadata (line 1462) | def set_metadata(self, *args, **kwargs) -> None: method has_metadata (line 1548) | def has_metadata(self, name: str) -> bool: method del_metadata (line 1558) | def del_metadata(self, name: str) -> None: method copy_from (line 1569) | def copy_from(self, other: Variable[A]) -> None: method update_from_state (line 1581) | def update_from_state(self, variable_state: Variable[A]): method get_raw_value (line 1591) | def get_raw_value(self) -> A: method set_raw_value (line 1595) | def set_raw_value(self, value: A, *, _unsafe_bypass_check: bool = False): method raw_value (line 1601) | def raw_value(self) -> A: method raw_value (line 1611) | def raw_value(self, value: A): method value (line 1621) | def value(self) -> A: method value (line 1633) | def value(self, value: A): method create_value (line 1644) | def create_value(self, value: A): method get_value (line 1647) | def get_value(self, *, index: tp.Any = MISSING) -> A: method set_value (line 1662) | def set_value(self, value: A, *, index: tp.Any = MISSING): method add_axis (line 1695) | def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): method remove_axis (line 1699) | def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): method copy (line 1704) | def copy(self, value: B, **kwargs) -> Variable[B]: ... method copy (line 1707) | def copy(self, **kwargs) -> Variable[A]: ... method copy (line 1709) | def copy( method _new (line 1738) | def _new( method from_metadata (line 1751) | def from_metadata( method __nnx_repr__ (line 1764) | def __nnx_repr__(self): method __treescope_repr__ (line 1782) | def __treescope_repr__(self, path, subtree_renderer): method on_get_value (line 1811) | def on_get_value(self, value: A) -> A: ... method on_set_value (line 1813) | def on_set_value(self, value: A) -> A: ... method on_create_value (line 1815) | def on_create_value(self, value: A) -> A: ... method on_add_axis (line 1817) | def on_add_axis( method on_remove_axis (line 1821) | def on_remove_axis( method __jax_array__ (line 1825) | def __jax_array__(self): method __getstate__ (line 1829) | def __getstate__(self): method __setstate__ (line 1836) | def __setstate__(self, state): method __getitem__ (line 1846) | def __getitem__(self: Variable[jax.Array], key) -> jax.Array: method __getitem__ (line 1850) | def __getitem__(self: Variable[dict[tp.Any, B]], key) -> B: method __getitem__ (line 1854) | def __getitem__(self: Variable[list[B]], key: int) -> B: method __getitem__ (line 1858) | def __getitem__(self: Variable[tuple[B, ...]], key: int) -> B: method __getitem__ (line 1862) | def __getitem__(self, key) -> tp.Any: method __getitem__ (line 1865) | def __getitem__(self, key): method __setitem__ (line 1868) | def __setitem__(self, key, value) -> None: method __delitem__ (line 1871) | def __delitem__(self, key) -> None: method __call__ (line 1876) | def __call__(self, *args, **kwargs) -> tp.Any: method __len__ (line 1879) | def __len__(self) -> int: method __iter__ (line 1882) | def __iter__(self) -> tp.Iterator: method __contains__ (line 1885) | def __contains__(self, item) -> bool: method __eq__ (line 1915) | def __eq__(self, other) -> bool: method __iadd__ (line 1920) | def __iadd__(self: V, other) -> V: method __isub__ (line 1926) | def __isub__(self: V, other) -> V: method __imul__ (line 1932) | def __imul__(self: V, other) -> V: method __imatmul__ (line 1938) | def __imatmul__(self: V, other) -> V: method __itruediv__ (line 1944) | def __itruediv__(self: V, other) -> V: method __ifloordiv__ (line 1950) | def __ifloordiv__(self: V, other) -> V: method __imod__ (line 1956) | def __imod__(self: V, other) -> V: method __ipow__ (line 1962) | def __ipow__(self: V, other) -> V: method __ilshift__ (line 1968) | def __ilshift__(self: V, other) -> V: method __irshift__ (line 1974) | def __irshift__(self: V, other) -> V: method __iand__ (line 1980) | def __iand__(self: V, other) -> V: method __ixor__ (line 1986) | def __ixor__(self: V, other) -> V: method __ior__ (line 1992) | def __ior__(self: V, other) -> V: method __round__ (line 2010) | def __round__(self, ndigits: int = 0) -> A: method __init_subclass__ (line 2015) | def __init_subclass__(cls) -> None: function _variable_flatten_with_keys (line 2027) | def _variable_flatten_with_keys(x: Variable[tp.Any]): function _variable_flatten (line 2033) | def _variable_flatten(x: Variable[tp.Any]): function _variable_unflatten (line 2038) | def _variable_unflatten( class Param (line 2056) | class Param(Variable[A]): class BatchStat (line 2079) | class BatchStat(Variable[A]): class Cache (line 2110) | class Cache(Variable[A]): class Intermediate (line 2143) | class Intermediate(Variable[A]): class Perturbation (line 2175) | class Perturbation(Intermediate[A]): function with_metadata (line 2207) | def with_metadata( function variable_type_from_name (line 2286) | def variable_type_from_name( function variable_name_from_type (line 2305) | def variable_name_from_type( function register_variable_name (line 2332) | def register_variable_name( function register_variable_name (line 2341) | def register_variable_name( function register_variable_name (line 2348) | def register_variable_name( FILE: flax/nnx/visualization.py function display (line 28) | def display(*args): function render_object_constructor (line 43) | def render_object_constructor( FILE: flax/serialization.py class _ErrorContext (line 32) | class _ErrorContext(threading.local): method __init__ (line 35) | def __init__(self): function _record_path (line 43) | def _record_path(name): function current_path (line 51) | def current_path(): class _NamedTuple (line 56) | class _NamedTuple: function _is_namedtuple (line 62) | def _is_namedtuple(x): function from_state_dict (line 67) | def from_state_dict(target, state: dict[str, Any], name: str = '.'): function to_state_dict (line 96) | def to_state_dict(target) -> dict[str, Any]: function is_serializable (line 113) | def is_serializable(target): function register_serialization_state (line 119) | def register_serialization_state( function _list_state_dict (line 140) | def _list_state_dict(xs: list[Any]) -> dict[str, Any]: function _restore_list (line 144) | def _restore_list(xs, state_dict: dict[str, Any]) -> list[Any]: function _dict_state_dict (line 158) | def _dict_state_dict(xs: dict[str, Any]) -> dict[str, Any]: function _restore_dict (line 168) | def _restore_dict(xs, states: dict[str, Any]) -> dict[str, Any]: function _namedtuple_state_dict (line 183) | def _namedtuple_state_dict(nt) -> dict[str, Any]: function _restore_namedtuple (line 187) | def _restore_namedtuple(xs, state_dict: dict[str, Any]): function _ndarray_to_bytes (line 249) | def _ndarray_to_bytes(arr) -> bytes: function _dtype_from_name (line 262) | def _dtype_from_name(name: str): function _ndarray_from_bytes (line 270) | def _ndarray_from_bytes(data: bytes) -> np.ndarray: class _MsgpackExtType (line 278) | class _MsgpackExtType(enum.IntEnum): function _msgpack_ext_pack (line 286) | def _msgpack_ext_pack(x): function _msgpack_ext_unpack (line 304) | def _msgpack_ext_unpack(code, data): function _np_convert_in_place (line 327) | def _np_convert_in_place(d): function _chunk (line 344) | def _chunk(arr) -> dict[str, Any]: function _unchunk (line 356) | def _unchunk(data: dict[str, Any]): function _chunk_array_leaves_in_place (line 364) | def _chunk_array_leaves_in_place(d): function _unchunk_array_leaves_in_place (line 379) | def _unchunk_array_leaves_in_place(d): function msgpack_serialize (line 396) | def msgpack_serialize(pytree, in_place: bool = False) -> bytes: function msgpack_restore (line 418) | def msgpack_restore(encoded_pytree: bytes): function from_bytes (line 437) | def from_bytes(target, encoded_bytes: bytes): function to_bytes (line 454) | def to_bytes(target) -> bytes: FILE: flax/struct.py function field (line 32) | def field(pytree_node=True, *, metadata=None, **kwargs): function dataclass (line 39) | def dataclass(clz: _T, **kwargs) -> _T: function dataclass (line 45) | def dataclass(**kwargs) -> Callable[[_T], _T]: function dataclass (line 50) | def dataclass( class PyTreeNode (line 194) | class PyTreeNode: method __init_subclass__ (line 230) | def __init_subclass__(cls, **kwargs): method __init__ (line 234) | def __init__(self, *args, **kwargs): method replace (line 238) | def replace(self: TNode, **overrides) -> TNode: FILE: flax/testing/benchmark.py function _make_events_generator (line 54) | def _make_events_generator(path): function _is_scalar_value (line 61) | def _is_scalar_value(value): function _process_event (line 69) | def _process_event(event): function _get_tensorboard_scalars (line 84) | def _get_tensorboard_scalars(path): class Benchmark (line 106) | class Benchmark(absltest.TestCase): method __init__ (line 116) | def __init__(self, *args, **kwargs): method _collect_assert_wrapper (line 131) | def _collect_assert_wrapper(self, *args, fn=None, **kwargs): method setUp (line 138) | def setUp(self): method tearDown (line 147) | def tearDown(self): method get_tmp_model_dir (line 154) | def get_tmp_model_dir(self): method has_outstanding_fails (line 173) | def has_outstanding_fails(self): method read_summaries (line 177) | def read_summaries(self, path): method report_wall_time (line 181) | def report_wall_time(self, wall_time: float): method report_metrics (line 186) | def report_metrics(self, metrics: dict[str, float]): method report_metric (line 191) | def report_metric(self, name: str, value: float): method report_extras (line 195) | def report_extras(self, extras: dict[str, str]): method report_extra (line 200) | def report_extra(self, name: str, value: str): method _get_test_name (line 204) | def _get_test_name(self, prefix='test_'): method _update_reported_name (line 235) | def _update_reported_name(self): method _report_benchmark_results (line 239) | def _report_benchmark_results(self): FILE: flax/traceback_util.py function register_exclusion (line 35) | def register_exclusion(path): function hide_flax_in_tracebacks (line 44) | def hide_flax_in_tracebacks(): function show_flax_in_tracebacks (line 53) | def show_flax_in_tracebacks(): FILE: flax/training/checkpoints.py function _is_multiprocess_array (line 90) | def _is_multiprocess_array(value: Any) -> bool: function _checkpoint_path (line 97) | def _checkpoint_path( function _checkpoint_path_step (line 103) | def _checkpoint_path_step(path: str) -> float | None: function _allowempty_listdir (line 111) | def _allowempty_listdir(path: str): function _safe_remove (line 118) | def _safe_remove(path: str): function _is_orbax_checkpoint (line 126) | def _is_orbax_checkpoint(path: str) -> bool: class AsyncManager (line 134) | class AsyncManager: method __init__ (line 142) | def __init__(self, max_workers: int = 1): method wait_previous_save (line 146) | def wait_previous_save(self): method save_async (line 155) | def save_async(self, task: Callable[[], Any]): function _split_mp_arrays (line 167) | def _split_mp_arrays( function _make_mpa_dirs (line 188) | def _make_mpa_dirs( function _save_mpas (line 205) | def _save_mpas( function _restore_mpas (line 253) | def _restore_mpas( function natural_sort (line 347) | def natural_sort(file_list: Iterable[str], signed: bool = True) -> list[... function safe_normpath (line 375) | def safe_normpath(path: str) -> str: function _remove_invalid_ckpts (line 383) | def _remove_invalid_ckpts( function _save_commit (line 444) | def _save_commit( function _check_overwrite_error (line 505) | def _check_overwrite_error( function _save_main_ckpt_file (line 531) | def _save_main_ckpt_file( function _get_checkpoint_paths (line 564) | def _get_checkpoint_paths( function save_checkpoint (line 580) | def save_checkpoint( function save_checkpoint_multiprocess (line 748) | def save_checkpoint_multiprocess( function _all_checkpoints (line 924) | def _all_checkpoints( function latest_checkpoint (line 955) | def latest_checkpoint( function available_steps (line 974) | def available_steps( function restore_checkpoint (line 1001) | def restore_checkpoint( function convert_pre_linen (line 1189) | def convert_pre_linen(params: PyTree) -> PyTree: FILE: flax/training/common_utils.py function shard (line 26) | def shard(xs): function shard_prng_key (line 41) | def shard_prng_key(prng_key): function stack_forest (line 57) | def stack_forest(forest): function get_metrics (line 71) | def get_metrics(device_metrics): function onehot (line 90) | def onehot(labels, num_classes, on_value=1.0, off_value=0.0): FILE: flax/training/dynamic_scale.py class DynamicScaleResult (line 29) | class DynamicScaleResult(NamedTuple): class DynamicScale (line 36) | class DynamicScale(struct.PyTreeNode): method value_and_grad (line 91) | def value_and_grad( FILE: flax/training/early_stopping.py class EarlyStopping (line 22) | class EarlyStopping(struct.PyTreeNode): method reset (line 66) | def reset(self): method update (line 74) | def update(self, metric): FILE: flax/training/lr_schedule.py function _piecewise_constant (line 31) | def _piecewise_constant(boundaries, values, t): function create_constant_learning_rate_schedule (line 36) | def create_constant_learning_rate_schedule( function create_stepped_learning_rate_schedule (line 79) | def create_stepped_learning_rate_schedule( function create_cosine_learning_rate_schedule (line 142) | def create_cosine_learning_rate_schedule( FILE: flax/training/orbax_utils.py function is_multi_device_array (line 28) | def is_multi_device_array(value: Any) -> bool: function save_args_from_target (line 35) | def save_args_from_target(target: Any) -> Any: function maybe_construct_transformations (line 39) | def maybe_construct_transformations( function restore_args_from_target (line 52) | def restore_args_from_target(target: Any, mesh: Mesh | None = None) -> Any: FILE: flax/training/prefetch_iterator.py class PrefetchIterator (line 21) | class PrefetchIterator: method __init__ (line 38) | def __init__(self, data_iter, buffer_size=1): method __iter__ (line 60) | def __iter__(self): method __next__ (line 63) | def __next__(self): method close (line 75) | def close(self): method _prefetch_loop (line 80) | def _prefetch_loop(self): FILE: flax/training/train_state.py class TrainState (line 25) | class TrainState(struct.PyTreeNode): method apply_gradients (line 81) | def apply_gradients(self, *, grads, **kwargs): method create (line 125) | def create(cls, *, apply_fn, params, tx, **kwargs): FILE: flax/traverse_util.py class _EmptyNode (line 74) | class _EmptyNode: function _flatten (line 81) | def _flatten(xs, prefix, keep_empty_nodes, is_leaf, sep): function flatten_dict (line 104) | def flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None): function unflatten_dict (line 144) | def unflatten_dict(xs, sep=None): function path_aware_map (line 181) | def path_aware_map( class Traversal (line 211) | class Traversal(abc.ABC): method __new__ (line 214) | def __new__(cls, *args, **kwargs): method update (line 226) | def update(self, fn, inputs): method iterate (line 239) | def iterate(self, inputs): method set (line 249) | def set(self, values, inputs): method compose (line 269) | def compose(self, other): method merge (line 273) | def merge(self, *traversals): method each (line 277) | def each(self): method tree (line 281) | def tree(self): method filter (line 285) | def filter(self, fn): method __getattr__ (line 289) | def __getattr__(self, attr): method __getitem__ (line 292) | def __getitem__(self, key): class TraverseId (line 296) | class TraverseId(Traversal): method update (line 299) | def update(self, fn, inputs): method iterate (line 302) | def iterate(self, inputs): class TraverseMerge (line 311) | class TraverseMerge(Traversal): method __init__ (line 314) | def __init__(self, *traversals): method update (line 317) | def update(self, fn, inputs): method iterate (line 322) | def iterate(self, inputs): class TraverseCompose (line 327) | class TraverseCompose(Traversal): method __init__ (line 330) | def __init__(self, x, y): method update (line 334) | def update(self, fn, inputs): method iterate (line 340) | def iterate(self, inputs): class TraverseFilter (line 345) | class TraverseFilter(Traversal): method __init__ (line 348) | def __init__(self, fn): method update (line 351) | def update(self, fn, inputs): method iterate (line 357) | def iterate(self, inputs): function _is_namedtuple (line 362) | def _is_namedtuple(t): class TraverseAttr (line 366) | class TraverseAttr(Traversal): method __init__ (line 369) | def __init__(self, attr): method update (line 372) | def update(self, fn, inputs): method iterate (line 383) | def iterate(self, inputs): class TraverseItem (line 387) | class TraverseItem(Traversal): method __init__ (line 390) | def __init__(self, key): method update (line 393) | def update(self, fn, inputs): method iterate (line 414) | def iterate(self, inputs): class TraverseEach (line 421) | class TraverseEach(Traversal): method update (line 424) | def update(self, fn, inputs): method iterate (line 432) | def iterate(self, inputs): class TraverseTree (line 439) | class TraverseTree(Traversal): method update (line 442) | def update(self, fn, inputs): method iterate (line 445) | def iterate(self, inputs): function _get_params_dict (line 449) | def _get_params_dict(inputs): function _sorted_items (line 459) | def _sorted_items(x): class ModelParamTraversal (line 464) | class ModelParamTraversal(Traversal): method __init__ (line 475) | def __init__(self, filter_fn): method iterate (line 486) | def iterate(self, inputs): method update (line 494) | def update(self, fn, inputs): FILE: flax/typing.py class Key (line 50) | class Key(Hashable, Protocol): method __lt__ (line 51) | def __lt__(self: K, value: K, /) -> bool: function is_key_like (line 54) | def is_key_like(x: Any) -> TypeGuard[Key]: class In (line 105) | class In(Generic[T]): class Out (line 111) | class Out(Generic[T]): class PytreeDeque (line 145) | class PytreeDeque(deque[A]): function _pytree_deque_flatten (line 149) | def _pytree_deque_flatten(xs: PytreeDeque, *, with_path: bool): function _pytree_deque_unflatten (line 157) | def _pytree_deque_unflatten(_, nodes): class Missing (line 168) | class Missing: function _bytes_repr (line 175) | def _bytes_repr(num_bytes): class ShapeDtype (line 189) | class ShapeDtype(Protocol): function has_shape_dtype (line 194) | def has_shape_dtype(x: Any) -> TypeGuard[ShapeDtype]: class SizeBytes (line 199) | class SizeBytes: # type: ignore[misc] method from_array (line 204) | def from_array(cls, x: ShapeDtype): method __add__ (line 214) | def __add__(self, other: SizeBytes): method __bool__ (line 217) | def __bool__(self) -> bool: method __repr__ (line 220) | def __repr__(self) -> str: method from_any (line 225) | def from_any(cls, x): class PromoteDtypeFn (line 238) | class PromoteDtypeFn(Protocol): method __call__ (line 239) | def __call__( class HashableMapping (line 244) | class HashableMapping(Mapping[HA, HB], Hashable): method __init__ (line 247) | def __init__(self, mapping: Mapping[HA, HB], copy: bool = True): method __contains__ (line 250) | def __contains__(self, key: object) -> bool: method __getitem__ (line 253) | def __getitem__(self, key: HA) -> HB: method __iter__ (line 256) | def __iter__(self) -> Iterator[HA]: method __len__ (line 259) | def __len__(self) -> int: method __hash__ (line 262) | def __hash__(self) -> int: method __eq__ (line 275) | def __eq__(self, other: Any) -> bool: method __repr__ (line 280) | def __repr__(self) -> str: method update (line 283) | def update(self, other: Mapping[HA, HB]) -> HashableMapping[HA, HB]: class BaseConfigContext (line 293) | class BaseConfigContext(abc.ABC): method get_default (line 296) | def get_default(cls): method get_stack (line 301) | def get_stack(cls) -> list: method __init__ (line 304) | def __init__(self, value, /): method current_value (line 315) | def current_value(cls): method __enter__ (line 321) | def __enter__(self): method __exit__ (line 325) | def __exit__(self, exc_type, exc_value, traceback): method __call__ (line 328) | def __call__(self, f: F) -> F: FILE: flaxlib_src/src/flaxlib/flaxlib_cpp.pyi class NodeDef (line 20) | class NodeDef: method with_no_outer_index (line 27) | def with_no_outer_index(self) -> NodeDef: ... method with_same_outer_index (line 28) | def with_same_outer_index(self) -> NodeDef: ... method __eq__ (line 29) | def __eq__(self, other: tp.Any) -> bool: ... method __hash__ (line 30) | def __hash__(self) -> int: ... method __getstate__ (line 31) | def __getstate__( method __setstate__ (line 35) | def __setstate__( class VariableDef (line 39) | class VariableDef: method with_no_outer_index (line 45) | def with_no_outer_index(self) -> VariableDef: ... method with_same_outer_index (line 46) | def with_same_outer_index(self) -> VariableDef: ... method __eq__ (line 47) | def __eq__(self, other: tp.Any) -> bool: ... method __hash__ (line 48) | def __hash__(self) -> int: ... method __getstate__ (line 49) | def __getstate__( method __setstate__ (line 53) | def __setstate__( class NodeRef (line 57) | class NodeRef: method __eq__ (line 60) | def __eq__(self, other: tp.Any) -> bool: ... method __hash__ (line 61) | def __hash__(self) -> int: ... method __getstate__ (line 62) | def __getstate__(self) -> tuple[int]: ... method __setstate__ (line 64) | def __setstate__(noderef: NodeRef, state: tuple[int]) -> None: ... FILE: flaxlib_src/src/lib.cc function nb_id (line 34) | intptr_t nb_id(const nb::object &obj) function vector_to_tuple (line 40) | nb::tuple vector_to_tuple(const std::vector &vec) type NbObjectHash (line 54) | struct NbObjectHash type NbObjectEqual (line 63) | struct NbObjectEqual type flaxlib (line 71) | namespace flaxlib type PythonContext (line 77) | struct PythonContext method PythonContext (line 91) | PythonContext() function PythonContext (line 122) | PythonContext &get_python_context() method PythonContext (line 91) | PythonContext() type IndexMap (line 135) | struct IndexMap : public std::unordered_map type RefMapKeysIterator (line 143) | struct RefMapKeysIterator method RefMapKeysIterator (line 148) | RefMapKeysIterator(std::unordered_map get(const nb::object &key, std::optional def... function IndexMap (line 253) | static IndexMap indexmap_from_refmap(const RefMap &refmap) function RefMap (line 265) | static RefMap refmap_from_indexmap(const IndexMap &indexmap) method RefMap (line 192) | RefMap() {} method RefMap (line 194) | RefMap(const nb::object &iterable) : RefMap() method update (line 204) | void update(const RefMap &other) method __getitem__ (line 212) | int __getitem__(const nb::object &key) method __setitem__ (line 217) | void __setitem__(const nb::object &key, int value) method __len__ (line 222) | int __len__() const method __contains__ (line 227) | bool __contains__(const nb::object &key) const method RefMapKeysIterator (line 232) | RefMapKeysIterator __iter__() method RefMapItemsIterator (line 237) | RefMapItemsIterator items() method get (line 242) | std::optional get(const nb::object &key, std::optional def... type NodeDef (line 279) | struct NodeDef method NodeDef (line 287) | NodeDef(nb::object type, std::optional index, std::optional outer_ind... method VariableDef (line 341) | VariableDef with_no_outer_index() const method VariableDef (line 346) | VariableDef with_same_outer_index() const method __eq__ (line 351) | bool __eq__(const nb::object &other_obj) const method __hash__ (line 361) | int __hash__() const method __getstate__ (line 367) | nb::tuple __getstate__() const method __setstate__ (line 372) | static void __setstate__(VariableDef &variabledef, nb::tuple &t) type NodeRef (line 382) | struct NodeRef method NodeRef (line 386) | NodeRef(int index) method __eq__ (line 389) | bool __eq__(const nb::object &other_obj) const method __hash__ (line 399) | int __hash__() const method __getstate__ (line 404) | nb::tuple __getstate__() const method __setstate__ (line 409) | static void __setstate__(NodeRef &noderef, nb::tuple &t) function NB_MODULE (line 415) | NB_MODULE(flaxlib_cpp, m) FILE: tests/checkpoints_test.py function check_eq (line 39) | def check_eq(xs, ys): function shuffle (line 45) | def shuffle(l): class Inner (line 52) | class Inner(nn.Module): method __call__ (line 56) | def __call__(self, x): class Model (line 62) | class Model(nn.Module): method __call__ (line 66) | def __call__(self, inputs): class CustomDC (line 77) | class CustomDC: class CheckpointsTest (line 82) | class CheckpointsTest(parameterized.TestCase): method setUp (line 83) | def setUp(self): method test_naturalsort (line 87) | def test_naturalsort(self): method test_safe_normpath (line 99) | def test_safe_normpath(self): method test_save_restore_checkpoints (line 106) | def test_save_restore_checkpoints(self, use_orbax): method test_overwrite_checkpoints (line 176) | def test_overwrite_checkpoints(self, use_orbax): method test_keep (line 199) | def test_keep(self, use_orbax, keep_every_n_steps): method test_save_restore_checkpoints_w_float_steps (line 232) | def test_save_restore_checkpoints_w_float_steps(self, use_orbax): method test_save_restore_checkpoints_target_none (line 266) | def test_save_restore_checkpoints_target_none(self, use_orbax): method test_save_restore_checkpoints_target_singular (line 288) | def test_save_restore_checkpoints_target_singular(self, use_orbax): method test_save_restore_checkpoints_target_empty (line 307) | def test_save_restore_checkpoints_target_empty(self, use_orbax): method test_async_save_checkpoints (line 324) | def test_async_save_checkpoints(self): method test_last_checkpoint (line 367) | def test_last_checkpoint(self): method test_available_steps (line 399) | def test_available_steps(self, step_type, steps): method test_complex_pytree (line 416) | def test_complex_pytree(self, use_orbax): method test_auto_restore (line 432) | def test_auto_restore(self): method test_smaller_target (line 454) | def test_smaller_target(self, use_orbax): method test_convert_pre_linen (line 464) | def test_convert_pre_linen(self): FILE: tests/configurations_test.py class MyTestCase (line 22) | class MyTestCase(absltest.TestCase): method setUp (line 23) | def setUp(self): method test_duplicate_flag (line 28) | def test_duplicate_flag(self): method test_default (line 32) | def test_default(self): method test_typed_update (line 36) | def test_typed_update(self): method test_untyped_update (line 41) | def test_untyped_update(self): method test_update_unknown_flag (line 46) | def test_update_unknown_flag(self): method test_temp_flip (line 50) | def test_temp_flip(self): FILE: tests/core/core_frozen_dict_test.py class FrozenDictTest (line 21) | class FrozenDictTest(parameterized.TestCase): method test_frozen_dict_copies (line 22) | def test_frozen_dict_copies(self): method test_frozen_dict_maps (line 29) | def test_frozen_dict_maps(self): method test_frozen_dict_pop (line 35) | def test_frozen_dict_pop(self): method test_frozen_dict_partially_maps (line 41) | def test_frozen_dict_partially_maps(self): method test_frozen_dict_hash (line 47) | def test_frozen_dict_hash(self): method test_frozen_items (line 52) | def test_frozen_items(self): method test_frozen_dict_repr (line 58) | def test_frozen_dict_repr(self): method test_frozen_dict_reduce (line 71) | def test_frozen_dict_reduce(self): method test_frozen_dict_copy_reserved_name (line 79) | def test_frozen_dict_copy_reserved_name(self): method test_utility_pop (line 97) | def test_utility_pop(self, x, key, actual_new_x, actual_value): method test_utility_copy (line 118) | def test_utility_copy(self, x, add_or_replace, actual_new_x): method test_utility_copy_singlearg (line 132) | def test_utility_copy_singlearg(self, x): method test_utility_pretty_repr (line 152) | def test_utility_pretty_repr(self, x, pretty_str): method test_flatten (line 155) | def test_flatten(self): FILE: tests/core/core_lift_test.py class LiftTest (line 25) | class LiftTest(absltest.TestCase): method test_aliasing (line 26) | def test_aliasing(self): method test_undefined_param (line 38) | def test_undefined_param(self): method test_jit_cache (line 53) | def test_jit_cache(self): method test_vjp (line 81) | def test_vjp(self): method test_jvp (line 104) | def test_jvp(self): method test_while_loop (line 124) | def test_while_loop(self): method test_cond (line 180) | def test_cond(self): method test_switch (line 202) | def test_switch(self): method test_subscope_var_aliasing (line 234) | def test_subscope_var_aliasing(self): FILE: tests/core/core_meta_test.py class MetaTest (line 25) | class MetaTest(absltest.TestCase): method test_boxed_param (line 26) | def test_boxed_param(self): method test_boxed_variable (line 51) | def test_boxed_variable(self): method test_partition_axis_unspecified (line 84) | def test_partition_axis_unspecified(self): method test_unbox (line 104) | def test_unbox(self): method test_scan_over_layers (line 124) | def test_scan_over_layers(self): method test_get_partition_spec (line 152) | def test_get_partition_spec(self): method test_get_sharding (line 168) | def test_get_sharding(self): method test_boxed_param_with_mesh (line 188) | def test_boxed_param_with_mesh(self): FILE: tests/core/core_scope_test.py class ScopeTest (line 27) | class ScopeTest(absltest.TestCase): method test_rng (line 28) | def test_rng(self): method test_in_filter (line 39) | def test_in_filter(self): method test_union_filter (line 52) | def test_union_filter(self): method test_intersect_filter (line 70) | def test_intersect_filter(self): method test_subtract_filter (line 86) | def test_subtract_filter(self): method test_group_collections (line 104) | def test_group_collections(self): method test_inconsistent_param_shapes (line 121) | def test_inconsistent_param_shapes(self): method test_apply_variables_bad_pytree (line 133) | def test_apply_variables_bad_pytree(self): method test_mutate_undefined_collection (line 151) | def test_mutate_undefined_collection(self): method test_undefined_param (line 162) | def test_undefined_param(self): method test_variable_is_mutable (line 170) | def test_variable_is_mutable(self): method test_rngs_check_w_frozen_dict (line 178) | def test_rngs_check_w_frozen_dict(self): method test_rng_check_w_old_and_new_keys (line 184) | def test_rng_check_w_old_and_new_keys(self): method test_rng_check_w_lazy_rng (line 201) | def test_rng_check_w_lazy_rng(self): method test_jax_leak_detector (line 205) | def test_jax_leak_detector(self): method test_rng_counter_reuse (line 216) | def test_rng_counter_reuse(self): method test_empty_col_error (line 227) | def test_empty_col_error(self): method test_variable_no_init (line 242) | def test_variable_no_init(self): method test_variable_alias (line 252) | def test_variable_alias(self): method test_lazy_init (line 261) | def test_lazy_init(self): method test_lazy_init_fails_on_data_dependence (line 276) | def test_lazy_init_fails_on_data_dependence(self): method test_fold_in_static_seperator (line 287) | def test_fold_in_static_seperator(self): FILE: tests/core/design/core_attention_test.py function softmax_attn (line 26) | def softmax_attn(scope: Scope, weights: Array): function with_dropout (line 35) | def with_dropout(fn, rate: float, deterministic: bool = False): function _dot_product_attention (line 45) | def _dot_product_attention( function dot_product_attention (line 72) | def dot_product_attention( function multi_head_dot_product_attention (line 99) | def multi_head_dot_product_attention( class AttentionTest (line 145) | class AttentionTest(absltest.TestCase): method test_attention (line 146) | def test_attention(self): FILE: tests/core/design/core_auto_encoder_test.py function mlp (line 26) | def mlp(scope: Scope, x: Array, hidden: int, out: int): class AutoEncoder (line 33) | class AutoEncoder: method __call__ (line 38) | def __call__(self, scope, x): method encode (line 42) | def encode(self, scope, x): method decode (line 45) | def decode(self, scope, z): function module_method (line 49) | def module_method(fn, name=None): class AutoEncoder2 (line 62) | class AutoEncoder2: method __call__ (line 68) | def __call__(self, x): method encode (line 73) | def encode(self, scope, x): method decode (line 77) | def decode(self, scope, z): class AutoEncoder3 (line 82) | class AutoEncoder3: method create (line 87) | def create(scope, hidden: int, latents: int, features: int): method __call__ (line 92) | def __call__(self, x): class AutoEncoderTest (line 97) | class AutoEncoderTest(absltest.TestCase): method test_auto_encoder_hp_struct (line 98) | def test_auto_encoder_hp_struct(self): method test_auto_encoder_with_scope (line 120) | def test_auto_encoder_with_scope(self): method test_auto_encoder_bind_method (line 145) | def test_auto_encoder_bind_method(self): FILE: tests/core/design/core_big_resnets_test.py function residual_block (line 28) | def residual_block(scope: Scope, x: Array, conv, norm, act, features: int): function big_resnet (line 38) | def big_resnet( class BigResnetTest (line 66) | class BigResnetTest(absltest.TestCase): method test_big_resnet (line 67) | def test_big_resnet(self): FILE: tests/core/design/core_custom_vjp_test.py function mlp_custom_grad (line 27) | def mlp_custom_grad( class CustomVJPTest (line 59) | class CustomVJPTest(absltest.TestCase): method test_custom_vjp (line 60) | def test_custom_vjp(self): FILE: tests/core/design/core_dense_test.py class Dense (line 28) | class Dense: method __call__ (line 34) | def __call__(self, scope, x): class ExplicitDense (line 46) | class ExplicitDense: method create (line 52) | def create( method create_in_scope (line 70) | def create_in_scope( method __call__ (line 85) | def __call__(self, x): function explicit_mlp (line 92) | def explicit_mlp(scope, x, sizes=(3, 1)): function semi_explicit_mlp (line 101) | def semi_explicit_mlp(scope, x, sizes=(3, 1)): class DenseTest (line 112) | class DenseTest(absltest.TestCase): method test_dense (line 113) | def test_dense(self): method test_explicit_dense (line 129) | def test_explicit_dense(self): method test_explicit_dense (line 144) | def test_explicit_dense(self): method test_semi_explicit_dense (line 159) | def test_semi_explicit_dense(self): FILE: tests/core/design/core_flow_test.py class DenseFlow (line 32) | class DenseFlow: method params (line 36) | def params(self, scope: Scope, features: int): method forward (line 41) | def forward(self, scope: Scope, x: Array): method backward (line 45) | def backward(self, scope: Scope, y: Array): class StackFlow (line 51) | class StackFlow: method forward (line 54) | def forward(self, scope: Scope, x: Array): method backward (line 59) | def backward(self, scope: Scope, x: Array): class FlowTest (line 65) | class FlowTest(absltest.TestCase): method test_flow (line 66) | def test_flow(self): FILE: tests/core/design/core_resnet_test.py function residual_block (line 27) | def residual_block( function resnet (line 49) | def resnet( class ResNetTest (line 85) | class ResNetTest(absltest.TestCase): method test_resnet (line 86) | def test_resnet(self): FILE: tests/core/design/core_scan_test.py function mlp_scan (line 23) | def mlp_scan(scope: Scope, xs: Array, share_params: bool = False): class ScanTest (line 51) | class ScanTest(absltest.TestCase): method test_scan_unshared_params (line 52) | def test_scan_unshared_params(self): method test_scan_shared_params (line 72) | def test_scan_shared_params(self): FILE: tests/core/design/core_tied_autoencoder_test.py function transpose (line 25) | def transpose(fn): class TiedAutoEncoder (line 35) | class TiedAutoEncoder: method __call__ (line 39) | def __call__(self, scope, x): method encode (line 43) | def encode(self, scope, x): method decode (line 46) | def decode(self, scope, z): class TiedAutoEncoderTest (line 50) | class TiedAutoEncoderTest(absltest.TestCase): method test_tied_auto_encoder (line 51) | def test_tied_auto_encoder(self): method test_init_from_decoder (line 67) | def test_init_from_decoder(self): FILE: tests/core/design/core_vmap_test.py function mlp_vmap (line 25) | def mlp_vmap( class VMapTest (line 56) | class VMapTest(absltest.TestCase): method test_vmap_shared (line 57) | def test_vmap_shared(self): method test_vmap_unshared (line 76) | def test_vmap_unshared(self): FILE: tests/core/design/core_weight_std_test.py function weight_std (line 26) | def weight_std(fn, kernel_name='kernel', eps=1e-8): function mlp (line 45) | def mlp(scope: Scope, x: Array, sizes: Sequence[int] = (8, 1)): class WeightStdTest (line 54) | class WeightStdTest(absltest.TestCase): method test_weight_std (line 55) | def test_weight_std(self): FILE: tests/cursor_test.py class GenericTuple (line 37) | class GenericTuple(NamedTuple): class GenericDataClass (line 44) | class GenericDataClass: class CursorTest (line 50) | class CursorTest(absltest.TestCase): method test_repr (line 51) | def test_repr(self): method test_magic_methods (line 125) | def test_magic_methods(self): method test_path (line 311) | def test_path(self): method test_traverse_tree (line 336) | def test_traverse_tree(self): method test_set_and_build (line 401) | def test_set_and_build(self): method test_apply_update (line 489) | def test_apply_update(self): method test_apply_update_root_node_unmodified (line 594) | def test_apply_update_root_node_unmodified(self): method test_multi_modify (line 605) | def test_multi_modify(self): method test_hidden_change (line 619) | def test_hidden_change(self): method test_named_tuple_multi_access (line 658) | def test_named_tuple_multi_access(self): method test_find (line 686) | def test_find(self): method test_find_all (line 721) | def test_find_all(self): FILE: tests/early_stopping_test.py class EarlyStoppingTests (line 27) | class EarlyStoppingTests(absltest.TestCase): method test_update (line 28) | def test_update(self): method test_patience (line 46) | def test_patience(self): method test_delta (line 65) | def test_delta(self): FILE: tests/io_test.py class IOTest (line 30) | class IOTest(parameterized.TestCase): method test_override (line 35) | def test_override(self, backend_mode): method test_GFile (line 43) | def test_GFile(self, write_mode, read_mode): method test_listdir (line 56) | def test_listdir(self): method test_isdir (line 76) | def test_isdir(self, create_temp_fn): method test_copy (line 88) | def test_copy(self): method test_copy_raises_error (line 117) | def test_copy_raises_error(self, backend_mode, error_type): method test_rename (line 123) | def test_rename(self): method test_rename_raises_error (line 151) | def test_rename_raises_error(self, backend_mode, error_type): method test_exists (line 157) | def test_exists(self): method test_makedirs (line 171) | def test_makedirs(self, backend_mode): method test_glob (line 181) | def test_glob(self): method test_remove (line 201) | def test_remove(self, backend_mode): method test_rmtree (line 217) | def test_rmtree(self, backend_mode): method test_getsize (line 238) | def test_getsize(self, backend_mode): FILE: tests/jax_utils_test.py function assert_max_traces (line 31) | def assert_max_traces(n): class PadShardUnpadTest (line 52) | class PadShardUnpadTest(parameterized.TestCase): method test_basics (line 59) | def test_basics(self, dtype, bs): method test_trees (line 72) | def test_trees(self, dtype, bs): method test_min_device_batch_avoids_recompile (line 84) | def test_min_device_batch_avoids_recompile(self, dtype): method test_static_argnum (line 99) | def test_static_argnum(self, dtype, bs): method test_static_argnames (line 110) | def test_static_argnames(self, dtype, bs): FILE: tests/linen/initializers_test.py class InitializersTest (line 30) | class InitializersTest(parameterized.TestCase): method test_call_builder (line 43) | def test_call_builder(self, builder_fn, params_shape, expected_params): method test_kernel_builder (line 57) | def test_kernel_builder(self, builder_fn, expected_params): FILE: tests/linen/kw_only_dataclasses_test.py class KwOnlyDataclassesTest (line 25) | class KwOnlyDataclassesTest(absltest.TestCase): method test_kwonly_args_moved_to_end (line 26) | def test_kwonly_args_moved_to_end(self): method test_base_optional_subclass_required (line 48) | def test_base_optional_subclass_required(self): method test_subclass_overrides_base (line 68) | def test_subclass_overrides_base(self): method test_kwonly_marker (line 108) | def test_kwonly_marker(self): method test_whatever (line 126) | def test_whatever(self): FILE: tests/linen/linen_activation_test.py class ActivationTest (line 29) | class ActivationTest(absltest.TestCase): method test_prelu (line 31) | def test_prelu(self): FILE: tests/linen/linen_attention_test.py class AttentionTest (line 37) | class AttentionTest(parameterized.TestCase): method test_multihead_self_attention (line 38) | def test_multihead_self_attention(self): method test_dtype_infer (line 52) | def test_dtype_infer(self): method test_multihead_encoder_decoder_attention (line 66) | def test_multihead_encoder_decoder_attention(self): method test_mha_out_initializers (line 79) | def test_mha_out_initializers(self): method test_multihead_self_attention_w_dropout (line 104) | def test_multihead_self_attention_w_dropout(self): method test_multihead_self_attention_explicit_dropout (line 120) | def test_multihead_self_attention_explicit_dropout(self): method test_multihead_self_attention_w_dropout_disabled (line 167) | def test_multihead_self_attention_w_dropout_disabled(self): method test_causal_mask_1d (line 208) | def test_causal_mask_1d(self): method test_decoding (line 221) | def test_decoding(self, spatial_shape, attn_dims): method test_autoregressive_receptive_field_1d (line 260) | def test_autoregressive_receptive_field_1d(self): method test_multihead_kv_args (line 307) | def test_multihead_kv_args(self): method test_multihead_mask_warning (line 355) | def test_multihead_mask_warning(self): method test_multihead_sow_attention_weights (line 382) | def test_multihead_sow_attention_weights(self): method test_autoregressive_decode_with_x64 (line 433) | def test_autoregressive_decode_with_x64(self): method test_attention_alias_equivalence (line 459) | def test_attention_alias_equivalence(self): method test_attention_alias_submodule (line 482) | def test_attention_alias_submodule(self): method test_mixed_precision_multihead_attention (line 560) | def test_mixed_precision_multihead_attention( method test_dot_product_attention_precision_and_einsum_override (line 588) | def test_dot_product_attention_precision_and_einsum_override( method test_dot_product_attention_specify_einsums_together (line 610) | def test_dot_product_attention_specify_einsums_together( FILE: tests/linen/linen_batch_apply_test.py class BatchApplyTest (line 28) | class BatchApplyTest(parameterized.TestCase): method test_batchapply (line 33) | def test_batchapply(self, fn): method test_batchapply_accepts_float (line 51) | def test_batchapply_accepts_float(self): method test_batchapply_accepts_none (line 60) | def test_batchapply_accepts_none(self): method test_batchapply_raises (line 71) | def test_batchapply_raises(self): FILE: tests/linen/linen_combinators_test.py class MLP (line 32) | class MLP(nn.Module): method __call__ (line 38) | def __call__(self, inputs): class AttentionTuple (line 54) | class AttentionTuple(nn.Module): method __call__ (line 59) | def __call__(self, query, key_value): class AttentionDict (line 66) | class AttentionDict(nn.Module): method __call__ (line 71) | def __call__(self, query, key_value): class SequentialTest (line 78) | class SequentialTest(absltest.TestCase): method test_construction (line 79) | def test_construction(self): method test_fails_if_layers_empty (line 87) | def test_fails_if_layers_empty(self): method test_same_output_as_mlp (line 92) | def test_same_output_as_mlp(self): method test_same_output_as_mlp_with_activation (line 111) | def test_same_output_as_mlp_with_activation(self): method test_tuple_output (line 138) | def test_tuple_output(self): method test_dict_output (line 156) | def test_dict_output(self): method test_sequential_compact (line 174) | def test_sequential_compact(self): FILE: tests/linen/linen_dtypes_test.py class DtypesTest (line 32) | class DtypesTest(absltest.TestCase): method test_no_inexact_dtype (line 33) | def test_no_inexact_dtype(self): method test_inexact_dtype (line 37) | def test_inexact_dtype(self): method test_explicit_downcast (line 46) | def test_explicit_downcast(self): FILE: tests/linen/linen_linear_test.py class LinearTest (line 32) | class LinearTest(parameterized.TestCase): method test_dense (line 33) | def test_dense(self): method test_dense_extra_batch_dims (line 46) | def test_dense_extra_batch_dims(self): method test_dense_no_bias (line 57) | def test_dense_no_bias(self): method test_dense_is_dense_general (line 68) | def test_dense_is_dense_general(self): method test_dense_general_batch_dim_raises (line 85) | def test_dense_general_batch_dim_raises(self): method test_dense_general_two_out (line 97) | def test_dense_general_two_out(self): method test_dense_general_two_in (line 108) | def test_dense_general_two_in(self): method test_dense_general_batch_dim (line 120) | def test_dense_general_batch_dim(self): method test_dense_general_vs_numpy (line 151) | def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): method test_complex_params_dense (line 166) | def test_complex_params_dense(self): method test_complex_input_dense (line 175) | def test_complex_input_dense(self): method test_einsum_init_apply (line 284) | def test_einsum_init_apply( method test_einsum_ellipsis_equivalence (line 331) | def test_einsum_ellipsis_equivalence( method test_einsum_str_arg (line 350) | def test_einsum_str_arg(self): method test_einsum_space_str (line 383) | def test_einsum_space_str(self): method test_einsum_error (line 417) | def test_einsum_error(self, einsum_str, error_msg): method test_conv (line 425) | def test_conv(self, use_bias): method test_multibatch_input_conv (line 442) | def test_multibatch_input_conv(self, use_bias): method test_conv_local (line 458) | def test_conv_local(self): method test_single_input_conv (line 472) | def test_single_input_conv(self): method test_single_input_masked_conv (line 486) | def test_single_input_masked_conv(self): method test_single_input_conv_local (line 512) | def test_single_input_conv_local(self): method test_group_conv (line 526) | def test_group_conv(self): method test_circular_conv_1d_constant (line 549) | def test_circular_conv_1d_constant( method _get_kernel_shape (line 589) | def _get_kernel_shape(self, input_shape, kernel_size, module, n_featur... method test_circular_conv_2d_constant (line 610) | def test_circular_conv_2d_constant( method test_circular_conv_1d_custom (line 653) | def test_circular_conv_1d_custom(self): method test_circular_conv_local_1d_custom (line 677) | def test_circular_conv_local_1d_custom(self): method test_circular_conv_1d_dilation (line 702) | def test_circular_conv_1d_dilation(self): method test_circular_conv_local_1d_dilation (line 734) | def test_circular_conv_local_1d_dilation(self): method test_circular_conv_2d_custom (line 770) | def test_circular_conv_2d_custom(self): method test_circular_conv_local_2d_custom (line 799) | def test_circular_conv_local_2d_custom(self): method test_reflect_conv_1d_custom (line 849) | def test_reflect_conv_1d_custom(self): method test_reflect_conv_2d_custom (line 873) | def test_reflect_conv_2d_custom(self): method test_causal_conv1d (line 902) | def test_causal_conv1d(self): method test_conv_transpose (line 933) | def test_conv_transpose(self, use_bias): method test_multibatch_input_conv_transpose (line 969) | def test_multibatch_input_conv_transpose(self, use_bias): method test_single_input_conv_transpose (line 1004) | def test_single_input_conv_transpose(self): method test_single_input_masked_conv_transpose (line 1032) | def test_single_input_masked_conv_transpose(self): method test_circular_conv_transpose_1d_constant (line 1069) | def test_circular_conv_transpose_1d_constant( method test_circular_conv_transpose_2d_constant (line 1107) | def test_circular_conv_transpose_2d_constant( method test_circular_conv_transpose_2d_with_vmap (line 1144) | def test_circular_conv_transpose_2d_with_vmap(self): method test_circular_conv_transpose_1d_custom (line 1159) | def test_circular_conv_transpose_1d_custom(self): method test_circular_conv_transpose_2d_custom (line 1201) | def test_circular_conv_transpose_2d_custom(self): method test_circular_conv_transpose_2d_custom_bias (line 1236) | def test_circular_conv_transpose_2d_custom_bias(self): method test_transpose_kernel_conv_transpose (line 1270) | def test_transpose_kernel_conv_transpose(self, use_bias): method test_int_kernel_equality (line 1286) | def test_int_kernel_equality(self, module): method test_embed (line 1300) | def test_embed(self): method test_embed_numpy (line 1318) | def test_embed_numpy(self): method test_embed_hash (line 1336) | def test_embed_hash(self): method test_non_final_axis (line 1340) | def test_non_final_axis(self): method test_non_final_axes (line 1354) | def test_non_final_axes(self): method test_canonicalize_padding (line 1368) | def test_canonicalize_padding(self): FILE: tests/linen/linen_meta_test.py class LinenMetaTest (line 27) | class LinenMetaTest(absltest.TestCase): method test_boxed_param (line 28) | def test_boxed_param(self): method test_boxed_variable (line 58) | def test_boxed_variable(self): method test_pjit_scan_over_layers (line 119) | def test_pjit_scan_over_layers(self): FILE: tests/linen/linen_module_test.py function tree_equals (line 53) | def tree_equals(x, y): function set_config (line 58) | def set_config(option: str, value: bool): class DummyModule (line 67) | class DummyModule(nn.Module): method __call__ (line 69) | def __call__(self, x): class Dense (line 74) | class Dense(nn.Module): method __call__ (line 78) | def __call__(self, x): class IdentityModule (line 86) | class IdentityModule(nn.Module): method __call__ (line 87) | def __call__(self, x): class RaisesModule (line 91) | class RaisesModule(nn.Module): method __call__ (line 92) | def __call__(self): class ModuleTest (line 96) | class ModuleTest(absltest.TestCase): method test_init_module (line 97) | def test_init_module(self): method test_lazy_init (line 108) | def test_lazy_init(self): method test_lazy_init_fails_on_data_dependence (line 124) | def test_lazy_init_fails_on_data_dependence(self): method test_arg_module (line 134) | def test_arg_module(self): method test_util_fun (line 144) | def test_util_fun(self): method test_nested_module_reuse (line 169) | def test_nested_module_reuse(self): method test_setup_dict_assignment (line 207) | def test_setup_dict_assignment(self): method test_setup_dict_nonstring_keys (line 235) | def test_setup_dict_nonstring_keys(self): method test_setup_frozen_dict_nonstring_keys (line 252) | def test_setup_frozen_dict_nonstring_keys(self): method test_setup_dict_nonstring_keys_in_state (line 268) | def test_setup_dict_nonstring_keys_in_state(self): method test_setup_cloning (line 282) | def test_setup_cloning(self): method test_submodule_attr (line 290) | def test_submodule_attr(self): method test_param_in_setup (line 321) | def test_param_in_setup(self): method test_init_outside_setup_without_compact (line 342) | def test_init_outside_setup_without_compact(self): method test_init_outside_call (line 355) | def test_init_outside_call(self): method test_setup_call_var_collision (line 373) | def test_setup_call_var_collision(self): method test_call_var_collision (line 393) | def test_call_var_collision(self): method test_setup_var_collision (line 411) | def test_setup_var_collision(self): method test_setattr_name_var_disagreement_allowed_in_lists (line 430) | def test_setattr_name_var_disagreement_allowed_in_lists(self): method test_setattr_name_var_disagreement_allowed_in_dicts (line 450) | def test_setattr_name_var_disagreement_allowed_in_dicts(self): method test_submodule_var_collision_with_scope (line 475) | def test_submodule_var_collision_with_scope(self): method test_submodule_var_collision_with_submodule (line 494) | def test_submodule_var_collision_with_submodule(self): method test_submodule_var_collision_with_params (line 515) | def test_submodule_var_collision_with_params(self): method test_attr_empty_container (line 536) | def test_attr_empty_container(self): method test_multiple_compact_methods (line 546) | def test_multiple_compact_methods(self): method test_only_one_compact_method_subclass (line 570) | def test_only_one_compact_method_subclass(self): method test_forgotten_compact_annotation (line 588) | def test_forgotten_compact_annotation(self): method test_forgotten_compact_annotation_with_explicit_parent (line 609) | def test_forgotten_compact_annotation_with_explicit_parent(self): method test_numpy_array_shape_class_args (line 629) | def test_numpy_array_shape_class_args(self): method test_get_local_methods (line 643) | def test_get_local_methods(self): method test_inheritance_dataclass_attribs (line 678) | def test_inheritance_dataclass_attribs(self): method test_get_suffix_value_pairs (line 742) | def test_get_suffix_value_pairs(self): method test_mixed_list_assignment_in_setup (line 762) | def test_mixed_list_assignment_in_setup(self): method test_module_is_hashable (line 779) | def test_module_is_hashable(self): method test_module_custom_hash (line 786) | def test_module_custom_hash(self): method test_module_with_scope_is_not_hashable (line 800) | def test_module_with_scope_is_not_hashable(self): method test_module_trace (line 806) | def test_module_trace(self): method test_default_params_rng_equivalence (line 857) | def test_default_params_rng_equivalence(self): method test_module_apply_method (line 912) | def test_module_apply_method(self): method test_module_apply_method_submodule (line 957) | def test_module_apply_method_submodule(self): method test_call_unbound_compact_module_methods (line 970) | def test_call_unbound_compact_module_methods(self): method test_call_unbound_has_variable (line 976) | def test_call_unbound_has_variable(self): method test_call_unbound_make_rng (line 985) | def test_call_unbound_make_rng(self): method test_call_unbound_variables (line 994) | def test_call_unbound_variables(self): method test_call_unbound_noncompact_module_methods (line 1003) | def test_call_unbound_noncompact_module_methods(self): method test_call_unbound_noncompact_module_methods_depending_on_setup (line 1015) | def test_call_unbound_noncompact_module_methods_depending_on_setup(self): method test_module_with_attrs (line 1028) | def test_module_with_attrs(self): method test_noncompact_module_frozen (line 1043) | def test_noncompact_module_frozen(self): method test_compact_module_frozen (line 1058) | def test_compact_module_frozen(self): method test_submodule_frozen (line 1071) | def test_submodule_frozen(self): method test_module_call_not_implemented (line 1085) | def test_module_call_not_implemented(self): method test_is_mutable_collection (line 1093) | def test_is_mutable_collection(self): method test_module_lazy_getattr_setup (line 1102) | def test_module_lazy_getattr_setup(self): method test_module_lazy_dir_setup (line 1125) | def test_module_lazy_dir_setup(self): method test_module_unbound_getattr (line 1147) | def test_module_unbound_getattr(self): method test_unbound_setup_call (line 1165) | def test_unbound_setup_call(self): method test_module_pass_as_attr (line 1179) | def test_module_pass_as_attr(self): method test_module_pass_in_closure (line 1207) | def test_module_pass_in_closure(self): method test_toplevel_submodule_adoption (line 1230) | def test_toplevel_submodule_adoption(self): method test_toplevel_submodule_adoption_pytree (line 1284) | def test_toplevel_submodule_adoption_pytree(self): method test_toplevel_submodule_adoption_sharing (line 1329) | def test_toplevel_submodule_adoption_sharing(self): method test_toplevel_named_submodule_adoption (line 1378) | def test_toplevel_named_submodule_adoption(self): method test_toplevel_submodule_pytree_adoption_sharing (line 1431) | def test_toplevel_submodule_pytree_adoption_sharing(self): method test_inner_class_def (line 1465) | def test_inner_class_def(self): method test_sow (line 1478) | def test_sow(self): method test_capture_intermediates (line 1505) | def test_capture_intermediates(self): method test_perturb (line 1521) | def test_perturb(self): method test_perturb_setup (line 1547) | def test_perturb_setup(self): method test_perturb_noop (line 1575) | def test_perturb_noop(self): method test_functional_apply (line 1601) | def test_functional_apply(self): method test_bind (line 1619) | def test_bind(self): method test_bind_stateful (line 1636) | def test_bind_stateful(self): method test_unbind (line 1664) | def test_unbind(self): method test_bind_unbind_equality (line 1705) | def test_bind_unbind_equality(self): method test_passing_mutable_variables (line 1731) | def test_passing_mutable_variables(self): method test_super_compact (line 1742) | def test_super_compact(self): method test_super_setup (line 1769) | def test_super_setup(self): method test_freeze_attr (line 1790) | def test_freeze_attr(self): method test_generic_multiple_inheritance (line 1802) | def test_generic_multiple_inheritance(self): method test_jit_rng_equivalance (line 1817) | def test_jit_rng_equivalance(self): method test_rng_reuse_after_rewind (line 1824) | def test_rng_reuse_after_rewind(self): method test_module_get_put_has_variable (line 1851) | def test_module_get_put_has_variable(self): method test_generic_module (line 1893) | def test_generic_module(self): method test_modifying_attribs_in_post_init (line 1911) | def test_modifying_attribs_in_post_init(self): method test_has_rng (line 1933) | def test_has_rng(self): method test_is_initializing (line 1945) | def test_is_initializing(self): method test_throws_invalid_instance_module_error (line 1955) | def test_throws_invalid_instance_module_error(self): method test_throws_incorrect_post_init_override_error (line 1977) | def test_throws_incorrect_post_init_override_error(self): method test_deepcopy_unspecified_parent (line 1993) | def test_deepcopy_unspecified_parent(self): method test_type_hints (line 2001) | def test_type_hints(self): method test_incorrect_property (line 2008) | def test_incorrect_property(self): method test_custom_descriptor (line 2023) | def test_custom_descriptor(self): method test_custom_descriptor_error (line 2038) | def test_custom_descriptor_error(self): method test_nested_external_modules (line 2055) | def test_nested_external_modules(self): method test_getattribute_triggers_setup (line 2082) | def test_getattribute_triggers_setup(self): method test_nested_sequential_in_call (line 2103) | def test_nested_sequential_in_call(self): method test_setup_called_bounded_submodules (line 2115) | def test_setup_called_bounded_submodules(self): method test_call_bounded_toplevel_mutable (line 2137) | def test_call_bounded_toplevel_mutable(self): method test_nested_init (line 2164) | def test_nested_init(self): method test_nested_shared (line 2202) | def test_nested_shared(self): method test_repr (line 2243) | def test_repr(self): method test_repr_should_not_cause_setup (line 2260) | def test_repr_should_not_cause_setup(self): method test_kw_only (line 2282) | def test_kw_only(self): method test_positional_cannot_be_kw_only (line 2306) | def test_positional_cannot_be_kw_only(self): method test_module_path_empty (line 2318) | def test_module_path_empty(self): method test_module_path_unbound_module_error (line 2334) | def test_module_path_unbound_module_error(self): method test_module_path_in_nested_module (line 2339) | def test_module_path_in_nested_module(self): method test_intercept_methods (line 2405) | def test_intercept_methods(self): method test_intercept_methods_compact (line 2431) | def test_intercept_methods_compact(self): method test_intercept_methods_setup (line 2457) | def test_intercept_methods_setup(self): method test_intercept_methods_calling_underlying_optional (line 2487) | def test_intercept_methods_calling_underlying_optional(self): method test_intercept_methods_run_in_lifo_order (line 2503) | def test_intercept_methods_run_in_lifo_order(self): method test_intercept_methods_subclasses (line 2530) | def test_intercept_methods_subclasses(self): method test_intercept_methods_nested_module (line 2557) | def test_intercept_methods_nested_module(self): method test_cloudpickle_class (line 2588) | def test_cloudpickle_class(self): method test_cloudpickle_module (line 2599) | def test_cloudpickle_module(self): method test_module_paths (line 2624) | def test_module_paths(self): method test_init_apply_default_rng (line 2659) | def test_init_apply_default_rng(self): method test_default_make_rng (line 2739) | def test_default_make_rng(self): method test_default_rng_error (line 2773) | def test_default_rng_error(self): method test_compact_name_scope (line 2796) | def test_compact_name_scope(self): method test_compact_name_scope_outside_compact (line 2832) | def test_compact_name_scope_outside_compact(self): class LeakTests (line 2867) | class LeakTests(absltest.TestCase): method test_tracer_leaks (line 2868) | def test_tracer_leaks(self): class RelaxedNamingTests (line 2892) | class RelaxedNamingTests(absltest.TestCase): method test_relaxed_adoption (line 2893) | def test_relaxed_adoption(self): method test_class_optional_adoption_name_preservation (line 2924) | def test_class_optional_adoption_name_preservation(self): method test_nested_class_optional_adoption_name_preservation (line 2963) | def test_nested_class_optional_adoption_name_preservation(self): method test_relaxed_adoption_still_conflict_checks (line 2995) | def test_relaxed_adoption_still_conflict_checks(self): method test_relaxed_adoption_unnamed_adoptee (line 3018) | def test_relaxed_adoption_unnamed_adoptee(self): method test_relaxed_python_conflict (line 3049) | def test_relaxed_python_conflict(self): method test_relaxed_intercollection_conflict (line 3063) | def test_relaxed_intercollection_conflict(self): method test_relaxed_intercollection_conflict_set (line 3076) | def test_relaxed_intercollection_conflict_set(self): method test_internal_deep_clone (line 3091) | def test_internal_deep_clone(self): method test_copy_method (line 3119) | def test_copy_method(self): method test_copy_from_template (line 3139) | def test_copy_from_template(self): method test_nonstring_keys_in_dict_on_module (line 3172) | def test_nonstring_keys_in_dict_on_module(self): class FrozenDictTests (line 3187) | class FrozenDictTests(absltest.TestCase): method test_frozendict_flag (line 3188) | def test_frozendict_flag(self): class ShareScopeTest (line 3202) | class ShareScopeTest(absltest.TestCase): method test_basic (line 3203) | def test_basic(self): method test_child_scope (line 3227) | def test_child_scope(self): method test_in_compact (line 3251) | def test_in_compact(self): method test_adopt_child_name (line 3280) | def test_adopt_child_name(self): method test_other_scope_is_none (line 3309) | def test_other_scope_is_none(self): method test_external_grandchild_scope_correct (line 3342) | def test_external_grandchild_scope_correct(self): FILE: tests/linen/linen_recurrent_test.py class RNNTest (line 30) | class RNNTest(absltest.TestCase): method test_rnn_basic_forward (line 31) | def test_rnn_basic_forward(self): method test_rnn_multiple_batch_dims (line 54) | def test_rnn_multiple_batch_dims(self): method test_rnn_unroll (line 77) | def test_rnn_unroll(self): method test_rnn_time_major (line 100) | def test_rnn_time_major(self): method test_rnn_with_spatial_dimensions (line 129) | def test_rnn_with_spatial_dimensions(self): method test_numerical_equivalence (line 164) | def test_numerical_equivalence(self): method test_numerical_equivalence_with_mask (line 187) | def test_numerical_equivalence_with_mask(self): method test_numerical_equivalence_single_batch (line 224) | def test_numerical_equivalence_single_batch(self): method test_numerical_equivalence_single_batch_nn_scan (line 252) | def test_numerical_equivalence_single_batch_nn_scan(self): method test_numerical_equivalence_single_batch_jax_scan (line 290) | def test_numerical_equivalence_single_batch_jax_scan(self): method test_reverse (line 321) | def test_reverse(self): method test_reverse_but_keep_order (line 352) | def test_reverse_but_keep_order(self): method test_flip_sequence (line 390) | def test_flip_sequence(self): method test_flip_sequence_more_feature_dims (line 400) | def test_flip_sequence_more_feature_dims(self): method test_flip_sequence_time_major (line 410) | def test_flip_sequence_time_major(self): method test_flip_sequence_time_major_more_feature_dims (line 420) | def test_flip_sequence_time_major_more_feature_dims(self): method test_basic_seq_lengths (line 430) | def test_basic_seq_lengths(self): class BidirectionalTest (line 437) | class BidirectionalTest(absltest.TestCase): method test_bidirectional (line 438) | def test_bidirectional(self): method test_shared_cell (line 454) | def test_shared_cell(self): method test_custom_merge_fn (line 469) | def test_custom_merge_fn(self): method test_return_carry (line 487) | def test_return_carry(self): FILE: tests/linen/linen_test.py function check_eq (line 36) | def check_eq(xs, ys): class PoolTest (line 42) | class PoolTest(parameterized.TestCase): method test_pool_custom_reduce (line 43) | def test_pool_custom_reduce(self): method test_avg_pool (line 52) | def test_avg_pool(self, count_include_pad): method test_avg_pool_no_batch (line 70) | def test_avg_pool_no_batch(self, count_include_pad): method test_max_pool (line 85) | def test_max_pool(self): method test_avg_pool_padding_same (line 109) | def test_avg_pool_padding_same(self, count_include_pad): method test_pooling_variable_batch_dims (line 125) | def test_pooling_variable_batch_dims(self): method test_pooling_no_batch_dims (line 131) | def test_pooling_no_batch_dims(self): class NormalizationTest (line 138) | class NormalizationTest(parameterized.TestCase): method test_layer_norm_mask (line 139) | def test_layer_norm_mask(self): method test_rms_norm_mask (line 156) | def test_rms_norm_mask(self): method test_group_norm_mask (line 170) | def test_group_norm_mask(self): method test_batch_norm (line 193) | def test_batch_norm(self, test_mask): method test_batch_norm_complex (line 226) | def test_batch_norm_complex(self, test_mask): method test_layer_norm (line 269) | def test_layer_norm(self, reduction_axes, use_fast_variance=True): method test_rms_norm (line 294) | def test_rms_norm(self, reduction_axes): method test_group_norm (line 310) | def test_group_norm(self): method test_group_norm_unbatched (line 331) | def test_group_norm_unbatched(self): method test_group_norm_batched (line 356) | def test_group_norm_batched(self): method test_group_norm_raises (line 386) | def test_group_norm_raises(self): method test_group_norm_raises_incorrect_reduction_axes (line 398) | def test_group_norm_raises_incorrect_reduction_axes(self): method test_batch_norm_multi_init (line 414) | def test_batch_norm_multi_init(self): method test_instance_norm (line 438) | def test_instance_norm(self, feature_axes, use_fast_variance=True): method test_instance_norm_raise_error (line 477) | def test_instance_norm_raise_error(self, feature_axes): method test_normalization_equivalence (line 534) | def test_normalization_equivalence(self, layer1, layer2): method test_spectral_norm_train (line 581) | def test_spectral_norm_train(self, model_index, key_paths): method test_spectral_norm_sigma (line 664) | def test_spectral_norm_sigma(self, n_steps, update_stats, result): method test_spectral_norm_3d_tensor (line 693) | def test_spectral_norm_3d_tensor(self, error_on_non_matrix): method test_manual_weight_norm (line 735) | def test_manual_weight_norm( method test_weight_norm_variable_filter (line 812) | def test_weight_norm_variable_filter(self, variable_filters, key_paths): method test_weight_norm_train (line 850) | def test_weight_norm_train(self, model_index, key_paths): method test_weight_norm_compatibility_with_partitioning (line 921) | def test_weight_norm_compatibility_with_partitioning(self): class StochasticTest (line 965) | class StochasticTest(parameterized.TestCase): method test_dropout (line 966) | def test_dropout(self): method test_dropout_rate_stats (line 986) | def test_dropout_rate_stats(self): method test_dropout_rate_limits (line 1005) | def test_dropout_rate_limits(self): method test_dropout_broadcast (line 1064) | def test_dropout_broadcast( method test_dropout_manual_rng (line 1084) | def test_dropout_manual_rng(self): class RecurrentTest (line 1102) | class RecurrentTest(parameterized.TestCase): method test_lstm (line 1103) | def test_lstm(self): method test_gated_units (line 1159) | def test_gated_units(self, module_cls, expected_param_shapes): method test_complex_input_gated_units (line 1190) | def test_complex_input_gated_units(self, module_cls): method test_convlstm (line 1201) | def test_convlstm(self): method test_optimized_lstm_cell_matches_regular (line 1222) | def test_optimized_lstm_cell_matches_regular(self): method test_mgu_reset_gate (line 1246) | def test_mgu_reset_gate(self): class IdsTest (line 1271) | class IdsTest(absltest.TestCase): method test_hashable (line 1272) | def test_hashable(self): function get_fp8_dtypes (line 1283) | def get_fp8_dtypes(fp8_genre): class Fp8Test (line 1293) | class Fp8Test(parameterized.TestCase): method test_fp8_einsum (line 1299) | def test_fp8_einsum(self, x_shape, y_shape, g_shape, eqn): method test_fp8_dot_general_injection (line 1338) | def test_fp8_dot_general_injection(self, fp8_genre): method test_fp8_train_state (line 1408) | def test_fp8_train_state(self, fp8_genre): method test_fp8_meta_dtype (line 1499) | def test_fp8_meta_dtype(self, fp8_genre, use_jit): FILE: tests/linen/linen_transforms_test.py function tree_equals (line 41) | def tree_equals(x, y): function tree_allclose (line 45) | def tree_allclose(x, y): class TransformedMLP (line 54) | class TransformedMLP(nn.Module): method __call__ (line 59) | def __call__(self, inputs): function decorated_MLP (line 69) | def decorated_MLP(transform: Callable = id_fn): class TransformTest (line 87) | class TransformTest(parameterized.TestCase): method assert_keys_equal (line 89) | def assert_keys_equal(self, key1, key2): method test_jit (line 93) | def test_jit(self): method test_jit_decorated (line 105) | def test_jit_decorated(self): method test_jit_init_fn (line 117) | def test_jit_init_fn(self): method test_remat (line 130) | def test_remat(self): method test_remat_decorated (line 142) | def test_remat_decorated(self): method test_remat_kwargs (line 154) | def test_remat_kwargs(self): method test_remat_static_argnums (line 174) | def test_remat_static_argnums(self): method test_remat_decorator_static_argnums (line 206) | def test_remat_decorator_static_argnums(self): method test_vmap (line 242) | def test_vmap(self): method test_vmap_decorated (line 266) | def test_vmap_decorated(self): method test_vmap_batchnorm (line 290) | def test_vmap_batchnorm(self): method test_scan (line 323) | def test_scan(self): method test_scan_decorated (line 358) | def test_scan_decorated(self): method test_scan_negative_axes (line 396) | def test_scan_negative_axes(self): method test_multiscope_lifting_simple (line 427) | def test_multiscope_lifting_simple(self): method test_multiscope_lifting_simple_decorator (line 472) | def test_multiscope_lifting_simple_decorator(self): method test_multiscope_lifting_argtree (line 520) | def test_multiscope_lifting_argtree(self): method test_multiscope_lifting_argtree_decorator (line 572) | def test_multiscope_lifting_argtree_decorator(self): method test_multiscope_lifting_simple_decorator_w_jit (line 627) | def test_multiscope_lifting_simple_decorator_w_jit(self): method test_vmapped_outer_module (line 677) | def test_vmapped_outer_module(self): method test_module_transform_with_setup (line 721) | def test_module_transform_with_setup(self): method test_nested_module_args_vmap (line 740) | def test_nested_module_args_vmap(self): method test_nested_module_args_vmap_2 (line 785) | def test_nested_module_args_vmap_2(self): method test_nested_setup_calls_count (line 838) | def test_nested_setup_calls_count(self): method test_multimethod_setup_calls (line 880) | def test_multimethod_setup_calls(self): method test_toplevel_submodule_adoption_transform (line 918) | def test_toplevel_submodule_adoption_transform(self): method test_toplevel_submodule_adoption_pytree_transform (line 989) | def test_toplevel_submodule_adoption_pytree_transform(self): method test_partially_applied_module_constructor_transform (line 1041) | def test_partially_applied_module_constructor_transform(self): method test_partial_module_method (line 1057) | def test_partial_module_method(self): method test_variable_in_args_transform (line 1081) | def test_variable_in_args_transform(self): method test_module_instance_in_args_transform (line 1115) | def test_module_instance_in_args_transform(self): method test_module_instance_in_args_transform_nested (line 1157) | def test_module_instance_in_args_transform_nested(self): method test_nested_variable_passing (line 1207) | def test_nested_variable_passing(self): method test_returned_module_warning (line 1253) | def test_returned_module_warning(self): method test_returned_variable_warning (line 1275) | def test_returned_variable_warning(self): method test_nowrap (line 1291) | def test_nowrap(self): method test_map_variables_tied_autoencoder (line 1307) | def test_map_variables_tied_autoencoder(self): method test_map_variables_bit_weights (line 1343) | def test_map_variables_bit_weights(self): method test_remat_scan (line 1360) | def test_remat_scan(self): method test_vjp (line 1377) | def test_vjp(self): method test_jvp (line 1406) | def test_jvp(self): method test_complicated_alias_mutation (line 1433) | def test_complicated_alias_mutation(self): method test_custom_vjp (line 1481) | def test_custom_vjp(self): method test_transform_with_setup_and_methods_on_submodules (line 1506) | def test_transform_with_setup_and_methods_on_submodules(self): method test_transform_methods_on_submodules_still_reserve_names (line 1546) | def test_transform_methods_on_submodules_still_reserve_names(self): method test_transform_setup_still_reserve_names (line 1565) | def test_transform_setup_still_reserve_names(self): method test_transform_with_setup_and_methods_on_submodule_pytrees (line 1588) | def test_transform_with_setup_and_methods_on_submodule_pytrees(self): method test_transform_setup_still_reserve_names_pytrees (line 1621) | def test_transform_setup_still_reserve_names_pytrees(self): method test_scan_of_setup_parameter (line 1645) | def test_scan_of_setup_parameter(self): method test_multi_method_class_transform (line 1663) | def test_multi_method_class_transform(self): method test_compact_aliasing_collision (line 1708) | def test_compact_aliasing_collision(self): method test_compact_aliasing_collision_arg_and_attrib (line 1730) | def test_compact_aliasing_collision_arg_and_attrib(self): method test_jit_with_setup_helpers (line 1751) | def test_jit_with_setup_helpers(self): method test_jit_kwargs (line 1786) | def test_jit_kwargs(self): method test_jit_static_argnames (line 1798) | def test_jit_static_argnames(self): method test_jit_and_sow (line 1815) | def test_jit_and_sow(self): method test_fold_rngs (line 1841) | def test_fold_rngs(self): method test_same_key (line 1861) | def test_same_key(self): method test_jit_repr_hash (line 1896) | def test_jit_repr_hash(self): method test_jit_reuse (line 1915) | def test_jit_reuse(self): method test_jit_recursive (line 1936) | def test_jit_recursive(self): method test_jit_reuse_hash (line 1960) | def test_jit_reuse_hash(self, jit_class: bool): method test_jit_reuse_submodules (line 1989) | def test_jit_reuse_submodules(self, jit_class: bool): method test_jit_stateful_submodules (line 2028) | def test_jit_stateful_submodules(self, jit_class: bool): method test_jit_reuse_nested_submodules (line 2076) | def test_jit_reuse_nested_submodules(self, jit_class: bool): method test_jit_hashes_serializable_types (line 2115) | def test_jit_hashes_serializable_types(self): method test_while_loop (line 2158) | def test_while_loop(self): method test_while_loop_denylist_split_rngs (line 2221) | def test_while_loop_denylist_split_rngs(self): method test_cond (line 2250) | def test_cond(self): method test_switch (line 2268) | def test_switch(self): method test_switch_multihead (line 2304) | def test_switch_multihead(self): method test_lift_instance_error (line 2367) | def test_lift_instance_error(self): method test_scan_compact_count (line 2377) | def test_scan_compact_count(self): method test_bound_methods_in_direct_transforms (line 2400) | def test_bound_methods_in_direct_transforms(self): method test_add_metadata_axis (line 2423) | def test_add_metadata_axis(self): method test_outer_setup_called_with_sharing_across_transforms (line 2483) | def test_outer_setup_called_with_sharing_across_transforms(self): method test_grad_simple (line 2519) | def test_grad_simple(self): method test_grad_simple_with_aux (line 2544) | def test_grad_simple_with_aux(self): method test_value_and_grad_simple (line 2570) | def test_value_and_grad_simple(self): method test_value_and_grad_simple_with_aux (line 2595) | def test_value_and_grad_simple_with_aux(self): method test_value_and_grad_multiscope (line 2621) | def test_value_and_grad_multiscope(self): method test_value_and_grad_multiscope_adopted (line 2657) | def test_value_and_grad_multiscope_adopted(self): method test_vmap_add_remove_axis_transforms (line 2686) | def test_vmap_add_remove_axis_transforms(self): method test_vjp_tracer_leak (line 2728) | def test_vjp_tracer_leak(self): method test_jit_scan_retracing (line 2750) | def test_jit_scan_retracing(self, retracing_scan: bool): FILE: tests/linen/partitioning_test.py class PartitioningTest (line 37) | class PartitioningTest(parameterized.TestCase): method test_axis_rules (line 38) | def test_axis_rules(self): method test_axis_rules_context (line 45) | def test_axis_rules_context(self): method test_logical_to_mesh_axes_resolves_to_none_or_unconstrained (line 52) | def test_logical_to_mesh_axes_resolves_to_none_or_unconstrained(self): method test_logical_to_mesh_axes (line 67) | def test_logical_to_mesh_axes(self): method test_logical_to_mesh_axes_priorities (line 97) | def test_logical_to_mesh_axes_priorities(self): method test_logical_to_mesh_axes_cases (line 164) | def test_logical_to_mesh_axes_cases(self, rules, axes, expected): method test_with_sharding_constraint (line 170) | def test_with_sharding_constraint(self, wsc_fn): method test_with_sharding_constraint_fallback (line 190) | def test_with_sharding_constraint_fallback(self, wsc_fn): method test_param_with_axes_no_axes (line 222) | def test_param_with_axes_no_axes(self, axes_spec): method test_param_with_axes (line 239) | def test_param_with_axes(self): method test_param_pytree_with_axes (line 268) | def test_param_pytree_with_axes(self): method test_variable_with_axes_no_axes (line 309) | def test_variable_with_axes_no_axes(self, axes_spec): method test_variable_with_empty_tuple_has_empty_axes (line 322) | def test_variable_with_empty_tuple_has_empty_axes(self): method test_variable_with_axes (line 337) | def test_variable_with_axes(self): method test_variable_with_axes_fallback (line 363) | def test_variable_with_axes_fallback(self, wsc_fn): method test_scan_with_axes (line 400) | def test_scan_with_axes(self): method test_vmap_with_axes (line 485) | def test_vmap_with_axes(self): method test_logical_with_mesh_and_rules (line 551) | def test_logical_with_mesh_and_rules(self): FILE: tests/linen/summary_test.py function _get_shapes (line 34) | def _get_shapes(pytree): function _get_obj_repr_value (line 40) | def _get_obj_repr_value(x): class ConvBlock (line 46) | class ConvBlock(nn.Module): method setup (line 51) | def setup(self) -> None: method block_method (line 56) | def block_method(self, x: Array, training: bool) -> Array: method __call__ (line 67) | def __call__(self, x: Array, training: bool) -> Array: class CNN (line 79) | class CNN(nn.Module): method setup (line 82) | def setup(self) -> None: method cnn_method (line 87) | def cnn_method(self, x: Array, training: bool) -> Array: method __call__ (line 99) | def __call__(self, x: Array, training: bool) -> Array: class SummaryTest (line 112) | class SummaryTest(absltest.TestCase): method test_module_summary (line 113) | def test_module_summary(self): method test_module_summary_with_depth (line 229) | def test_module_summary_with_depth(self): method test_tabulate (line 304) | def test_tabulate(self): method test_tabulate_with_sow (line 355) | def test_tabulate_with_sow(self): method test_tabulate_with_method (line 374) | def test_tabulate_with_method(self): method test_tabulate_function (line 393) | def test_tabulate_function(self): method test_lifted_transform (line 439) | def test_lifted_transform(self): method test_lifted_transform_no_rename (line 476) | def test_lifted_transform_no_rename(self): method test_module_reuse (line 513) | def test_module_reuse(self): method test_empty_input (line 594) | def test_empty_input(self): method test_numpy_scalar (line 614) | def test_numpy_scalar(self): method test_partitioned_params (line 641) | def test_partitioned_params(self): method test_non_array_variables (line 694) | def test_non_array_variables(self): method test_tabulate_param_count_and_flops (line 717) | def test_tabulate_param_count_and_flops(self): method test_tabulate_enum (line 738) | def test_tabulate_enum(self): method test_tabulate_norm_wrapper (line 758) | def test_tabulate_norm_wrapper(self): FILE: tests/nnx/bridge/module_test.py class TestBridgeModule (line 33) | class TestBridgeModule(absltest.TestCase): method test_update (line 34) | def test_update(self): method test_module_stack (line 42) | def test_module_stack(self): method test_compact_basic (line 60) | def test_compact_basic(self): method test_mutable_state (line 108) | def test_mutable_state(self): method test_compact_parent_none (line 139) | def test_compact_parent_none(self): method test_dense_port (line 164) | def test_dense_port(self): method test_metadata (line 243) | def test_metadata(self): method test_pure_nnx_submodule (line 279) | def test_pure_nnx_submodule(self): method test_pure_nnx_submodule_modified_rng (line 334) | def test_pure_nnx_submodule_modified_rng(self): method test_linen_submodule (line 356) | def test_linen_submodule(self): method test_name (line 408) | def test_name(self): method test_transforms (line 421) | def test_transforms(self): method test_shared_modules (line 463) | def test_shared_modules(self): method test_linen_layer_naming (line 497) | def test_linen_layer_naming(self): FILE: tests/nnx/bridge/wrappers_test.py class TestCompatibility (line 29) | class TestCompatibility(absltest.TestCase): method setUp (line 30) | def setUp(self): method test_functional (line 36) | def test_functional(self): method test_linen_to_nnx (line 47) | def test_linen_to_nnx(self): method test_linen_to_nnx_submodule (line 62) | def test_linen_to_nnx_submodule(self): method test_linen_to_nnx_noncall_method (line 88) | def test_linen_to_nnx_noncall_method(self): method test_linen_to_nnx_mutable (line 113) | def test_linen_to_nnx_mutable(self): method test_linen_to_nnx_transform (line 129) | def test_linen_to_nnx_transform(self): method test_linen_to_nnx_metadata (line 151) | def test_linen_to_nnx_metadata(self): method test_linen_to_nnx_state_structure_consistency (line 192) | def test_linen_to_nnx_state_structure_consistency(self): method test_adding_new_attributes (line 236) | def test_adding_new_attributes(self): method test_nnx_to_linen (line 262) | def test_nnx_to_linen(self): method test_nnx_to_linen_multiple_rngs (line 269) | def test_nnx_to_linen_multiple_rngs(self): method test_nnx_to_linen_multiple_collections (line 290) | def test_nnx_to_linen_multiple_collections(self): method test_nnx_to_linen_mutable (line 308) | def test_nnx_to_linen_mutable(self): method test_to_linen_method_call (line 328) | def test_to_linen_method_call(self): method test_to_linen_nnx_method_arg (line 348) | def test_to_linen_nnx_method_arg(self): method test_nnx_to_linen_mutated_static_data (line 360) | def test_nnx_to_linen_mutated_static_data(self): method test_nnx_to_linen_transforms (line 383) | def test_nnx_to_linen_transforms(self): method test_nnx_to_linen_metadata (line 403) | def test_nnx_to_linen_metadata(self): method test_nnx_to_linen_metadata_transform (line 418) | def test_nnx_to_linen_metadata_transform(self): method test_nnx_to_linen_pytree_structure_consistency (line 422) | def test_nnx_to_linen_pytree_structure_consistency(self): method test_nnx_linen_nnx (line 473) | def test_nnx_linen_nnx(self): method test_linen_nnx_linen (line 525) | def test_linen_nnx_linen(self): FILE: tests/nnx/containers_test.py class TestContainers (line 21) | class TestContainers(absltest.TestCase): method test_unbox (line 22) | def test_unbox(self): method test_on_set_value (line 30) | def test_on_set_value(self): method test_module_unbox (line 39) | def test_module_unbox(self): method test_module_box (line 49) | def test_module_box(self): FILE: tests/nnx/filters_test.py class TestFilters (line 20) | class TestFilters(absltest.TestCase): method test_path_contains (line 21) | def test_path_contains(self): FILE: tests/nnx/graph_utils_test.py class StatefulLinear (line 27) | class StatefulLinear(nnx.Module): method __init__ (line 28) | def __init__(self, din, dout, rngs): method increment (line 33) | def increment(self): method __call__ (line 36) | def __call__(self, x): class TestGraphUtils (line 41) | class TestGraphUtils(parameterized.TestCase): method test_flatten (line 42) | def test_flatten(self): method test_flatten_no_paths (line 56) | def test_flatten_no_paths(self): method test_unflatten (line 72) | def test_unflatten(self): method test_flatten_unflatten_unkown_leaves (line 82) | def test_flatten_unflatten_unkown_leaves(self, graph): method test_split_merge_unkown_leaves (line 92) | def test_split_merge_unkown_leaves(self, graph): method test_split_merge_unkown_leaves_with_filters (line 102) | def test_split_merge_unkown_leaves_with_filters(self, graph): method test_unflatten_pure_dict (line 111) | def test_unflatten_pure_dict(self): method test_unflatten_pytree (line 122) | def test_unflatten_pytree(self): method test_unflatten_empty (line 131) | def test_unflatten_empty(self): method test_unflatten_return_variables (line 140) | def test_unflatten_return_variables(self): method test_update_dynamic (line 153) | def test_update_dynamic(self): method test_update_from_pure_dict (line 165) | def test_update_from_pure_dict(self): method test_module_list (line 179) | def test_module_list(self, graph): method test_shared_variables (line 195) | def test_shared_variables(self): method test_tied_weights (line 207) | def test_tied_weights(self): method test_tied_weights_example (line 225) | def test_tied_weights_example(self): method test_state_variables_shared_with_graph (line 259) | def test_state_variables_shared_with_graph(self): method test_shared_state_variables_shared_with_graph (line 279) | def test_shared_state_variables_shared_with_graph(self): method test_pytree_flatten (line 309) | def test_pytree_flatten(self): method test_pytree_node (line 329) | def test_pytree_node(self): method test_cached_unflatten (line 354) | def test_cached_unflatten(self): method test_cached_unflatten_swap_variables (line 394) | def test_cached_unflatten_swap_variables(self): method test_cached_unflatten_add_self_reference (line 434) | def test_cached_unflatten_add_self_reference(self): method test_call_jit_update (line 470) | def test_call_jit_update(self): method test_stateful_linear (line 494) | def test_stateful_linear(self): method test_getitem (line 511) | def test_getitem(self): method test_object_state_propagation (line 525) | def test_object_state_propagation(self): method test_object_state_propagation_nested (line 537) | def test_object_state_propagation_nested(self): method test_split_merge_context (line 558) | def test_split_merge_context(self): method test_split_merge_context_example (line 579) | def test_split_merge_context_example(self): method test_split_merge_context_nested (line 598) | def test_split_merge_context_nested(self): method test_split_merge_update_context (line 618) | def test_split_merge_update_context(self): method test_to_tree_simple (line 674) | def test_to_tree_simple(self): method test_to_tree_update_context (line 701) | def test_to_tree_update_context(self): method test_graph_flatten_with_data_wrapper (line 769) | def test_graph_flatten_with_data_wrapper(self): method test_to_tree_consistent_prefix (line 783) | def test_to_tree_consistent_prefix(self): method test_simple_vmap (line 793) | def test_simple_vmap(self): method test_split_variable (line 867) | def test_split_variable(self, graph): method test_split_filter_variable (line 879) | def test_split_filter_variable(self, graph): method test_split_update_variable (line 897) | def test_split_update_variable(self, graph): method test_split_update_filter_variable (line 911) | def test_split_update_filter_variable(self, graph): method test_split_leaf (line 935) | def test_split_leaf(self, leaf_fn): method test_jit_variable (line 945) | def test_jit_variable(self, graph, graph_updates): method test_jit_pytree_of_variables (line 956) | def test_jit_pytree_of_variables(self): method test_variable_reference_in_module (line 974) | def test_variable_reference_in_module(self): method test_variables_example (line 993) | def test_variables_example(self, graph, graph_updates): method test_array_attributes (line 1017) | def test_array_attributes(self, graph): method test_transform_array_attributes (line 1036) | def test_transform_array_attributes(self): method test_data_after_init (line 1053) | def test_data_after_init(self): method test_update_dict (line 1065) | def test_update_dict(self): method test_pop_dict (line 1087) | def test_pop_dict(self): method test_iter_graph (line 1104) | def test_iter_graph(self): method test_cached_partial_docstring_example (line 1147) | def test_cached_partial_docstring_example(self): method test_find_duplicates (line 1169) | def test_find_duplicates(self): method test_resursive_map (line 1182) | def test_resursive_map(self): method test_resursive_map_replace (line 1205) | def test_resursive_map_replace(self): method test_recursive_map_with_list (line 1229) | def test_recursive_map_with_list(self, graph): method test_graphdef_hash_with_sequential (line 1240) | def test_graphdef_hash_with_sequential(self): method test_split_graph_error (line 1248) | def test_split_graph_error(self): class SimpleModule (line 1255) | class SimpleModule(nnx.Module): class TestThreading (line 1259) | class TestThreading(parameterized.TestCase): method test_threading (line 1260) | def test_threading(self): class TestTreeFlatten (line 1272) | class TestTreeFlatten(parameterized.TestCase): method test_tree_flatten_unflatten (line 1273) | def test_tree_flatten_unflatten(self): method test_tree_flatten_no_paths (line 1292) | def test_tree_flatten_no_paths(self): method test_tree_split_merge (line 1301) | def test_tree_split_merge(self): method test_tree_split_merge_module (line 1316) | def test_tree_split_merge_module(self): method test_tree_shared_variables_raises (line 1327) | def test_tree_shared_variables_raises(self): method test_tree_shared_refs_raises (line 1334) | def test_tree_shared_refs_raises(self): method test_tree_shared_variables_state_raises (line 1341) | def test_tree_shared_variables_state_raises(self): method test_tree_shared_variables_graphdef_raises (line 1348) | def test_tree_shared_variables_graphdef_raises(self): method test_tree_shared_variables_clone_raises (line 1355) | def test_tree_shared_variables_clone_raises(self): method test_tree_flatten_unflatten_ordering (line 1362) | def test_tree_flatten_unflatten_ordering(self): method test_tree_flatten_dict (line 1375) | def test_tree_flatten_dict(self): method test_tree_flatten_tuple (line 1382) | def test_tree_flatten_tuple(self): method test_tree_flatten_namedtuple (line 1391) | def test_tree_flatten_namedtuple(self): method test_tree_flatten_registered_dataclass (line 1412) | def test_tree_flatten_registered_dataclass(self): method test_tree_flatten_nested_mixed (line 1437) | def test_tree_flatten_nested_mixed(self): method test_iter_graph (line 1465) | def test_iter_graph(self, graph): method test_iter_graph_tree_mode_shared_variable_raises (line 1485) | def test_iter_graph_tree_mode_shared_variable_raises(self): method test_iter_graph_tree_mode_cycle_raises (line 1496) | def test_iter_graph_tree_mode_cycle_raises(self): method test_iter_modules (line 1507) | def test_iter_modules(self, graph): method test_iter_modules_nested (line 1517) | def test_iter_modules_nested(self, graph): method test_recursive_map_tree_mode (line 1531) | def test_recursive_map_tree_mode(self): method test_recursive_map_tree_mode_replace (line 1553) | def test_recursive_map_tree_mode_replace(self): method test_recursive_map_tree_mode_with_list (line 1575) | def test_recursive_map_tree_mode_with_list(self): method test_recursive_map_tree_mode_shared_variable_raises (line 1592) | def test_recursive_map_tree_mode_shared_variable_raises(self): method test_recursive_map_tree_mode_cycle_raises (line 1601) | def test_recursive_map_tree_mode_cycle_raises(self): method test_check_valid_pytree_flatten (line 1611) | def test_check_valid_pytree_flatten(self): method test_check_valid_pytree_iter_graph (line 1622) | def test_check_valid_pytree_iter_graph(self): method test_check_valid_pytree_iter_children (line 1633) | def test_check_valid_pytree_iter_children(self): method test_check_valid_pytree_recursive_map (line 1644) | def test_check_valid_pytree_recursive_map(self): method test_map (line 1656) | def test_map(self, graph): method test_map_with_path (line 1665) | def test_map_with_path(self): method test_map_nested (line 1678) | def test_map_nested(self): method test_map_replace (line 1690) | def test_map_replace(self): FILE: tests/nnx/helpers_test.py class TrainState (line 26) | class TrainState(nnx.TrainState): class TestHelpers (line 30) | class TestHelpers(absltest.TestCase): method test_train_state (line 31) | def test_train_state(self): method test_train_state_methods (line 45) | def test_train_state_methods(self): method test_nnx_linen_sequential_equivalence (line 76) | def test_nnx_linen_sequential_equivalence(self): method test_nnx_empty_sequential_is_identity (line 106) | def test_nnx_empty_sequential_is_identity(self): method test_dict_mutable_mapping (line 113) | def test_dict_mutable_mapping(self): method test_dict_setdefault (line 130) | def test_dict_setdefault(self): method test_dict_contains (line 139) | def test_dict_contains(self): method test_list_mutable_sequence (line 151) | def test_list_mutable_sequence(self): method test_list_fori_loop (line 183) | def test_list_fori_loop(self): method test_list_pytree_default_behavior (line 197) | def test_list_pytree_default_behavior(self): method test_list_pytree_static_elements (line 205) | def test_list_pytree_static_elements(self): method test_list_pytree_data_elements (line 210) | def test_list_pytree_data_elements(self): method test_list_pytree_mixed_static_data (line 218) | def test_list_pytree_mixed_static_data(self): method test_list_pytree_flatten_unflatten (line 230) | def test_list_pytree_flatten_unflatten(self): method test_list_pytree_jit (line 243) | def test_list_pytree_jit(self): FILE: tests/nnx/ids_test.py class TestIds (line 21) | class TestIds(absltest.TestCase): method test_hashable (line 22) | def test_hashable(self): FILE: tests/nnx/integration_test.py class TestIntegration (line 31) | class TestIntegration(parameterized.TestCase): method test_basic_view_example (line 34) | def test_basic_view_example(self, graph_mode): method test_shared_modules (line 86) | def test_shared_modules(self): method test_shared_modules_view (line 136) | def test_shared_modules_view(self): method test_shared_modules_pure (line 186) | def test_shared_modules_pure(self): method test_shared_modules_pure_view (line 245) | def test_shared_modules_pure_view(self): method test_stateful_example (line 305) | def test_stateful_example(self, graph_mode): method test_functional_example (line 346) | def test_functional_example(self): method test_intermediates_example (line 388) | def test_intermediates_example(self): method test_intermediates_example_functional (line 408) | def test_intermediates_example_functional(self): method test_replace_by_pure_dict (line 430) | def test_replace_by_pure_dict(self): method test_example_mutable_arrays (line 467) | def test_example_mutable_arrays(self): method test_tree_mode_train_step (line 501) | def test_tree_mode_train_step(self): method test_tree_mode_multi_module (line 522) | def test_tree_mode_multi_module(self): method test_tree_mode_stateful (line 565) | def test_tree_mode_stateful(self): FILE: tests/nnx/metrics_test.py class TestMetrics (line 23) | class TestMetrics(parameterized.TestCase): method test_split_merge (line 25) | def test_split_merge(self): method test_welford (line 43) | def test_welford(self): method test_welford_large (line 62) | def test_welford_large(self): method test_welford_many (line 81) | def test_welford_many(self): method test_multimetric (line 95) | def test_multimetric(self, with_mask): method test_multimetric_with_custom_metric (line 133) | def test_multimetric_with_custom_metric(self, with_mask): method test_binary_classification_accuracy (line 155) | def test_binary_classification_accuracy(self): method test_accuracy_dims (line 180) | def test_accuracy_dims(self, logits, labels, threshold, error_msg): method test_average (line 189) | def test_average(self, with_mask, scalar_values): FILE: tests/nnx/module_test.py function set_graph_mode (line 36) | def set_graph_mode(mode): class PytreeTest (line 45) | class PytreeTest(absltest.TestCase): method test_pytree (line 46) | def test_pytree(self): method test_sequential_map (line 56) | def test_sequential_map(self): method test_sequential_has_leaves (line 60) | def test_sequential_has_leaves(self): method test_consistent_attrs (line 64) | def test_consistent_attrs(self): method test_assing_pytree_with_data (line 104) | def test_assing_pytree_with_data(self): method test_consistent_attrs_frozen_dataclass (line 115) | def test_consistent_attrs_frozen_dataclass(self): method test_consistent_attrs_dataclass_annotations (line 125) | def test_consistent_attrs_dataclass_annotations(self): method test_explicit_dont_change (line 165) | def test_explicit_dont_change(self): method test_no_data_in_static (line 175) | def test_no_data_in_static(self): class TestCapture (line 186) | class TestCapture(parameterized.TestCase): method test_vmap (line 188) | def test_vmap(self): method test_fwd_bwd (line 227) | def test_fwd_bwd(self, graph_mode): method test_nested_modules (line 257) | def test_nested_modules(self, graph_mode): method test_method_outputs_single_module (line 293) | def test_method_outputs_single_module(self): method test_method_outputs_nested_modules (line 320) | def test_method_outputs_nested_modules(self): method test_method_outputs_mixed_with_sow (line 353) | def test_method_outputs_mixed_with_sow(self): class SowMod (line 376) | class SowMod(nnx.Module): method __init__ (line 377) | def __init__(self, rngs: nnx.Rngs): method __call__ (line 380) | def __call__(self, x): class TestModule (line 385) | class TestModule(parameterized.TestCase): method test_has_module_state (line 386) | def test_has_module_state(self): method test_trace_level (line 393) | def test_trace_level(self): method test_tree_map (line 406) | def test_tree_map(self): method test_split_2 (line 413) | def test_split_2(self): method test_split_merge (line 420) | def test_split_merge(self): method test_call (line 434) | def test_call(self): method test_shared_module (line 450) | def test_shared_module(self): method test_module_graph (line 460) | def test_module_graph(self): method test_deref_through_jit (line 474) | def test_deref_through_jit(self): method test_cross_barrier (line 500) | def test_cross_barrier(self): method test_no_rejit (line 515) | def test_no_rejit(self): method test_deref_number_of_fields (line 545) | def test_deref_number_of_fields(self): method test_clone (line 560) | def test_clone(self): method test_sow_existing_non_variable_field (line 577) | def test_sow_existing_non_variable_field(self): method test_sow_wrong_collection (line 592) | def test_sow_wrong_collection(self): method test_sow_pop (line 607) | def test_sow_pop(self): method test_cached_partial (line 614) | def test_cached_partial(self): method test_update_static_state_submodules (line 626) | def test_update_static_state_submodules(self): method test_update_new_submodule (line 658) | def test_update_new_submodule(self): method test_update_update_submodule (line 687) | def test_update_update_submodule(self): method test_update_add_shared (line 713) | def test_update_add_shared(self): method test_create_abstract (line 741) | def test_create_abstract(self): method test_create_abstract_stateful (line 747) | def test_create_abstract_stateful(self): method test_partial_init (line 754) | def test_partial_init(self): method test_deepcopy (line 774) | def test_deepcopy(self): method test_set_attributes (line 791) | def test_set_attributes(self): method test_set_attribute_error (line 814) | def test_set_attribute_error(self): method test_view (line 844) | def test_view(self, graph): method test_with_attributes (line 868) | def test_with_attributes(self, graph): method test_with_attributes_filter (line 891) | def test_with_attributes_filter(self, graph): method test_with_attributes_error (line 908) | def test_with_attributes_error(self, graph): method test_view_error (line 938) | def test_view_error(self, graph): method test_cloud_pickle (line 956) | def test_cloud_pickle(self): method test_repr (line 987) | def test_repr(self): method test_view_info (line 1028) | def test_view_info(self, graph): method test_view_info_with_filter (line 1052) | def test_view_info_with_filter(self, graph): method test_view_info_with_custom_set_mode (line 1071) | def test_view_info_with_custom_set_mode(self, graph): class TestModuleDataclass (line 1093) | class TestModuleDataclass(absltest.TestCase): method test_basic (line 1094) | def test_basic(self): method test_field_specifiers (line 1126) | def test_field_specifiers(self): method test_field_specifiers_forced (line 1138) | def test_field_specifiers_forced(self): method test_field_specifiers_with_defaults (line 1149) | def test_field_specifiers_with_defaults(self): method test_field_specifiers_array_in_static (line 1160) | def test_field_specifiers_array_in_static(self): method test_variable_in_static_list (line 1171) | def test_variable_in_static_list(self): method test_module_in_static_list (line 1182) | def test_module_in_static_list(self): method test_post_init (line 1196) | def test_post_init(self): class TestModuleDef (line 1215) | class TestModuleDef(parameterized.TestCase): method test_apply (line 1216) | def test_apply(self): method test_derefed_mod_apply (line 1237) | def test_derefed_mod_apply(self): method test_modules_iterator (line 1261) | def test_modules_iterator(self): method test_children_modules_iterator (line 1288) | def test_children_modules_iterator(self, graph): method test_state_in_module (line 1308) | def test_state_in_module(self): FILE: tests/nnx/mutable_array_test.py class TestPytree (line 25) | class TestPytree(absltest.TestCase): method test_pytree (line 26) | def test_pytree(self): method test_pytree_data_typehint (line 39) | def test_pytree_data_typehint(self): method test_pytree_data_instance (line 54) | def test_pytree_data_instance(self): method test_pytree_dataclass (line 67) | def test_pytree_dataclass(self): method test_data_example (line 89) | def test_data_example(self): method test_register_data_type (line 99) | def test_register_data_type(self): class TestVariableRefMode (line 116) | class TestVariableRefMode(absltest.TestCase): method test_split_mutable_array (line 117) | def test_split_mutable_array(self): method test_to_arrays_example (line 127) | def test_to_arrays_example(self): method test_freeze_and_mutable_with_filter (line 145) | def test_freeze_and_mutable_with_filter(self): method test_freeze_duplicate_error (line 169) | def test_freeze_duplicate_error(self): method test_mutable_array_split (line 180) | def test_mutable_array_split(self): method test_mutable_array_split_merge_in_variable (line 197) | def test_mutable_array_split_merge_in_variable(self): method test_mutable_array_split_merge_in_variable_shared_array (line 214) | def test_mutable_array_split_merge_in_variable_shared_array(self): method test_mutable_example (line 232) | def test_mutable_example(self): method test_mutable_array_split_freeze (line 240) | def test_mutable_array_split_freeze(self): method test_update_context (line 257) | def test_update_context(self): method test_update_context_flatten (line 292) | def test_update_context_flatten(self): method test_update_context_to_tree1 (line 329) | def test_update_context_to_tree1(self): method test_update_context_to_tree2 (line 368) | def test_update_context_to_tree2(self): method test_update_context_to_tree_trivial_prefix (line 407) | def test_update_context_to_tree_trivial_prefix(self): method test_simple_jit (line 446) | def test_simple_jit(self): method test_jit_mutable (line 462) | def test_jit_mutable(self): method test_static (line 479) | def test_static(self): method test_variable_creation (line 500) | def test_variable_creation(self): method test_variable_metadata (line 506) | def test_variable_metadata(self): method test_object (line 511) | def test_object(self): method test_object_state (line 565) | def test_object_state(self): method test_rngs_create (line 590) | def test_rngs_create(self): method test_rngs_call (line 615) | def test_rngs_call(self): class TestOptimizer (line 621) | class TestOptimizer(absltest.TestCase): method test_optimize_arrays (line 622) | def test_optimize_arrays(self): method test_optimize_hijax (line 660) | def test_optimize_hijax(self): class TestHijaxVariables (line 693) | class TestHijaxVariables(parameterized.TestCase): method test_variable_to_hijax (line 694) | def test_variable_to_hijax(self): method test_from_metadata (line 723) | def test_from_metadata(self): method test_variable_to_hijax_clean (line 739) | def test_variable_to_hijax_clean(self): method test_pytree_value (line 767) | def test_pytree_value(self): method test_hijax_dynamic_structure (line 780) | def test_hijax_dynamic_structure(self): method test_hijax_and_pytree (line 793) | def test_hijax_and_pytree(self): method test_use_hijax (line 819) | def test_use_hijax(self): method test_hijax_rngs (line 832) | def test_hijax_rngs(self): method test_return_hijax_from_transform (line 847) | def test_return_hijax_from_transform(self): method test_lower (line 857) | def test_lower(self): method test_eval_shape (line 870) | def test_eval_shape(self): method test_no_qdd_grad (line 882) | def test_no_qdd_grad(self): method test_no_qdd_grad_new (line 896) | def test_no_qdd_grad_new(self): method test_variable_properties (line 913) | def test_variable_properties(self, hijax, ref): method test_variable_copy_properties (line 930) | def test_variable_copy_properties(self, hijax, ref): method test_variable_vars_as_properties (line 949) | def test_variable_vars_as_properties(self, hijax, ref): class TestVarDefaults (line 965) | class TestVarDefaults(absltest.TestCase): method test_defaults (line 966) | def test_defaults(self): method test_context_manager_hijax (line 973) | def test_context_manager_hijax(self): method test_context_manager_ref (line 983) | def test_context_manager_ref(self): method test_context_manager_nested (line 993) | def test_context_manager_nested(self): method test_mapping_protocol (line 1009) | def test_mapping_protocol(self): method test_decorator (line 1018) | def test_decorator(self): method test_variable_init_override (line 1028) | def test_variable_init_override(self): class HijaxTransformCoverageTest (line 1037) | class HijaxTransformCoverageTest(absltest.TestCase): method test_hitypes_as_grad_args (line 1042) | def test_hitypes_as_grad_args(self): method test_hitypes_as_nondiff_grad_args (line 1052) | def test_hitypes_as_nondiff_grad_args(self): method test_hitypes_as_captured_args (line 1063) | def test_hitypes_as_captured_args(self): method test_mutable_hitypes_as_grad_args (line 1075) | def test_mutable_hitypes_as_grad_args(self): method test_mutable_hitypes_as_nondiff_grad_args (line 1085) | def test_mutable_hitypes_as_nondiff_grad_args(self): method test_mutable_hitypes_as_captured_args (line 1098) | def test_mutable_hitypes_as_captured_args(self): method test_hitypes_as_scan_carry (line 1114) | def test_hitypes_as_scan_carry(self): method test_hitypes_as_scan_extensive (line 1128) | def test_hitypes_as_scan_extensive(self): method test_hitypes_as_scan_captured (line 1143) | def test_hitypes_as_scan_captured(self): method test_mutable_hitypes_as_scan_carry (line 1161) | def test_mutable_hitypes_as_scan_carry(self): method test_mutable_hitypes_as_scan_extensive (line 1173) | def test_mutable_hitypes_as_scan_extensive(self): method test_mutable_hitypes_as_scan_captured (line 1186) | def test_mutable_hitypes_as_scan_captured(self): method test_hijax_variable_in_jit_graph_updates_false (line 1196) | def test_hijax_variable_in_jit_graph_updates_false(self): FILE: tests/nnx/nn/attention_test.py class TestMultiHeadAttention (line 36) | class TestMultiHeadAttention(parameterized.TestCase): method test_basic (line 37) | def test_basic(self): method test_multihead_sow_attention_weights (line 49) | def test_multihead_sow_attention_weights(self): method test_autoregressive_decode_with_x64 (line 91) | def test_autoregressive_decode_with_x64(self): method test_keep_rngs (line 112) | def test_keep_rngs(self, keep_rngs): method test_causal_mask_equivalence (line 135) | def test_causal_mask_equivalence( class TestLinenConsistency (line 211) | class TestLinenConsistency(parameterized.TestCase): method test_nnx_attention_equivalence (line 220) | def test_nnx_attention_equivalence( class TestKVFeatures (line 276) | class TestKVFeatures(parameterized.TestCase): method test_varying_num_features (line 278) | def test_varying_num_features(self): class TestGQADotProductAttention (line 302) | class TestGQADotProductAttention(parameterized.TestCase): method test_gqa_shapes (line 304) | def test_gqa_shapes(self): method test_gqa_invalid_heads (line 319) | def test_gqa_invalid_heads(self): method test_gqa_multihead_attention (line 328) | def test_gqa_multihead_attention(self): method test_gqa_parity_with_jax (line 359) | def test_gqa_parity_with_jax(self): FILE: tests/nnx/nn/conv_test.py class TestConvLinenConsistency (line 31) | class TestConvLinenConsistency(parameterized.TestCase): method test_nnx_linen_conv_equivalence (line 45) | def test_nnx_linen_conv_equivalence( method test_nnx_linen_convtranspose_equivalence (line 137) | def test_nnx_linen_convtranspose_equivalence( FILE: tests/nnx/nn/embed_test.py class TestLinenConsistency (line 28) | class TestLinenConsistency(parameterized.TestCase): method test_nnx_linen_equivalence (line 35) | def test_nnx_linen_equivalence( FILE: tests/nnx/nn/linear_test.py class TestLinearGeneral (line 30) | class TestLinearGeneral(parameterized.TestCase): method test_basic (line 37) | def test_basic( method test_basic_multi_features (line 63) | def test_basic_multi_features(self): class TestLinenConsistency (line 73) | class TestLinenConsistency(parameterized.TestCase): method test_nnx_linear_equivalence (line 81) | def test_nnx_linear_equivalence( method test_nnx_einsum_equivalence (line 141) | def test_nnx_einsum_equivalence( method test_einsum_op (line 196) | def test_einsum_op(self): class TestPReLUConsistency (line 206) | class TestPReLUConsistency(parameterized.TestCase): method test_equivalence (line 211) | def test_equivalence(self, dtype, param_dtype): class TestLayersSameGraph (line 243) | class TestLayersSameGraph(parameterized.TestCase): method test (line 285) | def test(self, module_args_kwargs_initargs): class TestLayersParamsMetadata (line 298) | class TestLayersParamsMetadata(parameterized.TestCase): method test (line 391) | def test(self, module_args_kwargs_initargs): FILE: tests/nnx/nn/lora_test.py class TestLora (line 23) | class TestLora(absltest.TestCase): method test_basic (line 24) | def test_basic(self): method test_lora_base_module (line 34) | def test_lora_base_module(self): method test_layer_swap_lora (line 51) | def test_layer_swap_lora(self): method test_layer_swap_loralinear (line 76) | def test_layer_swap_loralinear(self): method test_lora_param_type (line 105) | def test_lora_param_type(self): method test_dtype (line 119) | def test_dtype(self): FILE: tests/nnx/nn/normalization_test.py class TestLinenConsistency (line 28) | class TestLinenConsistency(parameterized.TestCase): method test_nnx_linen_batchnorm_equivalence (line 35) | def test_nnx_linen_batchnorm_equivalence( method test_nnx_linen_layernorm_equivalence (line 134) | def test_nnx_linen_layernorm_equivalence( method test_nnx_linen_rmsnorm_equivalence (line 208) | def test_nnx_linen_rmsnorm_equivalence( method test_nnx_linen_groupnorm_equivalence (line 282) | def test_nnx_linen_groupnorm_equivalence( method test_nnx_linen_weightnorm_equivalence (line 361) | def test_nnx_linen_weightnorm_equivalence( method test_nnx_linen_instancenorm_equivalence (line 425) | def test_nnx_linen_instancenorm_equivalence( method test_nnx_linen_spectralnorm_equivalence (line 498) | def test_nnx_linen_spectralnorm_equivalence( FILE: tests/nnx/nn/recurrent_test.py class TestLSTMCell (line 27) | class TestLSTMCell(absltest.TestCase): method test_basic (line 28) | def test_basic(self): method test_lstm_sequence (line 39) | def test_lstm_sequence(self): method test_lstm_with_different_dtypes (line 55) | def test_lstm_with_different_dtypes(self): method test_lstm_with_custom_activations (line 70) | def test_lstm_with_custom_activations(self): method test_lstm_initialize_carry (line 84) | def test_lstm_initialize_carry(self): method test_lstm_with_variable_sequence_length (line 101) | def test_lstm_with_variable_sequence_length(self): method test_lstm_stateful (line 131) | def test_lstm_stateful(self): method test_lstm_equivalence_with_flax_linen (line 146) | def test_lstm_equivalence_with_flax_linen(self): class TestRNN (line 207) | class TestRNN(absltest.TestCase): method test_rnn_with_lstm_cell (line 208) | def test_rnn_with_lstm_cell(self): method test_rnn_with_gru_cell (line 233) | def test_rnn_with_gru_cell(self): method test_rnn_time_major (line 258) | def test_rnn_time_major(self): method test_rnn_reverse (line 283) | def test_rnn_reverse(self): method test_rnn_with_seq_lengths (line 311) | def test_rnn_with_seq_lengths(self): method test_rnn_with_keep_order (line 362) | def test_rnn_with_keep_order(self): method test_rnn_equivalence_with_flax_linen (line 389) | def test_rnn_equivalence_with_flax_linen(self): method test_rnn_with_unroll (line 449) | def test_rnn_with_unroll(self): method test_rnn_with_custom_cell (line 470) | def test_rnn_with_custom_cell(self): method test_rnn_with_different_dtypes (line 523) | def test_rnn_with_different_dtypes(self): method test_rnn_with_variable_batch_size (line 549) | def test_rnn_with_variable_batch_size(self): method test_recurrent_dropout (line 573) | def test_recurrent_dropout(self): FILE: tests/nnx/nn/stochastic_test.py class TestStochastic (line 24) | class TestStochastic: method test_dropout_internal_rngs (line 25) | def test_dropout_internal_rngs(self): method test_dropout_rng_override (line 56) | def test_dropout_rng_override(self): method test_dropout_arg_override (line 69) | def test_dropout_arg_override(self): method test_dropout_arg_override_view (line 91) | def test_dropout_arg_override_view(self): FILE: tests/nnx/optimizer_test.py function assert_equal (line 23) | def assert_equal(path, x, y): function assert_not_equal (line 27) | def assert_not_equal(path, x, y): class Model (line 33) | class Model(nnx.Module): method __init__ (line 35) | def __init__(self, in_features, out_features, rngs): method __call__ (line 39) | def __call__(self, x): class TestOptimizer (line 43) | class TestOptimizer(parameterized.TestCase): method test_split_merge (line 49) | def test_split_merge(self, module_cls): method test_update (line 59) | def test_update(self): method test_sharding_propagation (line 71) | def test_sharding_propagation(self): method test_jit (line 105) | def test_jit(self, module_cls, jit_decorator, optimizer): method test_jit_linesearch (line 152) | def test_jit_linesearch(self, module_cls, jit_decorator, optimizer): method test_metrics (line 207) | def test_metrics(self, module_cls, optimizer): method test_wrt_update (line 237) | def test_wrt_update(self, variable): method test_wrt_update_linesearch (line 284) | def test_wrt_update_linesearch(self, variable): FILE: tests/nnx/partitioning_test.py class TestPartitioning (line 21) | class TestPartitioning(absltest.TestCase): method test_partition (line 23) | def test_partition(self): method test_complete_partitioning (line 49) | def test_complete_partitioning(self): method test_complete_partitioning_plus_ellipsis (line 58) | def test_complete_partitioning_plus_ellipsis(self): method test_inclomplete_partition_error (line 67) | def test_inclomplete_partition_error(self): method test_ellipsis_not_last_error (line 78) | def test_ellipsis_not_last_error(self): method test_update_from (line 89) | def test_update_from(self): method test_update_from_with_array_leaf (line 108) | def test_update_from_with_array_leaf(self): method test_grad_example (line 125) | def test_grad_example(self): method test_get_paritition (line 145) | def test_get_paritition(self): FILE: tests/nnx/rngs_test.py class TestRngs (line 27) | class TestRngs(parameterized.TestCase): method test_call (line 28) | def test_call(self): method test_fallback (line 32) | def test_fallback(self): method test_fallback_error_no_default (line 36) | def test_fallback_error_no_default(self): method test_rng_stream (line 41) | def test_rng_stream(self): method test_rng_trace_level_constraints (line 56) | def test_rng_trace_level_constraints(self): method test_jit_updates (line 85) | def test_jit_updates(self): method test_lifting_rng_state (line 116) | def test_lifting_rng_state(self): method test_reseed (line 176) | def test_reseed(self, graph): method test_split_rngs (line 196) | def test_split_rngs(self, graph): method test_fork_rngs (line 212) | def test_fork_rngs(self, graph): method test_random_helpers (line 219) | def test_random_helpers(self): FILE: tests/nnx/spmd_test.py class TestSPMD (line 29) | class TestSPMD(parameterized.TestCase): method setUp (line 31) | def setUp(self): method test_init (line 35) | def test_init(self): method test_init_all_devices (line 66) | def test_init_all_devices(self): method test_shard_optimizer_state (line 95) | def test_shard_optimizer_state(self): method test_add_remove_axis_in_transform (line 128) | def test_add_remove_axis_in_transform(self): method test_transform_metadata_decorator (line 197) | def test_transform_metadata_decorator(self): method test_eager_sharding_context (line 223) | def test_eager_sharding_context(self, use_eager_sharding): method test_out_sharding_linear_layers (line 240) | def test_out_sharding_linear_layers(self): method test_out_sharding_embed (line 254) | def test_out_sharding_embed(self): method test_out_sharding_conv (line 264) | def test_out_sharding_conv(self): method test_out_sharding_embed_attend (line 273) | def test_out_sharding_embed_attend(self): method test_out_sharding_dropout (line 282) | def test_out_sharding_dropout(self): method test_logical_rules (line 301) | def test_logical_rules(self, use_hijax): method test_get_abstract_model (line 339) | def test_get_abstract_model(self): method test_sharding_axis_types (line 359) | def test_sharding_axis_types(self, mode): method test_eval_shape_with_explicit_sharding (line 388) | def test_eval_shape_with_explicit_sharding(self): method test_eval_shape_with_sharding0 (line 400) | def test_eval_shape_with_sharding0(self): method test_eval_shape_with_sharding1 (line 418) | def test_eval_shape_with_sharding1(self): method test_variable_out_sharding_types (line 431) | def test_variable_out_sharding_types(self, axis_type_name): method test_get_abstract_with_abstract_mesh (line 460) | def test_get_abstract_with_abstract_mesh(self): method test_get_abstract_with_per_variable_mesh (line 484) | def test_get_abstract_with_per_variable_mesh(self): method test_get_abstract_no_sharding_metadata (line 519) | def test_get_abstract_no_sharding_metadata(self): function has_sharding_spec (line 528) | def has_sharding_spec(array): FILE: tests/nnx/state_test.py class StateTest (line 22) | class StateTest(absltest.TestCase): method test_create_state (line 23) | def test_create_state(self): method test_get_attr (line 31) | def test_get_attr(self): method test_set_attr (line 39) | def test_set_attr(self): method test_set_attr_variables (line 50) | def test_set_attr_variables(self): method test_add_nested_attr (line 63) | def test_add_nested_attr(self): method test_delete_nested_attr (line 71) | def test_delete_nested_attr(self): method test_integer_access (line 79) | def test_integer_access(self): method test_pure_dict (line 95) | def test_pure_dict(self): method test_diff (line 110) | def test_diff(self): FILE: tests/nnx/summary_test.py class SummaryTest (line 23) | class SummaryTest(absltest.TestCase): method test_tabulate (line 24) | def test_tabulate(self): method test_multiple_inputs_and_outputs (line 77) | def test_multiple_inputs_and_outputs(self): method test_tabulate_empty_dict_first_arg (line 101) | def test_tabulate_empty_dict_first_arg(self): method test_tabulate_empty_dict_last_arg (line 117) | def test_tabulate_empty_dict_last_arg(self): method test_tabulate_empty_dict_and_none_kwarg (line 132) | def test_tabulate_empty_dict_and_none_kwarg(self): method test_tabulate_empty_dict_property (line 150) | def test_tabulate_empty_dict_property(self): method test_no_dup_flops (line 169) | def test_no_dup_flops(self): method test_flops (line 182) | def test_flops(self): method test_nested (line 201) | def test_nested(self): method test_time_complexity (line 228) | def test_time_complexity(self): method test_shared (line 254) | def test_shared(self): method test_tabulate_with_variable_hooks (line 282) | def test_tabulate_with_variable_hooks(self): method test_tabulate_concrete_shape (line 325) | def test_tabulate_concrete_shape(self): FILE: tests/nnx/test_traversals.py class TraversalTest (line 25) | class TraversalTest(absltest.TestCase): method test_flatten_mapping (line 26) | def test_flatten_mapping(self): method test_unflatten_mapping (line 53) | def test_unflatten_mapping(self): method test_flatten_mapping_keep_empty (line 71) | def test_flatten_mapping_keep_empty(self): method test_flatten_mapping_is_leaf (line 88) | def test_flatten_mapping_is_leaf(self): FILE: tests/nnx/transforms_test.py class TestJIT (line 35) | class TestJIT(parameterized.TestCase): method test_jit (line 36) | def test_jit(self): method test_mutable_array_input_output (line 49) | def test_mutable_array_input_output(self): method test_simple_double_call (line 67) | def test_simple_double_call(self, graph_mode, graph_updates): method test_jit_on_init (line 88) | def test_jit_on_init(self): method test_jit_on_call (line 117) | def test_jit_on_call(self, graph_mode, graph_updates): method test_cached_unflatten (line 147) | def test_cached_unflatten(self): method test_jit_custom_vjp (line 204) | def test_jit_custom_vjp(self, graph_mode, graph_updates): method test_cached_unflatten_same_type (line 222) | def test_cached_unflatten_same_type(self): method test_objects_in_pytree (line 252) | def test_objects_in_pytree(self): method test_cached_unflatten_swap_variables (line 286) | def test_cached_unflatten_swap_variables(self): method test_cached_unflatten_add_self_reference (line 305) | def test_cached_unflatten_add_self_reference(self): method test_cached_unflatten_ref_in_output (line 335) | def test_cached_unflatten_ref_in_output(self): method test_apply_shardings (line 369) | def test_apply_shardings(self): method test_cache_args (line 398) | def test_cache_args(self): method test_jit_wrapped (line 422) | def test_jit_wrapped(self, graph_mode, graph_updates): method test_jit_static_args_with_shardings (line 460) | def test_jit_static_args_with_shardings(self, graph_mode, graph_update... method test_with_sharding_and_static_args (line 490) | def test_with_sharding_and_static_args(self, static_args): class TestTreeJIT (line 523) | class TestTreeJIT(parameterized.TestCase): method test_tree_jit_basic (line 529) | def test_tree_jit_basic(self, graph, graph_updates): method test_tree_jit_module (line 547) | def test_tree_jit_module(self, graph, graph_updates): method test_tree_jit_variable_update (line 558) | def test_tree_jit_variable_update(self): method test_tree_jit_no_retrace (line 580) | def test_tree_jit_no_retrace(self, graph, graph_updates): method test_tree_jit_static_argnums (line 599) | def test_tree_jit_static_argnums(self): method test_tree_jit_no_input_output_aliasing (line 612) | def test_tree_jit_no_input_output_aliasing(self): method test_tree_jit_no_shared_variable_refs (line 622) | def test_tree_jit_no_shared_variable_refs(self): method test_tree_jit_new_variable_output_ok (line 634) | def test_tree_jit_new_variable_output_ok(self): method test_tree_jit_donate_argnums_unchanged_var (line 643) | def test_tree_jit_donate_argnums_unchanged_var(self): method test_tree_jit_donate_argnums_module (line 658) | def test_tree_jit_donate_argnums_module(self): method test_tree_jit_donate_argnums_with_mutation (line 675) | def test_tree_jit_donate_argnums_with_mutation(self): method test_tree_jit_donate_argnames (line 688) | def test_tree_jit_donate_argnames(self): method test_tree_jit_donate_selective (line 703) | def test_tree_jit_donate_selective(self): method test_jit_partial_basic (line 720) | def test_jit_partial_basic(self, graph_mode): method test_jit_partial_lower_compile (line 732) | def test_jit_partial_lower_compile(self, graph_mode): method test_jit_partial_variable_update (line 757) | def test_jit_partial_variable_update(self, graph_mode): method test_jit_partial_multiple_args (line 776) | def test_jit_partial_multiple_args(self, graph_mode): method test_jit_partial_no_retrace (line 789) | def test_jit_partial_no_retrace(self, graph_mode): method test_jit_partial_no_retrace_after_mutation (line 806) | def test_jit_partial_no_retrace_after_mutation(self, graph_mode): method test_jit_partial_no_partial_args (line 830) | def test_jit_partial_no_partial_args(self, graph_mode): method test_jit_partial_in_shardings_none_broadcast (line 836) | def test_jit_partial_in_shardings_none_broadcast(self, graph_mode): method test_jit_partial_in_shardings_named (line 852) | def test_jit_partial_in_shardings_named(self, graph_mode): method test_jit_partial_mixed_shardings (line 873) | def test_jit_partial_mixed_shardings(self, graph_mode): method test_jit_partial_in_shardings_non_tuple (line 894) | def test_jit_partial_in_shardings_non_tuple(self, graph_mode): method test_jit_partial_train_step (line 912) | def test_jit_partial_train_step(self, graph_mode): method test_jit_partial_shared_variable (line 929) | def test_jit_partial_shared_variable(self): method test_jit_inconsistent_aliasing (line 948) | def test_jit_inconsistent_aliasing(self): class TestEvalShape (line 965) | class TestEvalShape(parameterized.TestCase): method test_eval_shape (line 969) | def test_eval_shape(self, graph, graph_updates): method test_eval_shape_mutable_array (line 977) | def test_eval_shape_mutable_array(self): method test_eval_shape_with_module_input (line 987) | def test_eval_shape_with_module_input(self, graph, graph_updates): method test_eval_shape_no_state_update (line 1002) | def test_eval_shape_no_state_update(self, graph, graph_updates): method test_eval_shape_no_input_output_aliasing (line 1018) | def test_eval_shape_no_input_output_aliasing(self, graph, graph_updates): method test_eval_shape_no_shared_variable_refs (line 1031) | def test_eval_shape_no_shared_variable_refs(self, graph, graph_updates): class TestShardMap (line 1041) | class TestShardMap(parameterized.TestCase): method test_basic_shardmap (line 1042) | def test_basic_shardmap(self): method test_basic_shardmap_variables (line 1073) | def test_basic_shardmap_variables(self, graph, graph_updates): method test_from_state (line 1102) | def test_from_state(self): method test_simple_data_parallel (line 1135) | def test_simple_data_parallel(self, graph, graph_updates): method test_simple_tensor_parallel (line 1167) | def test_simple_tensor_parallel(self): method test_shardmap_with_sharding_names (line 1205) | def test_shardmap_with_sharding_names(self, graph, graph_updates): method test_shardmap_sharding_names_mutation (line 1234) | def test_shardmap_sharding_names_mutation(self, graph, graph_updates): method test_shardmap_shared_variable (line 1257) | def test_shardmap_shared_variable(self): method test_shardmap_module_variable_update (line 1285) | def test_shardmap_module_variable_update(self, graph, graph_updates): method test_shard_map_inconsistent_aliasing (line 1309) | def test_shard_map_inconsistent_aliasing(self): class TestGrad (line 1327) | class TestGrad(parameterized.TestCase): method test_grad (line 1328) | def test_grad(self): method test_grad_with_multiple_ref_types (line 1363) | def test_grad_with_multiple_ref_types(self): method test_grad_with_type_predicate (line 1391) | def test_grad_with_type_predicate(self): method test_multiple_inputs (line 1419) | def test_multiple_inputs(self): method test_multiple_graph_nodes (line 1443) | def test_multiple_graph_nodes(self, loss_fn, argnums): method test_multiple_args (line 1464) | def test_multiple_args(self): method test_multiple_args_in_pytrees (line 1482) | def test_multiple_args_in_pytrees(self): method test_value_and_grad_multiple_args_in_pytrees (line 1502) | def test_value_and_grad_multiple_args_in_pytrees(self): method test_value_and_grad_with_aux (line 1523) | def test_value_and_grad_with_aux(self): method test_variables_in_grad (line 1549) | def test_variables_in_grad(self): method test_tree_mode_grad (line 1582) | def test_tree_mode_grad(self, graph, graph_updates): method test_tree_mode_grad_multiple_inputs (line 1601) | def test_tree_mode_grad_multiple_inputs(self, graph, graph_updates): method test_tree_mode_grad_multiple_graph_nodes (line 1620) | def test_tree_mode_grad_multiple_graph_nodes(self, graph, graph_updates): method test_tree_mode_value_and_grad_with_aux (line 1644) | def test_tree_mode_value_and_grad_with_aux(self, graph, graph_updates): class TestCustomVJP (line 1663) | class TestCustomVJP(parameterized.TestCase): method test_basic_call (line 1669) | def test_basic_call(self, graph, graph_updates): method test_basic_call_with_state (line 1697) | def test_basic_call_with_state(self, graph): method test_jax_example (line 1731) | def test_jax_example(self, graph, graph_updates): method test_diff_state (line 1782) | def test_diff_state(self): method test_jax_example_with_remat (line 1828) | def test_jax_example_with_remat(self, graph, graph_updates): method test_two_args (line 1884) | def test_two_args(self): method test_non_diff_args (line 1942) | def test_non_diff_args(self, graph, graph_updates): method test_docs_example (line 2006) | def test_docs_example(self): method test_issue (line 2037) | def test_issue(self, use_custom_vjp: bool): method test_tree_mode_basic_call (line 2089) | def test_tree_mode_basic_call(self, graph, graph_updates): method test_tree_mode_jax_example (line 2121) | def test_tree_mode_jax_example(self, graph, graph_updates): method test_tree_mode_with_remat (line 2157) | def test_tree_mode_with_remat(self, graph, graph_updates): method test_tree_mode_non_diff_args (line 2198) | def test_tree_mode_non_diff_args(self, graph, graph_updates): method test_tree_mode_diffstate_error (line 2242) | def test_tree_mode_diffstate_error(self): method test_grad_inconsistent_aliasing (line 2251) | def test_grad_inconsistent_aliasing(self): method test_custom_vjp_inconsistent_aliasing (line 2260) | def test_custom_vjp_inconsistent_aliasing(self): method test_custom_vjp_diff_arg_mutation_error (line 2279) | def test_custom_vjp_diff_arg_mutation_error(self): class TestVjpJvp (line 2309) | class TestVjpJvp(parameterized.TestCase): method test_vjp_basic (line 2315) | def test_vjp_basic(self, graph, graph_updates): method test_vjp_has_aux (line 2341) | def test_vjp_has_aux(self, graph, graph_updates): method test_vjp_state_propagation (line 2364) | def test_vjp_state_propagation(self, graph, graph_updates): method test_vjp_matches_jax (line 2390) | def test_vjp_matches_jax(self, graph, graph_updates): method test_vjp_decorator (line 2413) | def test_vjp_decorator(self, graph, graph_updates): method test_jvp_basic (line 2434) | def test_jvp_basic(self, graph, graph_updates): method test_jvp_has_aux (line 2461) | def test_jvp_has_aux(self, graph, graph_updates): method test_jvp_state_propagation (line 2488) | def test_jvp_state_propagation(self, graph, graph_updates): method test_jvp_matches_jax (line 2518) | def test_jvp_matches_jax(self, graph, graph_updates): method test_jvp_decorator (line 2541) | def test_jvp_decorator(self, graph, graph_updates): class TestScan (line 2562) | class TestScan(parameterized.TestCase): method test_basic (line 2563) | def test_basic(self): method test_variables_in_scan (line 2594) | def test_variables_in_scan(self, graph_updates): method test_variables_as_carries_in_scan (line 2627) | def test_variables_as_carries_in_scan(self, graph, graph_updates): method test_basic_no_carry (line 2651) | def test_basic_no_carry(self): method test_all_carry (line 2682) | def test_all_carry(self, graph_updates): method test_all_carry_one_argument_error (line 2702) | def test_all_carry_one_argument_error(self): method test_all_carry_new_reference_error (line 2718) | def test_all_carry_new_reference_error(self): method test_all_scan (line 2739) | def test_all_scan(self, graph_updates): method test_all_broadcast (line 2758) | def test_all_broadcast(self): method test_input_output_carry_mismatch_error (line 2775) | def test_input_output_carry_mismatch_error(self): method test_double_carry_error (line 2792) | def test_double_carry_error(self): method test_broadcast_in_output_error (line 2801) | def test_broadcast_in_output_error(self): method test_scan_stateful (line 2823) | def test_scan_stateful(self, graph, graph_updates): method test_scan_carry_identity_error (line 2841) | def test_scan_carry_identity_error(self, graph, graph_updates): method test_tree_mode_custom_axes (line 2855) | def test_tree_mode_custom_axes(self): method test_only_carry (line 2864) | def test_only_carry(self, graph_updates): method test_out_axes (line 2879) | def test_out_axes(self): method test_in_axes_simple (line 2908) | def test_in_axes_simple(self): method test_in_axes (line 2932) | def test_in_axes(self): method test_in_axes_broadcast (line 2968) | def test_in_axes_broadcast(self): method test_complex (line 3004) | def test_complex(self): method test_complex_view (line 3036) | def test_complex_view(self): method test_complex_broadcast_dropout (line 3068) | def test_complex_broadcast_dropout(self): method test_complex_broadcast_dropout_view (line 3101) | def test_complex_broadcast_dropout_view(self): method test_complex_decorator (line 3134) | def test_complex_decorator(self): method test_complex_decorator_view (line 3169) | def test_complex_decorator_view(self): method test_scan_with_sharding (line 3205) | def test_scan_with_sharding(self): method test_cache_tracing_simple (line 3261) | def test_cache_tracing_simple(self): method test_cache_tracing_object (line 3281) | def test_cache_tracing_object(self): method test_scan_broadcast_keys (line 3311) | def test_scan_broadcast_keys(self): method test_rnn_example (line 3329) | def test_rnn_example(self): method test_carry_pytree_sow (line 3363) | def test_carry_pytree_sow(self): method test_broadcast_variable_mutation_rejected (line 3430) | def test_broadcast_variable_mutation_rejected(self): method test_broadcast_out_axes_rejected (line 3443) | def test_broadcast_out_axes_rejected(self): method test_scan_inconsistent_aliasing (line 3454) | def test_scan_inconsistent_aliasing(self): method test_scan_input_output_aliasing (line 3469) | def test_scan_input_output_aliasing(self): class TestRemat (line 3480) | class TestRemat(parameterized.TestCase): method test_remat_basic (line 3486) | def test_remat_basic(self, graph, graph_updates): method test_remat_variables (line 3514) | def test_remat_variables(self, graph, graph_updates): method test_remat_with_scan_decorator (line 3538) | def test_remat_with_scan_decorator(self): method test_tree_mode_remat_basic (line 3565) | def test_tree_mode_remat_basic(self, graph, graph_updates): method test_tree_mode_remat_stateful (line 3586) | def test_tree_mode_remat_stateful(self, graph, graph_updates): class TestVmap (line 3612) | class TestVmap(parameterized.TestCase): method test_vmap_basic (line 3616) | def test_vmap_basic(self, graph, graph_updates): method test_vmap_stateful (line 3634) | def test_vmap_stateful(self, graph, graph_updates): method test_vmap_variables (line 3662) | def test_vmap_variables(self, graph, graph_updates): method test_vmap_ensemble_forward (line 3678) | def test_vmap_ensemble_forward(self, graph, graph_updates): method test_vmap_replicate (line 3709) | def test_vmap_replicate(self, graph, graph_updates): method test_basic (line 3720) | def test_basic(self): method test_basic_variables (line 3745) | def test_basic_variables(self): method test_state_axes (line 3772) | def test_state_axes(self): method test_split_rngs_context_manager (line 3828) | def test_split_rngs_context_manager(self): method test_split_rngs_decorator (line 3875) | def test_split_rngs_decorator(self): method test_state_axes_simple (line 3924) | def test_state_axes_simple(self): method test_split_rngs_decorator_simple (line 3958) | def test_split_rngs_decorator_simple(self): method test_state_axes_super_simple (line 3998) | def test_state_axes_super_simple(self): method test_replicate (line 4030) | def test_replicate(self): method test_consistent_aliasing_inputs (line 4076) | def test_consistent_aliasing_inputs(self): method test_consistent_aliasing_input_output (line 4090) | def test_consistent_aliasing_input_output(self): method test_consistent_aliasing_shared (line 4104) | def test_consistent_aliasing_shared(self): method test_equivalent_state_axes_mapping (line 4128) | def test_equivalent_state_axes_mapping(self): method test_equivalent_state_sharding_mapping (line 4140) | def test_equivalent_state_sharding_mapping(self): method test_captured_module_in_return_error (line 4157) | def test_captured_module_in_return_error(self): method test_vmap_and_cond_passthrough (line 4175) | def test_vmap_and_cond_passthrough(self): method test_vmap_and_cond_passthrough_error (line 4205) | def test_vmap_and_cond_passthrough_error(self): method test_example (line 4239) | def test_example(self): method test_example_with_vectorization (line 4267) | def test_example_with_vectorization(self): method test_metadata (line 4284) | def test_metadata(self): method test_state_axes_from_state (line 4308) | def test_state_axes_from_state(self): method test_vmap_inconsistent_aliasing (line 4347) | def test_vmap_inconsistent_aliasing(self): class TestPmap (line 4358) | class TestPmap(parameterized.TestCase): method test_basic_single (line 4359) | def test_basic_single(self): method test_basic_demo_single (line 4406) | def test_basic_demo_single(self, graph, graph_updates): method test_replicate_single (line 4443) | def test_replicate_single(self): class TestCond (line 4489) | class TestCond(parameterized.TestCase): method test_basic (line 4490) | def test_basic(self): method test_basic_variable (line 4538) | def test_basic_variable(self, graph, graph_updates): method test_cond_and_vmap (line 4564) | def test_cond_and_vmap(self, graph, graph_updates): method test_cond_different_variable_per_branch (line 4599) | def test_cond_different_variable_per_branch(self, graph, graph_updates): method test_cond_shared_references (line 4623) | def test_cond_shared_references(self): class TestSwitch (line 4649) | class TestSwitch(parameterized.TestCase): method test_basic (line 4654) | def test_basic(self, graph, graph_updates): method test_switch_variable (line 4698) | def test_switch_variable(self, graph, graph_updates): method test_switch_shared_references (line 4719) | def test_switch_shared_references(self): class TestWhileLoop (line 4745) | class TestWhileLoop(parameterized.TestCase): method test_basic (line 4749) | def test_basic(self, graph, graph_updates): method test_multiple_objects (line 4767) | def test_multiple_objects(self, graph, graph_updates): method test_nested_module (line 4786) | def test_nested_module(self, graph, graph_updates): method test_shared_module (line 4802) | def test_shared_module(self): method test_value_changed (line 4832) | def test_value_changed(self, graph, graph_updates): method test_ref_changed (line 4854) | def test_ref_changed(self, graph, graph_updates): method test_structure_changed (line 4872) | def test_structure_changed(self, graph, graph_updates): method test_repeated_object (line 4888) | def test_repeated_object(self): method test_immut_fori_loop (line 4901) | def test_immut_fori_loop(self): method test_fori_loop_grad_accum (line 4915) | def test_fori_loop_grad_accum(self, graph, graph_updates): method test_fori_loop_basic (line 4929) | def test_fori_loop_basic(self, graph, graph_updates): method test_fori_loop_with_sharing (line 4942) | def test_fori_loop_with_sharing(self): method test_loops_multiple_modules (line 4979) | def test_loops_multiple_modules(self, graph, graph_updates): method test_tree_mode_while_loop_stateful (line 4999) | def test_tree_mode_while_loop_stateful(self, graph, graph_updates): method test_tree_mode_while_loop_inside_jit (line 5028) | def test_tree_mode_while_loop_inside_jit(self, graph, graph_updates): method test_tree_mode_fori_loop_stateful (line 5052) | def test_tree_mode_fori_loop_stateful(self, graph, graph_updates): method test_tree_mode_fori_loop_inside_jit (line 5078) | def test_tree_mode_fori_loop_inside_jit(self, graph, graph_updates): class TestSplitMergeInputs (line 5096) | class TestSplitMergeInputs(absltest.TestCase): method test_split_inputs (line 5097) | def test_split_inputs(self): method test_split_inputs_cond (line 5120) | def test_split_inputs_cond(self): method test_split_inputs_vmap (line 5146) | def test_split_inputs_vmap(self): class TestCheckify (line 5185) | class TestCheckify(parameterized.TestCase): method test_basic (line 5189) | def test_basic(self, graph, graph_updates): method test_checkify_stateful (line 5212) | def test_checkify_stateful(self, graph, graph_updates): class TestBoundMethodTransforms (line 5227) | class TestBoundMethodTransforms(parameterized.TestCase): method test_remat_with_bound_method_raises (line 5228) | def test_remat_with_bound_method_raises(self): method test_jit_with_bound_method_raises (line 5242) | def test_jit_with_bound_method_raises(self): method test_vmap_with_bound_method_raises (line 5253) | def test_vmap_with_bound_method_raises(self): method test_eval_shape_with_bound_method_raises (line 5264) | def test_eval_shape_with_bound_method_raises(self): method test_grad_with_bound_method_raises (line 5279) | def test_grad_with_bound_method_raises(self, graph_mode, graph_updates): method test_value_and_grad_with_bound_method_raises (line 5293) | def test_value_and_grad_with_bound_method_raises(self, graph_mode, gra... method test_checkify_with_bound_method_raises (line 5306) | def test_checkify_with_bound_method_raises(self): method test_pmap_with_bound_method_raises (line 5316) | def test_pmap_with_bound_method_raises(self): method test_shard_map_with_bound_method_raises (line 5326) | def test_shard_map_with_bound_method_raises(self): method test_custom_vjp_with_bound_method_raises (line 5337) | def test_custom_vjp_with_bound_method_raises(self): method test_scan_bound_method_raises (line 5347) | def test_scan_bound_method_raises(self): method test_tree_mode_pmap_basic (line 5358) | def test_tree_mode_pmap_basic(self, graph, graph_updates): method test_tree_mode_pmap_stateful (line 5377) | def test_tree_mode_pmap_stateful(self, graph, graph_updates): method test_tree_mode_pmap_split_merge (line 5403) | def test_tree_mode_pmap_split_merge(self): method test_tree_mode_pmap_replicate (line 5446) | def test_tree_mode_pmap_replicate(self): class TestPureJaxFancyScan (line 5489) | class TestPureJaxFancyScan(absltest.TestCase): method test_carry_and_scan (line 5491) | def test_carry_and_scan(self): method test_carry_only_output (line 5503) | def test_carry_only_output(self): method test_broadcast_args (line 5513) | def test_broadcast_args(self): method test_pytree_carry (line 5524) | def test_pytree_carry(self): method test_no_carry_all_scanned (line 5538) | def test_no_carry_all_scanned(self): method test_reverse (line 5548) | def test_reverse(self): method test_pytree_prefix_in_axes (line 5559) | def test_pytree_prefix_in_axes(self): method test_nested_carry_rejected (line 5571) | def test_nested_carry_rejected(self): method test_broadcast_out_axes_rejected (line 5579) | def test_broadcast_out_axes_rejected(self): method test_none_broadcast_input (line 5587) | def test_none_broadcast_input(self): method test_none_nested_in_arg (line 5598) | def test_none_nested_in_arg(self): method test_nested_carry_in_out_axes_rejected (line 5610) | def test_nested_carry_in_out_axes_rejected(self): method test_carry_in_in_axes_only_rejected (line 5618) | def test_carry_in_in_axes_only_rejected(self): method test_carry_in_out_axes_only_rejected (line 5626) | def test_carry_in_out_axes_only_rejected(self): method test_non_tuple_carry_only (line 5634) | def test_non_tuple_carry_only(self): method test_non_tuple_scan_only (line 5644) | def test_non_tuple_scan_only(self): FILE: tests/nnx/variable_test.py class TestVariable (line 27) | class TestVariable(parameterized.TestCase): method test_pytree (line 28) | def test_pytree(self): method test_overloads_module (line 38) | def test_overloads_module(self): method test_jax_array (line 53) | def test_jax_array(self): method test_proxy_access (line 68) | def test_proxy_access(self): method test_proxy_call (line 74) | def test_proxy_call(self): method test_binary_ops (line 86) | def test_binary_ops(self): method test_eq_op (line 103) | def test_eq_op(self, v1, v2): method test_mutable_array_context (line 111) | def test_mutable_array_context(self): method test_get_set_metadata (line 137) | def test_get_set_metadata(self): method test_set_module_metadata (line 174) | def test_set_module_metadata(self): method test_broadcasting (line 206) | def test_broadcasting(self): method test_set_metadata_out_sharding (line 211) | def test_set_metadata_out_sharding(self): FILE: tests/pickle_test.py class ErrorrsTest (line 21) | class ErrorrsTest(absltest.TestCase): method test_exception_can_be_pickled (line 22) | def test_exception_can_be_pickled(self): FILE: tests/serialization_test.py class Point (line 41) | class Point: class Box (line 48) | class Box: function to_state_dict (line 52) | def to_state_dict(box: Box): function from_state_dict (line 56) | def from_state_dict(box: Box, state: Any): class OriginalTuple (line 65) | class OriginalTuple(NamedTuple): class WrongTuple (line 69) | class WrongTuple(NamedTuple): class OriginalModule (line 73) | class OriginalModule(nn.Module): method __call__ (line 75) | def __call__(self, x): class WrongModule (line 80) | class WrongModule(nn.Module): method __call__ (line 82) | def __call__(self, x): class SerializationTest (line 88) | class SerializationTest(parameterized.TestCase): method test_dataclass_serialization (line 89) | def test_dataclass_serialization(self): method test_pass_through_serialization (line 108) | def test_pass_through_serialization(self): method test_model_serialization (line 119) | def test_model_serialization(self): method test_partial_serialization (line 143) | def test_partial_serialization(self): method test_optimizer_serialization (line 150) | def test_optimizer_serialization(self): method test_collection_serialization (line 175) | def test_collection_serialization(self): method test_numpy_serialization (line 268) | def test_numpy_serialization(self, dtype): method test_jax_numpy_serialization (line 294) | def test_jax_numpy_serialization(self): method test_complex_serialization (line 324) | def test_complex_serialization(self): method test_restore_chunked (line 331) | def test_restore_chunked(self): method test_restore_unchunked (line 343) | def test_restore_unchunked(self): method test_namedtuple_serialization (line 363) | def test_namedtuple_serialization(self): method test_namedtuple_restore_legacy (line 372) | def test_namedtuple_restore_legacy(self): method test_model_serialization_to_bytes (line 385) | def test_model_serialization_to_bytes(self): method test_optimizer_serialization_to_bytes (line 393) | def test_optimizer_serialization_to_bytes(self): method test_serialization_chunking (line 404) | def test_serialization_chunking(self): method test_serialization_chunking2 (line 422) | def test_serialization_chunking2(self): method test_serialization_chunking3 (line 433) | def test_serialization_chunking3(self): method test_serialization_errors (line 506) | def test_serialization_errors(self, target, wrong_target, msg): FILE: tests/struct_test.py class Point (line 31) | class Point: class StructTest (line 37) | class StructTest(parameterized.TestCase): method test_no_extra_fields (line 38) | def test_no_extra_fields(self): method test_mutation (line 43) | def test_mutation(self): method test_slots (line 50) | def test_slots(self): method test_pytree_nodes (line 62) | def test_pytree_nodes(self): method test_keypath_error (line 69) | def test_keypath_error(self): method test_double_wrap_no_op (line 75) | def test_double_wrap_no_op(self): method test_wrap_pytree_node_no_error (line 87) | def test_wrap_pytree_node_no_error(self): method test_kw_only (line 96) | def test_kw_only(self, mode): method test_metadata_pass_through (line 125) | def test_metadata_pass_through(self): method test_mutable (line 135) | def test_mutable(self, mode): method test_generic_pytreenode_base_order (line 160) | def test_generic_pytreenode_base_order(self): FILE: tests/tensorboard_test.py function _process_event (line 32) | def _process_event(event): function _disk_usage (line 37) | def _disk_usage(path: pathlib.Path): class TensorboardTest (line 50) | class TensorboardTest(absltest.TestCase): method parse_and_return_summary_value (line 51) | def parse_and_return_summary_value(self, path): method test_summarywriter_flush_after_close (line 69) | def test_summarywriter_flush_after_close(self): method test_summarywriter_scalar (line 76) | def test_summarywriter_scalar(self): method test_summarywriter_text (line 91) | def test_summarywriter_text(self): method test_summarywriter_image (line 104) | def test_summarywriter_image(self): method test_summarywriter_image_float_pixel_values (line 116) | def test_summarywriter_image_float_pixel_values(self): method test_summarywriter_2dimage_scaled (line 132) | def test_summarywriter_2dimage_scaled(self): method test_summarywriter_single_channel_image_scaled (line 145) | def test_summarywriter_single_channel_image_scaled(self): method test_summarywriter_multiple_images (line 158) | def test_summarywriter_multiple_images(self): method test_summarywriter_multiple_2dimages_scaled (line 172) | def test_summarywriter_multiple_2dimages_scaled(self): method test_summarywriter_audio (line 187) | def test_summarywriter_audio(self): method test_summarywriter_audio_sampled_output (line 218) | def test_summarywriter_audio_sampled_output(self): method test_summarywriter_clipped_audio (line 244) | def test_summarywriter_clipped_audio(self): method test_summarywriter_histogram_defaultbins (line 272) | def test_summarywriter_histogram_defaultbins(self): method test_summarywriter_histogram_2bins (line 287) | def test_summarywriter_histogram_2bins(self): method test_flatten_dict (line 306) | def test_flatten_dict(self): method test_auto_flush (line 379) | def test_auto_flush(self): method test_no_auto_flush (line 388) | def test_no_auto_flush(self): FILE: tests/traceback_util_test.py class TracebackTest (line 38) | class TracebackTest(absltest.TestCase): method test_exclusion_list (line 39) | def test_exclusion_list(self): method test_simple_exclusion_tracebackhide (line 49) | def test_simple_exclusion_tracebackhide(self): method test_simple_exclusion_remove_frames (line 87) | def test_simple_exclusion_remove_frames(self): method test_dynamic_exclusion (line 125) | def test_dynamic_exclusion(self): FILE: tests/traverse_util_test.py class Foo (line 34) | class Foo: method __init__ (line 35) | def __init__(self, foo, bar=None): method __eq__ (line 39) | def __eq__(self, other): class TraversalTest (line 46) | class TraversalTest(absltest.TestCase): method test_traversal_id (line 47) | def test_traversal_id(self): method test_traverse_item (line 54) | def test_traverse_item(self): method test_traverse_tuple_item (line 61) | def test_traverse_tuple_item(self): method test_traverse_tuple_items (line 68) | def test_traverse_tuple_items(self): method test_traverse_namedtuple_item (line 75) | def test_traverse_namedtuple_item(self): method test_traverse_attr (line 82) | def test_traverse_attr(self): method test_traverse_namedtuple_attr (line 89) | def test_traverse_namedtuple_attr(self): method test_traverse_dataclass_attr (line 96) | def test_traverse_dataclass_attr(self): method test_traverse_merge (line 103) | def test_traverse_merge(self): method test_traverse_each (line 113) | def test_traverse_each(self): method test_traverse_each_dict (line 120) | def test_traverse_each_dict(self): method test_traverse_tree (line 127) | def test_traverse_tree(self): method test_traverse_filter (line 134) | def test_traverse_filter(self): method test_traversal_set (line 141) | def test_traversal_set(self): method test_flatten_dict (line 151) | def test_flatten_dict(self): method test_unflatten_dict (line 178) | def test_unflatten_dict(self): method test_flatten_dict_keep_empty (line 196) | def test_flatten_dict_keep_empty(self): method test_flatten_dict_is_leaf (line 210) | def test_flatten_dict_is_leaf(self): class ModelParamTraversalTest (line 226) | class ModelParamTraversalTest(absltest.TestCase): method test_only_works_on_model_params (line 227) | def test_only_works_on_model_params(self): method test_param_selection (line 232) | def test_param_selection(self): method test_path_value (line 276) | def test_path_value(self): method test_path_aware_map_with_multi_transform (line 284) | def test_path_aware_map_with_multi_transform(self): method test_path_aware_map_with_masked (line 314) | def test_path_aware_map_with_masked(self): method test_path_aware_map_with_empty_nodes (line 342) | def test_path_aware_map_with_empty_nodes(self):