gitextract_07u909eo/ ├── .git-blame-ignore-revs ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ └── bug_report.md │ ├── analytics/ │ │ ├── README.md │ │ ├── get_repo_metrics.py │ │ ├── issue_activity_since_date.gql │ │ ├── pr_data_query.gql │ │ └── requirements.txt │ ├── pull_request_template.md │ └── workflows/ │ ├── flax_publish.yml │ ├── flax_test.yml │ ├── flaxlib_publish.yml │ └── jax_nightly.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── AUTHORS ├── CHANGELOG.md ├── LICENSE ├── README.md ├── benchmarks/ │ ├── README.md │ ├── nnx_graph_overhead.py │ ├── nnx_mlpmixer_training.py │ ├── nnx_simple_training.py │ ├── nnx_state_traversal.py │ └── tracing/ │ ├── README.md │ ├── __init__.py │ ├── gemma.py │ ├── imagenet.py │ ├── lm1b.py │ ├── mnist.py │ ├── nlp_seq.py │ ├── ogbg_molpcba.py │ ├── ppo.py │ ├── requirements.txt │ ├── run_all_benchmarks.sh │ ├── seq2seq.py │ ├── sst2.py │ ├── tracing_benchmark.py │ ├── vae.py │ └── wmt.py ├── contributing.md ├── docs/ │ ├── .gitignore │ ├── .readthedocs.yaml │ ├── Makefile │ ├── README.md │ ├── _ext/ │ │ ├── codediff.py │ │ ├── codediff_test.py │ │ └── flax_module.py │ ├── _static/ │ │ └── css/ │ │ └── flax_theme.css │ ├── _templates/ │ │ └── autosummary/ │ │ └── flax_module.rst │ ├── api_reference/ │ │ ├── flax.core.frozen_dict.rst │ │ ├── flax.cursor.rst │ │ ├── flax.errors.rst │ │ ├── flax.jax_utils.rst │ │ ├── flax.linen/ │ │ │ ├── activation_functions.rst │ │ │ ├── decorators.rst │ │ │ ├── index.rst │ │ │ ├── init_apply.rst │ │ │ ├── initializers.rst │ │ │ ├── inspection.rst │ │ │ ├── layers.rst │ │ │ ├── module.rst │ │ │ ├── profiling.rst │ │ │ ├── spmd.rst │ │ │ ├── transformations.rst │ │ │ └── variable.rst │ │ ├── flax.serialization.rst │ │ ├── flax.struct.rst │ │ ├── flax.traceback_util.rst │ │ ├── flax.training.rst │ │ └── index.rst │ ├── conf.py │ ├── conf_sphinx_patch.py │ ├── developer_notes/ │ │ ├── index.rst │ │ ├── lift.md │ │ └── module_lifecycle.rst │ ├── examples/ │ │ ├── community_examples.rst │ │ ├── core_examples.rst │ │ ├── google_research_examples.rst │ │ ├── index.rst │ │ └── repositories_that_use_flax.rst │ ├── faq.rst │ ├── flip/ │ │ ├── 0000-template.md │ │ ├── 1009-optimizer-api.md │ │ ├── 1777-default-dtype.md │ │ ├── 2396-rnn.md │ │ ├── 2434-general-metadata.md │ │ ├── 2974-kw-only-dataclasses.md │ │ ├── 3099-rnnbase-refactor.md │ │ ├── 4105-jax-style-nnx-transforms.md │ │ └── README.md │ ├── glossary.rst │ ├── guides/ │ │ ├── converting_and_upgrading/ │ │ │ ├── convert_pytorch_to_flax.rst │ │ │ ├── haiku_migration_guide.rst │ │ │ ├── index.rst │ │ │ ├── linen_upgrade_guide.rst │ │ │ ├── optax_update_guide.rst │ │ │ ├── orbax_upgrade_guide.rst │ │ │ ├── regular_dict_upgrade_guide.rst │ │ │ └── rnncell_upgrade_guide.rst │ │ ├── data_preprocessing/ │ │ │ ├── full_eval.rst │ │ │ ├── index.rst │ │ │ ├── loading_datasets.ipynb │ │ │ └── loading_datasets.md │ │ ├── flax_fundamentals/ │ │ │ ├── arguments.md │ │ │ ├── flax_basics.ipynb │ │ │ ├── flax_basics.md │ │ │ ├── index.rst │ │ │ ├── rng_guide.ipynb │ │ │ ├── rng_guide.md │ │ │ ├── setup_or_nncompact.rst │ │ │ └── state_params.rst │ │ ├── flax_sharp_bits.ipynb │ │ ├── flax_sharp_bits.md │ │ ├── index.rst │ │ ├── model_inspection/ │ │ │ ├── extracting_intermediates.rst │ │ │ ├── index.rst │ │ │ ├── model_surgery.ipynb │ │ │ └── model_surgery.md │ │ ├── parallel_training/ │ │ │ ├── ensembling.rst │ │ │ ├── flax_on_pjit.ipynb │ │ │ ├── flax_on_pjit.md │ │ │ └── index.rst │ │ ├── quantization/ │ │ │ ├── fp8_basics.ipynb │ │ │ ├── fp8_basics.md │ │ │ └── index.rst │ │ └── training_techniques/ │ │ ├── batch_norm.rst │ │ ├── dropout.rst │ │ ├── index.rst │ │ ├── lr_schedule.rst │ │ ├── transfer_learning.ipynb │ │ ├── transfer_learning.md │ │ ├── use_checkpointing.ipynb │ │ └── use_checkpointing.md │ ├── index.rst │ ├── linen_intro.ipynb │ ├── linen_intro.md │ ├── quick_start.ipynb │ ├── quick_start.md │ └── robots.txt ├── docs_nnx/ │ ├── .gitignore │ ├── .readthedocs.yaml │ ├── Makefile │ ├── README.md │ ├── _ext/ │ │ ├── codediff.py │ │ ├── codediff_test.py │ │ └── flax_module.py │ ├── _static/ │ │ └── css/ │ │ └── flax_theme.css │ ├── _templates/ │ │ └── autosummary/ │ │ └── flax_module.rst │ ├── api_reference/ │ │ ├── flax.config.rst │ │ ├── flax.core.frozen_dict.rst │ │ ├── flax.nnx/ │ │ │ ├── bridge.rst │ │ │ ├── filterlib.rst │ │ │ ├── graph.rst │ │ │ ├── helpers.rst │ │ │ ├── index.rst │ │ │ ├── module.rst │ │ │ ├── nn/ │ │ │ │ ├── activations.rst │ │ │ │ ├── attention.rst │ │ │ │ ├── dtypes.rst │ │ │ │ ├── index.rst │ │ │ │ ├── initializers.rst │ │ │ │ ├── linear.rst │ │ │ │ ├── lora.rst │ │ │ │ ├── normalization.rst │ │ │ │ ├── recurrent.rst │ │ │ │ └── stochastic.rst │ │ │ ├── object.rst │ │ │ ├── rnglib.rst │ │ │ ├── spmd.rst │ │ │ ├── state.rst │ │ │ ├── summary.rst │ │ │ ├── training/ │ │ │ │ ├── index.rst │ │ │ │ ├── metrics.rst │ │ │ │ └── optimizer.rst │ │ │ ├── transforms.rst │ │ │ ├── variables.rst │ │ │ └── visualization.rst │ │ ├── flax.struct.rst │ │ ├── flax.training.rst │ │ ├── flax.traverse_util.rst │ │ └── index.rst │ ├── conf.py │ ├── conf_sphinx_patch.py │ ├── contributing.md │ ├── examples/ │ │ ├── core_examples.rst │ │ ├── gemma.ipynb │ │ ├── gemma.md │ │ └── index.rst │ ├── faq.rst │ ├── flip/ │ │ ├── 0000-template.md │ │ ├── 1009-optimizer-api.md │ │ ├── 1777-default-dtype.md │ │ ├── 2396-rnn.md │ │ ├── 2434-general-metadata.md │ │ ├── 2974-kw-only-dataclasses.md │ │ ├── 3099-rnnbase-refactor.md │ │ ├── 4105-jax-style-nnx-transforms.md │ │ ├── 4844-var-eager-sharding.md │ │ ├── 5310-tree-mode-nnx.md │ │ └── README.md │ ├── guides/ │ │ ├── blog.md │ │ ├── bridge_guide.ipynb │ │ ├── bridge_guide.md │ │ ├── checkpointing.ipynb │ │ ├── checkpointing.md │ │ ├── demo.ipynb │ │ ├── demo.md │ │ ├── extracting_intermediates.ipynb │ │ ├── extracting_intermediates.md │ │ ├── filters_guide.ipynb │ │ ├── filters_guide.md │ │ ├── flax_gspmd.ipynb │ │ ├── flax_gspmd.md │ │ ├── index.rst │ │ ├── jax_and_nnx_transforms.rst │ │ ├── performance.ipynb │ │ ├── performance.md │ │ ├── pytree.ipynb │ │ ├── pytree.md │ │ ├── randomness.ipynb │ │ ├── randomness.md │ │ ├── surgery.ipynb │ │ ├── surgery.md │ │ ├── tiny_nnx.ipynb │ │ ├── transforms.ipynb │ │ ├── transforms.md │ │ ├── view.ipynb │ │ └── view.md │ ├── guides_advanced.rst │ ├── guides_basic.rst │ ├── hijax/ │ │ ├── hijax.ipynb │ │ ├── hijax.md │ │ └── index.rst │ ├── index.rst │ ├── key_concepts.ipynb │ ├── key_concepts.md │ ├── migrating/ │ │ ├── convert_pytorch_to_flax.rst │ │ ├── haiku_to_flax.rst │ │ ├── index.rst │ │ ├── linen_to_nnx.rst │ │ └── nnx_010_to_nnx_011.rst │ ├── mnist_tutorial.ipynb │ ├── mnist_tutorial.md │ ├── nnx_basics.ipynb │ ├── nnx_basics.md │ ├── nnx_glossary.rst │ ├── philosophy.md │ ├── robots.txt │ └── why.rst ├── examples/ │ ├── README.md │ ├── __init__.py │ ├── cloud/ │ │ ├── README.md │ │ ├── launch_gce.py │ │ └── startup_script.sh │ ├── gemma/ │ │ ├── README.md │ │ ├── configs/ │ │ │ ├── default.py │ │ │ ├── gemma3_4b.py │ │ │ ├── small.py │ │ │ └── tiny.py │ │ ├── helpers.py │ │ ├── helpers_test.py │ │ ├── input_pipeline.py │ │ ├── input_pipeline_test.py │ │ ├── layers.py │ │ ├── layers_test.py │ │ ├── main.py │ │ ├── modules.py │ │ ├── modules_test.py │ │ ├── params.py │ │ ├── positional_embeddings.py │ │ ├── positional_embeddings_test.py │ │ ├── requirements.txt │ │ ├── sampler.py │ │ ├── sampler_test.py │ │ ├── sow_lib.py │ │ ├── tokenizer.py │ │ ├── train.py │ │ ├── transformer.py │ │ ├── transformer_test.py │ │ └── utils.py │ ├── imagenet/ │ │ ├── README.md │ │ ├── configs/ │ │ │ ├── default.py │ │ │ ├── fake_data_benchmark.py │ │ │ ├── tpu.py │ │ │ ├── v100_x8.py │ │ │ └── v100_x8_mixed_precision.py │ │ ├── imagenet.ipynb │ │ ├── imagenet_benchmark.py │ │ ├── imagenet_fake_data_benchmark.py │ │ ├── input_pipeline.py │ │ ├── main.py │ │ ├── models.py │ │ ├── models_test.py │ │ ├── requirements.txt │ │ ├── train.py │ │ └── train_test.py │ ├── linen_design_test/ │ │ ├── attention_simple.py │ │ ├── autoencoder.py │ │ ├── dense.py │ │ ├── linear_regression.py │ │ ├── mlp_explicit.py │ │ ├── mlp_inline.py │ │ └── mlp_lazy.py │ ├── lm1b/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── input_pipeline.py │ │ ├── input_pipeline_test.py │ │ ├── main.py │ │ ├── models.py │ │ ├── requirements.txt │ │ ├── temperature_sampler.py │ │ ├── temperature_sampler_test.py │ │ ├── tokenizer.py │ │ ├── train.py │ │ ├── train_test.py │ │ └── utils.py │ ├── mnist/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── main.py │ │ ├── mnist.ipynb │ │ ├── mnist_benchmark.py │ │ ├── requirements.txt │ │ ├── train.py │ │ └── train_test.py │ ├── nlp_seq/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── input_pipeline.py │ │ ├── input_pipeline_test.py │ │ ├── main.py │ │ ├── models.py │ │ ├── requirements.txt │ │ └── train.py │ ├── nnx_toy_examples/ │ │ ├── 01_functional_api.py │ │ ├── 02_lifted_transforms.py │ │ ├── 03_train_state.py │ │ ├── 04_data_parallel_with_jit.py │ │ ├── 05_vae.py │ │ ├── 06_scan_over_layers.py │ │ ├── 07_array_leaves.py │ │ ├── 08_save_load_checkpoints.py │ │ ├── 09_parameter_surgery.py │ │ ├── 10_fsdp_and_optimizer.py │ │ ├── hijax_basic.py │ │ ├── hijax_demo.py │ │ └── requirements.txt │ ├── ogbg_molpcba/ │ │ ├── README.md │ │ ├── configs/ │ │ │ ├── default.py │ │ │ ├── default_graph_net.py │ │ │ ├── hparam_sweep.py │ │ │ └── test.py │ │ ├── input_pipeline.py │ │ ├── input_pipeline_test.py │ │ ├── main.py │ │ ├── models.py │ │ ├── models_test.py │ │ ├── ogbg_molpcba.ipynb │ │ ├── ogbg_molpcba_benchmark.py │ │ ├── requirements.txt │ │ ├── train.py │ │ └── train_test.py │ ├── ppo/ │ │ ├── README.md │ │ ├── agent.py │ │ ├── configs/ │ │ │ └── default.py │ │ ├── env_utils.py │ │ ├── models.py │ │ ├── ppo_lib.py │ │ ├── ppo_lib_test.py │ │ ├── ppo_main.py │ │ ├── requirements.txt │ │ ├── seed_rl_atari_preprocessing.py │ │ └── test_episodes.py │ ├── seq2seq/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── input_pipeline.py │ │ ├── main.py │ │ ├── models.py │ │ ├── requirements.txt │ │ ├── seq2seq.ipynb │ │ ├── train.py │ │ └── train_test.py │ ├── sst2/ │ │ ├── README.md │ │ ├── build_vocabulary.py │ │ ├── configs/ │ │ │ └── default.py │ │ ├── input_pipeline.py │ │ ├── input_pipeline_test.py │ │ ├── main.py │ │ ├── models.py │ │ ├── models_test.py │ │ ├── requirements.txt │ │ ├── sst2.ipynb │ │ ├── train.py │ │ ├── train_test.py │ │ ├── vocab.txt │ │ └── vocabulary.py │ ├── vae/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── input_pipeline.py │ │ ├── main.py │ │ ├── models.py │ │ ├── requirements.txt │ │ ├── results/ │ │ │ └── .gitignore │ │ ├── train.py │ │ └── utils.py │ └── wmt/ │ ├── README.md │ ├── bleu.py │ ├── configs/ │ │ └── default.py │ ├── decode.py │ ├── input_pipeline.py │ ├── input_pipeline_test.py │ ├── main.py │ ├── models.py │ ├── requirements.txt │ ├── tokenizer.py │ ├── train.py │ └── train_test.py ├── flax/ │ ├── __init__.py │ ├── configurations.py │ ├── core/ │ │ ├── __init__.py │ │ ├── axes_scan.py │ │ ├── flax_functional_engine.ipynb │ │ ├── frozen_dict.py │ │ ├── lift.py │ │ ├── meta.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── linear.py │ │ │ ├── normalization.py │ │ │ └── stochastic.py │ │ ├── partial_eval.py │ │ ├── scope.py │ │ ├── spmd.py │ │ ├── tracers.py │ │ └── variables.py │ ├── cursor.py │ ├── errors.py │ ├── experimental/ │ │ ├── __init__.py │ │ └── nnx.py │ ├── ids.py │ ├── io.py │ ├── jax_utils.py │ ├── linen/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── attention.py │ │ ├── batch_apply.py │ │ ├── combinators.py │ │ ├── dtypes.py │ │ ├── experimental/ │ │ │ └── layers_with_named_axes.py │ │ ├── fp8_ops.py │ │ ├── initializers.py │ │ ├── kw_only_dataclasses.py │ │ ├── linear.py │ │ ├── module.py │ │ ├── normalization.py │ │ ├── partitioning.py │ │ ├── pooling.py │ │ ├── recurrent.py │ │ ├── spmd.py │ │ ├── stochastic.py │ │ ├── summary.py │ │ └── transforms.py │ ├── metrics/ │ │ ├── __init__.py │ │ └── tensorboard.py │ ├── nnx/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── bridge/ │ │ │ ├── __init__.py │ │ │ ├── interop.py │ │ │ ├── module.py │ │ │ ├── variables.py │ │ │ └── wrappers.py │ │ ├── compat.py │ │ ├── extract.py │ │ ├── filterlib.py │ │ ├── graph.py │ │ ├── graphlib.py │ │ ├── helpers.py │ │ ├── ids.py │ │ ├── module.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── attention.py │ │ │ ├── dtypes.py │ │ │ ├── initializers.py │ │ │ ├── linear.py │ │ │ ├── lora.py │ │ │ ├── normalization.py │ │ │ ├── recurrent.py │ │ │ └── stochastic.py │ │ ├── proxy_caller.py │ │ ├── pytreelib.py │ │ ├── reprlib.py │ │ ├── rnglib.py │ │ ├── scripts/ │ │ │ ├── requirements.txt │ │ │ └── run-all-examples.bash │ │ ├── spmd.py │ │ ├── statelib.py │ │ ├── summary.py │ │ ├── tracers.py │ │ ├── training/ │ │ │ ├── __init__.py │ │ │ ├── metrics.py │ │ │ └── optimizer.py │ │ ├── transforms/ │ │ │ ├── __init__.py │ │ │ ├── autodiff.py │ │ │ ├── compilation.py │ │ │ ├── general.py │ │ │ ├── iteration.py │ │ │ └── transforms.py │ │ ├── traversals.py │ │ ├── variablelib.py │ │ └── visualization.py │ ├── oss/ │ │ └── .git-blame-ignore-revs │ ├── py.typed │ ├── serialization.py │ ├── struct.py │ ├── testing/ │ │ ├── __init__.py │ │ └── benchmark.py │ ├── traceback_util.py │ ├── training/ │ │ ├── __init__.py │ │ ├── checkpoints.py │ │ ├── common_utils.py │ │ ├── dynamic_scale.py │ │ ├── early_stopping.py │ │ ├── lr_schedule.py │ │ ├── orbax_utils.py │ │ ├── prefetch_iterator.py │ │ └── train_state.py │ ├── traverse_util.py │ ├── typing.py │ └── version.py ├── flaxlib_src/ │ ├── .gitignore │ ├── CMakeLists.txt │ ├── Cargo.toml │ ├── LICENSE │ ├── README.md │ ├── pyproject.toml │ └── src/ │ ├── flaxlib/ │ │ ├── __init__.py │ │ └── flaxlib_cpp.pyi │ └── lib.cc ├── nnx.py ├── pylintrc ├── pyproject.toml └── tests/ ├── checkpoints_test.py ├── colab_tpu_jax_version.ipynb ├── configurations_test.py ├── core/ │ ├── core_frozen_dict_test.py │ ├── core_lift_test.py │ ├── core_meta_test.py │ ├── core_scope_test.py │ └── design/ │ ├── core_attention_test.py │ ├── core_auto_encoder_test.py │ ├── core_big_resnets_test.py │ ├── core_custom_vjp_test.py │ ├── core_dense_test.py │ ├── core_flow_test.py │ ├── core_resnet_test.py │ ├── core_scan_test.py │ ├── core_tied_autoencoder_test.py │ ├── core_vmap_test.py │ └── core_weight_std_test.py ├── cursor_test.py ├── download_dataset_metadata.sh ├── early_stopping_test.py ├── flaxlib_test.py ├── import_test.ipynb ├── io_test.py ├── jax_utils_test.py ├── linen/ │ ├── initializers_test.py │ ├── kw_only_dataclasses_test.py │ ├── linen_activation_test.py │ ├── linen_attention_test.py │ ├── linen_batch_apply_test.py │ ├── linen_combinators_test.py │ ├── linen_dtypes_test.py │ ├── linen_linear_test.py │ ├── linen_meta_test.py │ ├── linen_module_test.py │ ├── linen_recurrent_test.py │ ├── linen_test.py │ ├── linen_transforms_test.py │ ├── partitioning_test.py │ └── summary_test.py ├── nnx/ │ ├── __init__.py │ ├── bridge/ │ │ ├── module_test.py │ │ └── wrappers_test.py │ ├── containers_test.py │ ├── filters_test.py │ ├── graph_utils_test.py │ ├── helpers_test.py │ ├── ids_test.py │ ├── integration_test.py │ ├── metrics_test.py │ ├── module_test.py │ ├── mutable_array_test.py │ ├── nn/ │ │ ├── attention_test.py │ │ ├── conv_test.py │ │ ├── embed_test.py │ │ ├── linear_test.py │ │ ├── lora_test.py │ │ ├── normalization_test.py │ │ ├── recurrent_test.py │ │ └── stochastic_test.py │ ├── optimizer_test.py │ ├── partitioning_test.py │ ├── rngs_test.py │ ├── spmd_test.py │ ├── state_test.py │ ├── summary_test.py │ ├── test_traversals.py │ ├── transforms_test.py │ └── variable_test.py ├── pickle_test.py ├── run_all_tests.sh ├── serialization_test.py ├── struct_test.py ├── tensorboard_test.py ├── traceback_util_test.py └── traverse_util_test.py