Copy disabled (too large)
Download .txt
Showing preview only (10,449K chars total). Download the full file to get everything.
Repository: google/flax
Branch: main
Commit: 1572c84e34a8
Files: 618
Total size: 9.9 MB
Directory structure:
gitextract_07u909eo/
├── .git-blame-ignore-revs
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ └── bug_report.md
│ ├── analytics/
│ │ ├── README.md
│ │ ├── get_repo_metrics.py
│ │ ├── issue_activity_since_date.gql
│ │ ├── pr_data_query.gql
│ │ └── requirements.txt
│ ├── pull_request_template.md
│ └── workflows/
│ ├── flax_publish.yml
│ ├── flax_test.yml
│ ├── flaxlib_publish.yml
│ └── jax_nightly.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── AUTHORS
├── CHANGELOG.md
├── LICENSE
├── README.md
├── benchmarks/
│ ├── README.md
│ ├── nnx_graph_overhead.py
│ ├── nnx_mlpmixer_training.py
│ ├── nnx_simple_training.py
│ ├── nnx_state_traversal.py
│ └── tracing/
│ ├── README.md
│ ├── __init__.py
│ ├── gemma.py
│ ├── imagenet.py
│ ├── lm1b.py
│ ├── mnist.py
│ ├── nlp_seq.py
│ ├── ogbg_molpcba.py
│ ├── ppo.py
│ ├── requirements.txt
│ ├── run_all_benchmarks.sh
│ ├── seq2seq.py
│ ├── sst2.py
│ ├── tracing_benchmark.py
│ ├── vae.py
│ └── wmt.py
├── contributing.md
├── docs/
│ ├── .gitignore
│ ├── .readthedocs.yaml
│ ├── Makefile
│ ├── README.md
│ ├── _ext/
│ │ ├── codediff.py
│ │ ├── codediff_test.py
│ │ └── flax_module.py
│ ├── _static/
│ │ └── css/
│ │ └── flax_theme.css
│ ├── _templates/
│ │ └── autosummary/
│ │ └── flax_module.rst
│ ├── api_reference/
│ │ ├── flax.core.frozen_dict.rst
│ │ ├── flax.cursor.rst
│ │ ├── flax.errors.rst
│ │ ├── flax.jax_utils.rst
│ │ ├── flax.linen/
│ │ │ ├── activation_functions.rst
│ │ │ ├── decorators.rst
│ │ │ ├── index.rst
│ │ │ ├── init_apply.rst
│ │ │ ├── initializers.rst
│ │ │ ├── inspection.rst
│ │ │ ├── layers.rst
│ │ │ ├── module.rst
│ │ │ ├── profiling.rst
│ │ │ ├── spmd.rst
│ │ │ ├── transformations.rst
│ │ │ └── variable.rst
│ │ ├── flax.serialization.rst
│ │ ├── flax.struct.rst
│ │ ├── flax.traceback_util.rst
│ │ ├── flax.training.rst
│ │ └── index.rst
│ ├── conf.py
│ ├── conf_sphinx_patch.py
│ ├── developer_notes/
│ │ ├── index.rst
│ │ ├── lift.md
│ │ └── module_lifecycle.rst
│ ├── examples/
│ │ ├── community_examples.rst
│ │ ├── core_examples.rst
│ │ ├── google_research_examples.rst
│ │ ├── index.rst
│ │ └── repositories_that_use_flax.rst
│ ├── faq.rst
│ ├── flip/
│ │ ├── 0000-template.md
│ │ ├── 1009-optimizer-api.md
│ │ ├── 1777-default-dtype.md
│ │ ├── 2396-rnn.md
│ │ ├── 2434-general-metadata.md
│ │ ├── 2974-kw-only-dataclasses.md
│ │ ├── 3099-rnnbase-refactor.md
│ │ ├── 4105-jax-style-nnx-transforms.md
│ │ └── README.md
│ ├── glossary.rst
│ ├── guides/
│ │ ├── converting_and_upgrading/
│ │ │ ├── convert_pytorch_to_flax.rst
│ │ │ ├── haiku_migration_guide.rst
│ │ │ ├── index.rst
│ │ │ ├── linen_upgrade_guide.rst
│ │ │ ├── optax_update_guide.rst
│ │ │ ├── orbax_upgrade_guide.rst
│ │ │ ├── regular_dict_upgrade_guide.rst
│ │ │ └── rnncell_upgrade_guide.rst
│ │ ├── data_preprocessing/
│ │ │ ├── full_eval.rst
│ │ │ ├── index.rst
│ │ │ ├── loading_datasets.ipynb
│ │ │ └── loading_datasets.md
│ │ ├── flax_fundamentals/
│ │ │ ├── arguments.md
│ │ │ ├── flax_basics.ipynb
│ │ │ ├── flax_basics.md
│ │ │ ├── index.rst
│ │ │ ├── rng_guide.ipynb
│ │ │ ├── rng_guide.md
│ │ │ ├── setup_or_nncompact.rst
│ │ │ └── state_params.rst
│ │ ├── flax_sharp_bits.ipynb
│ │ ├── flax_sharp_bits.md
│ │ ├── index.rst
│ │ ├── model_inspection/
│ │ │ ├── extracting_intermediates.rst
│ │ │ ├── index.rst
│ │ │ ├── model_surgery.ipynb
│ │ │ └── model_surgery.md
│ │ ├── parallel_training/
│ │ │ ├── ensembling.rst
│ │ │ ├── flax_on_pjit.ipynb
│ │ │ ├── flax_on_pjit.md
│ │ │ └── index.rst
│ │ ├── quantization/
│ │ │ ├── fp8_basics.ipynb
│ │ │ ├── fp8_basics.md
│ │ │ └── index.rst
│ │ └── training_techniques/
│ │ ├── batch_norm.rst
│ │ ├── dropout.rst
│ │ ├── index.rst
│ │ ├── lr_schedule.rst
│ │ ├── transfer_learning.ipynb
│ │ ├── transfer_learning.md
│ │ ├── use_checkpointing.ipynb
│ │ └── use_checkpointing.md
│ ├── index.rst
│ ├── linen_intro.ipynb
│ ├── linen_intro.md
│ ├── quick_start.ipynb
│ ├── quick_start.md
│ └── robots.txt
├── docs_nnx/
│ ├── .gitignore
│ ├── .readthedocs.yaml
│ ├── Makefile
│ ├── README.md
│ ├── _ext/
│ │ ├── codediff.py
│ │ ├── codediff_test.py
│ │ └── flax_module.py
│ ├── _static/
│ │ └── css/
│ │ └── flax_theme.css
│ ├── _templates/
│ │ └── autosummary/
│ │ └── flax_module.rst
│ ├── api_reference/
│ │ ├── flax.config.rst
│ │ ├── flax.core.frozen_dict.rst
│ │ ├── flax.nnx/
│ │ │ ├── bridge.rst
│ │ │ ├── filterlib.rst
│ │ │ ├── graph.rst
│ │ │ ├── helpers.rst
│ │ │ ├── index.rst
│ │ │ ├── module.rst
│ │ │ ├── nn/
│ │ │ │ ├── activations.rst
│ │ │ │ ├── attention.rst
│ │ │ │ ├── dtypes.rst
│ │ │ │ ├── index.rst
│ │ │ │ ├── initializers.rst
│ │ │ │ ├── linear.rst
│ │ │ │ ├── lora.rst
│ │ │ │ ├── normalization.rst
│ │ │ │ ├── recurrent.rst
│ │ │ │ └── stochastic.rst
│ │ │ ├── object.rst
│ │ │ ├── rnglib.rst
│ │ │ ├── spmd.rst
│ │ │ ├── state.rst
│ │ │ ├── summary.rst
│ │ │ ├── training/
│ │ │ │ ├── index.rst
│ │ │ │ ├── metrics.rst
│ │ │ │ └── optimizer.rst
│ │ │ ├── transforms.rst
│ │ │ ├── variables.rst
│ │ │ └── visualization.rst
│ │ ├── flax.struct.rst
│ │ ├── flax.training.rst
│ │ ├── flax.traverse_util.rst
│ │ └── index.rst
│ ├── conf.py
│ ├── conf_sphinx_patch.py
│ ├── contributing.md
│ ├── examples/
│ │ ├── core_examples.rst
│ │ ├── gemma.ipynb
│ │ ├── gemma.md
│ │ └── index.rst
│ ├── faq.rst
│ ├── flip/
│ │ ├── 0000-template.md
│ │ ├── 1009-optimizer-api.md
│ │ ├── 1777-default-dtype.md
│ │ ├── 2396-rnn.md
│ │ ├── 2434-general-metadata.md
│ │ ├── 2974-kw-only-dataclasses.md
│ │ ├── 3099-rnnbase-refactor.md
│ │ ├── 4105-jax-style-nnx-transforms.md
│ │ ├── 4844-var-eager-sharding.md
│ │ ├── 5310-tree-mode-nnx.md
│ │ └── README.md
│ ├── guides/
│ │ ├── blog.md
│ │ ├── bridge_guide.ipynb
│ │ ├── bridge_guide.md
│ │ ├── checkpointing.ipynb
│ │ ├── checkpointing.md
│ │ ├── demo.ipynb
│ │ ├── demo.md
│ │ ├── extracting_intermediates.ipynb
│ │ ├── extracting_intermediates.md
│ │ ├── filters_guide.ipynb
│ │ ├── filters_guide.md
│ │ ├── flax_gspmd.ipynb
│ │ ├── flax_gspmd.md
│ │ ├── index.rst
│ │ ├── jax_and_nnx_transforms.rst
│ │ ├── performance.ipynb
│ │ ├── performance.md
│ │ ├── pytree.ipynb
│ │ ├── pytree.md
│ │ ├── randomness.ipynb
│ │ ├── randomness.md
│ │ ├── surgery.ipynb
│ │ ├── surgery.md
│ │ ├── tiny_nnx.ipynb
│ │ ├── transforms.ipynb
│ │ ├── transforms.md
│ │ ├── view.ipynb
│ │ └── view.md
│ ├── guides_advanced.rst
│ ├── guides_basic.rst
│ ├── hijax/
│ │ ├── hijax.ipynb
│ │ ├── hijax.md
│ │ └── index.rst
│ ├── index.rst
│ ├── key_concepts.ipynb
│ ├── key_concepts.md
│ ├── migrating/
│ │ ├── convert_pytorch_to_flax.rst
│ │ ├── haiku_to_flax.rst
│ │ ├── index.rst
│ │ ├── linen_to_nnx.rst
│ │ └── nnx_010_to_nnx_011.rst
│ ├── mnist_tutorial.ipynb
│ ├── mnist_tutorial.md
│ ├── nnx_basics.ipynb
│ ├── nnx_basics.md
│ ├── nnx_glossary.rst
│ ├── philosophy.md
│ ├── robots.txt
│ └── why.rst
├── examples/
│ ├── README.md
│ ├── __init__.py
│ ├── cloud/
│ │ ├── README.md
│ │ ├── launch_gce.py
│ │ └── startup_script.sh
│ ├── gemma/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ ├── default.py
│ │ │ ├── gemma3_4b.py
│ │ │ ├── small.py
│ │ │ └── tiny.py
│ │ ├── helpers.py
│ │ ├── helpers_test.py
│ │ ├── input_pipeline.py
│ │ ├── input_pipeline_test.py
│ │ ├── layers.py
│ │ ├── layers_test.py
│ │ ├── main.py
│ │ ├── modules.py
│ │ ├── modules_test.py
│ │ ├── params.py
│ │ ├── positional_embeddings.py
│ │ ├── positional_embeddings_test.py
│ │ ├── requirements.txt
│ │ ├── sampler.py
│ │ ├── sampler_test.py
│ │ ├── sow_lib.py
│ │ ├── tokenizer.py
│ │ ├── train.py
│ │ ├── transformer.py
│ │ ├── transformer_test.py
│ │ └── utils.py
│ ├── imagenet/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ ├── default.py
│ │ │ ├── fake_data_benchmark.py
│ │ │ ├── tpu.py
│ │ │ ├── v100_x8.py
│ │ │ └── v100_x8_mixed_precision.py
│ │ ├── imagenet.ipynb
│ │ ├── imagenet_benchmark.py
│ │ ├── imagenet_fake_data_benchmark.py
│ │ ├── input_pipeline.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── models_test.py
│ │ ├── requirements.txt
│ │ ├── train.py
│ │ └── train_test.py
│ ├── linen_design_test/
│ │ ├── attention_simple.py
│ │ ├── autoencoder.py
│ │ ├── dense.py
│ │ ├── linear_regression.py
│ │ ├── mlp_explicit.py
│ │ ├── mlp_inline.py
│ │ └── mlp_lazy.py
│ ├── lm1b/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── input_pipeline.py
│ │ ├── input_pipeline_test.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── requirements.txt
│ │ ├── temperature_sampler.py
│ │ ├── temperature_sampler_test.py
│ │ ├── tokenizer.py
│ │ ├── train.py
│ │ ├── train_test.py
│ │ └── utils.py
│ ├── mnist/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── main.py
│ │ ├── mnist.ipynb
│ │ ├── mnist_benchmark.py
│ │ ├── requirements.txt
│ │ ├── train.py
│ │ └── train_test.py
│ ├── nlp_seq/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── input_pipeline.py
│ │ ├── input_pipeline_test.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── requirements.txt
│ │ └── train.py
│ ├── nnx_toy_examples/
│ │ ├── 01_functional_api.py
│ │ ├── 02_lifted_transforms.py
│ │ ├── 03_train_state.py
│ │ ├── 04_data_parallel_with_jit.py
│ │ ├── 05_vae.py
│ │ ├── 06_scan_over_layers.py
│ │ ├── 07_array_leaves.py
│ │ ├── 08_save_load_checkpoints.py
│ │ ├── 09_parameter_surgery.py
│ │ ├── 10_fsdp_and_optimizer.py
│ │ ├── hijax_basic.py
│ │ ├── hijax_demo.py
│ │ └── requirements.txt
│ ├── ogbg_molpcba/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ ├── default.py
│ │ │ ├── default_graph_net.py
│ │ │ ├── hparam_sweep.py
│ │ │ └── test.py
│ │ ├── input_pipeline.py
│ │ ├── input_pipeline_test.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── models_test.py
│ │ ├── ogbg_molpcba.ipynb
│ │ ├── ogbg_molpcba_benchmark.py
│ │ ├── requirements.txt
│ │ ├── train.py
│ │ └── train_test.py
│ ├── ppo/
│ │ ├── README.md
│ │ ├── agent.py
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── env_utils.py
│ │ ├── models.py
│ │ ├── ppo_lib.py
│ │ ├── ppo_lib_test.py
│ │ ├── ppo_main.py
│ │ ├── requirements.txt
│ │ ├── seed_rl_atari_preprocessing.py
│ │ └── test_episodes.py
│ ├── seq2seq/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── input_pipeline.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── requirements.txt
│ │ ├── seq2seq.ipynb
│ │ ├── train.py
│ │ └── train_test.py
│ ├── sst2/
│ │ ├── README.md
│ │ ├── build_vocabulary.py
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── input_pipeline.py
│ │ ├── input_pipeline_test.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── models_test.py
│ │ ├── requirements.txt
│ │ ├── sst2.ipynb
│ │ ├── train.py
│ │ ├── train_test.py
│ │ ├── vocab.txt
│ │ └── vocabulary.py
│ ├── vae/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── input_pipeline.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── requirements.txt
│ │ ├── results/
│ │ │ └── .gitignore
│ │ ├── train.py
│ │ └── utils.py
│ └── wmt/
│ ├── README.md
│ ├── bleu.py
│ ├── configs/
│ │ └── default.py
│ ├── decode.py
│ ├── input_pipeline.py
│ ├── input_pipeline_test.py
│ ├── main.py
│ ├── models.py
│ ├── requirements.txt
│ ├── tokenizer.py
│ ├── train.py
│ └── train_test.py
├── flax/
│ ├── __init__.py
│ ├── configurations.py
│ ├── core/
│ │ ├── __init__.py
│ │ ├── axes_scan.py
│ │ ├── flax_functional_engine.ipynb
│ │ ├── frozen_dict.py
│ │ ├── lift.py
│ │ ├── meta.py
│ │ ├── nn/
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── linear.py
│ │ │ ├── normalization.py
│ │ │ └── stochastic.py
│ │ ├── partial_eval.py
│ │ ├── scope.py
│ │ ├── spmd.py
│ │ ├── tracers.py
│ │ └── variables.py
│ ├── cursor.py
│ ├── errors.py
│ ├── experimental/
│ │ ├── __init__.py
│ │ └── nnx.py
│ ├── ids.py
│ ├── io.py
│ ├── jax_utils.py
│ ├── linen/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── activation.py
│ │ ├── attention.py
│ │ ├── batch_apply.py
│ │ ├── combinators.py
│ │ ├── dtypes.py
│ │ ├── experimental/
│ │ │ └── layers_with_named_axes.py
│ │ ├── fp8_ops.py
│ │ ├── initializers.py
│ │ ├── kw_only_dataclasses.py
│ │ ├── linear.py
│ │ ├── module.py
│ │ ├── normalization.py
│ │ ├── partitioning.py
│ │ ├── pooling.py
│ │ ├── recurrent.py
│ │ ├── spmd.py
│ │ ├── stochastic.py
│ │ ├── summary.py
│ │ └── transforms.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ └── tensorboard.py
│ ├── nnx/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── bridge/
│ │ │ ├── __init__.py
│ │ │ ├── interop.py
│ │ │ ├── module.py
│ │ │ ├── variables.py
│ │ │ └── wrappers.py
│ │ ├── compat.py
│ │ ├── extract.py
│ │ ├── filterlib.py
│ │ ├── graph.py
│ │ ├── graphlib.py
│ │ ├── helpers.py
│ │ ├── ids.py
│ │ ├── module.py
│ │ ├── nn/
│ │ │ ├── __init__.py
│ │ │ ├── activations.py
│ │ │ ├── attention.py
│ │ │ ├── dtypes.py
│ │ │ ├── initializers.py
│ │ │ ├── linear.py
│ │ │ ├── lora.py
│ │ │ ├── normalization.py
│ │ │ ├── recurrent.py
│ │ │ └── stochastic.py
│ │ ├── proxy_caller.py
│ │ ├── pytreelib.py
│ │ ├── reprlib.py
│ │ ├── rnglib.py
│ │ ├── scripts/
│ │ │ ├── requirements.txt
│ │ │ └── run-all-examples.bash
│ │ ├── spmd.py
│ │ ├── statelib.py
│ │ ├── summary.py
│ │ ├── tracers.py
│ │ ├── training/
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ └── optimizer.py
│ │ ├── transforms/
│ │ │ ├── __init__.py
│ │ │ ├── autodiff.py
│ │ │ ├── compilation.py
│ │ │ ├── general.py
│ │ │ ├── iteration.py
│ │ │ └── transforms.py
│ │ ├── traversals.py
│ │ ├── variablelib.py
│ │ └── visualization.py
│ ├── oss/
│ │ └── .git-blame-ignore-revs
│ ├── py.typed
│ ├── serialization.py
│ ├── struct.py
│ ├── testing/
│ │ ├── __init__.py
│ │ └── benchmark.py
│ ├── traceback_util.py
│ ├── training/
│ │ ├── __init__.py
│ │ ├── checkpoints.py
│ │ ├── common_utils.py
│ │ ├── dynamic_scale.py
│ │ ├── early_stopping.py
│ │ ├── lr_schedule.py
│ │ ├── orbax_utils.py
│ │ ├── prefetch_iterator.py
│ │ └── train_state.py
│ ├── traverse_util.py
│ ├── typing.py
│ └── version.py
├── flaxlib_src/
│ ├── .gitignore
│ ├── CMakeLists.txt
│ ├── Cargo.toml
│ ├── LICENSE
│ ├── README.md
│ ├── pyproject.toml
│ └── src/
│ ├── flaxlib/
│ │ ├── __init__.py
│ │ └── flaxlib_cpp.pyi
│ └── lib.cc
├── nnx.py
├── pylintrc
├── pyproject.toml
└── tests/
├── checkpoints_test.py
├── colab_tpu_jax_version.ipynb
├── configurations_test.py
├── core/
│ ├── core_frozen_dict_test.py
│ ├── core_lift_test.py
│ ├── core_meta_test.py
│ ├── core_scope_test.py
│ └── design/
│ ├── core_attention_test.py
│ ├── core_auto_encoder_test.py
│ ├── core_big_resnets_test.py
│ ├── core_custom_vjp_test.py
│ ├── core_dense_test.py
│ ├── core_flow_test.py
│ ├── core_resnet_test.py
│ ├── core_scan_test.py
│ ├── core_tied_autoencoder_test.py
│ ├── core_vmap_test.py
│ └── core_weight_std_test.py
├── cursor_test.py
├── download_dataset_metadata.sh
├── early_stopping_test.py
├── flaxlib_test.py
├── import_test.ipynb
├── io_test.py
├── jax_utils_test.py
├── linen/
│ ├── initializers_test.py
│ ├── kw_only_dataclasses_test.py
│ ├── linen_activation_test.py
│ ├── linen_attention_test.py
│ ├── linen_batch_apply_test.py
│ ├── linen_combinators_test.py
│ ├── linen_dtypes_test.py
│ ├── linen_linear_test.py
│ ├── linen_meta_test.py
│ ├── linen_module_test.py
│ ├── linen_recurrent_test.py
│ ├── linen_test.py
│ ├── linen_transforms_test.py
│ ├── partitioning_test.py
│ └── summary_test.py
├── nnx/
│ ├── __init__.py
│ ├── bridge/
│ │ ├── module_test.py
│ │ └── wrappers_test.py
│ ├── containers_test.py
│ ├── filters_test.py
│ ├── graph_utils_test.py
│ ├── helpers_test.py
│ ├── ids_test.py
│ ├── integration_test.py
│ ├── metrics_test.py
│ ├── module_test.py
│ ├── mutable_array_test.py
│ ├── nn/
│ │ ├── attention_test.py
│ │ ├── conv_test.py
│ │ ├── embed_test.py
│ │ ├── linear_test.py
│ │ ├── lora_test.py
│ │ ├── normalization_test.py
│ │ ├── recurrent_test.py
│ │ └── stochastic_test.py
│ ├── optimizer_test.py
│ ├── partitioning_test.py
│ ├── rngs_test.py
│ ├── spmd_test.py
│ ├── state_test.py
│ ├── summary_test.py
│ ├── test_traversals.py
│ ├── transforms_test.py
│ └── variable_test.py
├── pickle_test.py
├── run_all_tests.sh
├── serialization_test.py
├── struct_test.py
├── tensorboard_test.py
├── traceback_util_test.py
└── traverse_util_test.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .git-blame-ignore-revs
================================================
# apply pyink
40a6e074e5224d733f964be00e21e0a1cb98bd2e
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug
assignees: ''
---
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
### System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Flax, jax, jaxlib versions (obtain with `pip show flax jax jaxlib`:
- Python version:
- GPU/TPU model and memory:
- CUDA version (if applicable):
### Problem you have encountered:
### What you expected to happen:
### Logs, error messages, etc:
### Steps to reproduce:
Whenever possible, please provide a *minimal example*. Please consider submitting it as a Colab link.
================================================
FILE: .github/analytics/README.md
================================================
# Repo Analytics
To run the repo analytics follow the steps below:
1. You must have a Github token, if you don't have one you can create one by following [this guide](https://docs.github.com/en/enterprise-server@3.4/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token).
2. Install the requirements:
```bash
pip install -r .github/analytics/requirements.txt
```
3. Run the analytics:
```bash
GITHUB_TOKEN=<token> \
python .github/analytics/get_repo_metrics.py \
--repo-owner google \
--repo-name flax
```
================================================
FILE: .github/analytics/get_repo_metrics.py
================================================
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from datetime import datetime
from collections.abc import Callable
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import pandas as pd
import requests
from absl import app, flags
token = os.environ['GITHUB_TOKEN']
endpoint = r'https://api.github.com/graphql'
headers = {'Authorization': f'bearer {token}'}
# ------------------------------------------------------------------------------
# GraphQL
# ------------------------------------------------------------------------------
# NOTE: This GraphQL logic was ported and adapted from this script:
# https://github.com/scientific-python/devstats-data/blob/4c022961abc4ca6061f8719d9c3387e98734b90c/query.py
# It contains style differences from Google's style guide.
def load_query_from_file(fname, repo_owner, repo_name) -> str:
with open(fname) as fh:
query = fh.read()
# Set target repo from template
query = query.replace('_REPO_OWNER_', repo_owner)
query = query.replace('_REPO_NAME_', repo_name)
return query
def send_query(query, query_type, cursor=None):
"""
Sends a GraphQL to the GitHub API.
No validation is done on the query before sending. GitHub GraphQL is
supported with the `cursor` argument.
Parameters
----------
query : str
The GraphQL query to be sent
query_type : {"issues", "pullRequests"}
The object being queried according to the GitHub GraphQL schema.
Currently only issues and pullRequests are supported
cursor : str, optional
If given, then the cursor is injected into the query to support
GitHub's GraphQL pagination.
Returns
-------
dict
The result of the query (json) parsed by `json.loads`
Notes
-----
This is intended mostly for internal use within `get_all_responses`.
"""
# TODO: Expand this, either by parsing the query type from the query
# directly or manually adding more query_types to the set
if query_type not in {'issues', 'pullRequests'}:
raise ValueError(
"Only 'issues' and 'pullRequests' queries are currently supported"
)
# TODO: Generalize this
# WARNING: The cursor injection depends on the specific structure of the
# query, this is the main reason why query types are limited to issues/PRs
if cursor is not None:
cursor_insertion_key = query_type + '('
cursor_ind = query.find(cursor_insertion_key) + len(cursor_insertion_key)
query = query[:cursor_ind] + f'after:"{cursor}", ' + query[cursor_ind:]
# Build request payload
payload = {'query': query}
response = requests.post(endpoint, json=payload, headers=headers)
return json.loads(response.content)
def get_all_responses(query, query_type):
'Helper function to bypass GitHub GraphQL API node limit.'
# Get data from a single response
initial_data = send_query(query, query_type)
data, last_cursor, total_count = parse_single_query(initial_data, query_type)
print(f'Retrieving {len(data)} out of {total_count} values...')
# Continue requesting data (with pagination) until all are acquired
while len(data) < total_count:
rdata = send_query(query, query_type, cursor=last_cursor)
pdata, last_cursor, _ = parse_single_query(rdata, query_type)
data.extend(pdata)
print(f'Retrieving {len(data)} out of {total_count} values...')
print('Done.')
return data
def parse_single_query(data, query_type):
"""
Parses the data returned by `send_query`
.. warning::
Like `send_query`, the logic here depends on the specific structure
of the query (e.g. it must be an issue or PR query, and must have a
total count).
"""
try:
total_count = data['data']['repository'][query_type]['totalCount']
data = data['data']['repository'][query_type]['edges']
last_cursor = data[-1]['cursor']
except KeyError as e:
print(data)
raise e
return data, last_cursor, total_count
class GithubGrabber:
"""
Pulls down data via the GitHub APIv.4 given a valid GraphQL query.
"""
def __init__(self, query_fname, query_type, repo_owner, repo_name):
"""
Create an object to send/recv queries related to the issue tracker
for the given repository via the GitHub API v.4.
The repository to query against is given by:
https://github.com/<repo_owner>/<repo_name>
Parameters
----------
query_fname : str
Path to a valid GraphQL query conforming to the GitHub GraphQL
schema
query_type : {"issues", "pullRequests"}
Type of object that is being queried according to the GitHub GraphQL
schema. Currently only "issues" and "pullRequests" are supported.
repo_owner : str
Repository owner.
repo_name : str
Repository name.
"""
self.query_fname = query_fname
self.query_type = query_type # TODO: Parse this directly from query
self.repo_owner = repo_owner
self.repo_name = repo_name
self.raw_data = None
self.load_query()
def load_query(self):
self.query = load_query_from_file(
self.query_fname, self.repo_owner, self.repo_name
)
def get(self):
self.raw_data = get_all_responses(self.query, self.query_type)
# ------------------------------------------------------------------------------
# metrics helpers
# ------------------------------------------------------------------------------
def _to_datetime(date_str: str) -> datetime:
return datetime.fromisoformat(date_str.replace('Z', ''))
def _get_issues_features(issues):
for issue in issues:
issue = issue['node']
created_at = _to_datetime(issue['createdAt'])
time_labeled_or_converted = None
time_issue_closed = None
for event in issue['timelineItems']['edges']:
event = event['node']
if event['__typename'] in {'LabeledEvent', 'ConvertedToDiscussionEvent'}:
time_labeled_or_converted = _to_datetime(event['createdAt'])
if event['__typename'] == 'ClosedEvent':
time_issue_closed = _to_datetime(event['createdAt'])
yield {
'created_at': created_at,
'time_labeled_or_converted': time_labeled_or_converted,
'time_issue_closed': time_issue_closed,
'issue_closed': issue['state'] == 'CLOSED',
}
def _get_pr_features(prs):
for pr in prs:
pr = pr['node']
created_at = _to_datetime(pr['createdAt'])
ready_for_review_at = _to_datetime(pr['createdAt'])
time_labeled_or_assigned = None
time_merged_or_closed = None
time_review = None
if pr['reviews']['nodes']:
review = pr['reviews']['nodes'][0]
time_review = _to_datetime(review['createdAt'])
for event in pr['timelineItems']['edges']:
event = event['node']
if (
time_labeled_or_assigned is None
and event['__typename'] == 'LabeledEvent'
and 'cla:' not in event['label']['name']
):
time_labeled_or_assigned = _to_datetime(event['createdAt'])
if (
time_labeled_or_assigned is None
and event['__typename'] == 'AssignedEvent'
):
time_labeled_or_assigned = _to_datetime(event['createdAt'])
if event['__typename'] in {'ClosedEvent', 'MergedEvent'}:
time_merged_or_closed = _to_datetime(event['createdAt'])
if event['__typename'] == 'ReadyForReviewEvent':
ready_for_review_at = _to_datetime(event['createdAt'])
yield {
'created_at': created_at,
'ready_for_review_at': ready_for_review_at,
'time_labeled_or_assigned': time_labeled_or_assigned,
'time_merged_or_closed': time_merged_or_closed,
'time_review': time_review,
'pr_closed': pr['state'] != 'OPEN',
}
def _start_of_month(date: datetime) -> datetime:
return date.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
def _shift_n_months(date: datetime, n: int) -> datetime:
month = ((date.month + n - 1) % 12) + 1
# shift to next year if necessary
if date.month > month:
date = date.replace(year=date.year + 1)
date = date.replace(month=month)
return date
def _rolling_window(
df: pd.DataFrame,
f: Callable[[pd.DataFrame], pd.Series],
window_size: int = 6,
step: int = 1,
) -> pd.DataFrame:
# start of month of the first issue
start: datetime = df.iloc[0]['created_at'].replace(
day=1, hour=0, minute=0, second=0, microsecond=0
)
end = _shift_n_months(start, window_size)
last_month = _start_of_month(df.iloc[-1]['created_at'])
last_month = _shift_n_months(last_month, 1)
rows: list[pd.Series] = []
while end < last_month:
row = f(df[(df['created_at'] >= start) & (df['created_at'] < end)])
row['period_start'] = start
row['period_end'] = end
rows.append(row)
start = _shift_n_months(start, step)
end = _shift_n_months(end, step)
df = pd.DataFrame(rows)
df = df[['period_start', 'period_end'] + list(df.columns[:-2])]
return df
def _process_prs(df: pd.DataFrame) -> pd.Series:
return pd.Series(
{
'pr_response_time': df['pr_response_time'].dt.days.mean(),
'pr_resolution_time': df['pr_resolution_time'].dt.days.mean(),
}
)
def _process_issues(df: pd.DataFrame) -> pd.Series:
return pd.Series(
{
'issue_response_time': df['issue_response_time'].dt.days.mean(),
'issue_resolution_time': df['issue_resolution_time'].dt.days.mean(),
}
)
# -----------------------------------------------------------------------------
# main
# -----------------------------------------------------------------------------
FLAGS = flags.FLAGS
flags.DEFINE_string('repo_owner', 'google', 'User name or organization')
flags.DEFINE_string('repo_name', 'flax', 'Name of the repository')
def main(_):
repo_owner: str = FLAGS.repo_owner
repo_name: str = FLAGS.repo_name
# Download issue data
issues = GithubGrabber(
'.github/analytics/issue_activity_since_date.gql',
'issues',
repo_owner=repo_owner,
repo_name=repo_name,
)
issues.get()
df_issues = df_issues0 = pd.DataFrame(
list(_get_issues_features(issues.raw_data))
)
df_issues['issue_response_time'] = (
df_issues['time_labeled_or_converted'] - df_issues['created_at']
)
df_issues['issue_resolution_time'] = (
df_issues['time_issue_closed'] - df_issues['created_at']
)
df_issues = _rolling_window(df_issues, _process_issues)
prs = GithubGrabber(
'.github/analytics/pr_data_query.gql',
'pullRequests',
repo_owner=repo_owner,
repo_name=repo_name,
)
prs.get()
df_prs = df_prs0 = pd.DataFrame(list(_get_pr_features(prs.raw_data)))
time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min(
axis=1
)
df_prs['pr_response_time'] = time_response - df_prs['ready_for_review_at']
df_prs['pr_resolution_time'] = (
df_prs['time_merged_or_closed'] - df_prs['ready_for_review_at']
)
df_prs = _rolling_window(df_prs, _process_prs)
# get cummulative issues
df_issues0 = df_issues0.copy()
df_issues0['number_of_issues'] = 1
df_issues0['number_of_issues'] = df_issues0['number_of_issues'].cumsum()
# get cummulative prs
df_prs0 = df_prs0.copy()
df_prs0['number_of_prs'] = 1
df_prs0['number_of_prs'] = df_prs0['number_of_prs'].cumsum()
# plot cumulative issues
plt.figure()
plt.plot(df_issues0['created_at'], df_issues0['number_of_issues'])
plt.xlabel('Date')
plt.ylabel('Number of issues')
plt.title('Number of issues')
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
# plot cumulative prs
plt.figure()
plt.plot(df_prs0['created_at'], df_prs0['number_of_prs'])
plt.xlabel('Date')
plt.ylabel('Number of PRs')
plt.title('Number of PRs')
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
# plot for isssue_response_time
plt.figure()
plt.plot(df_issues['period_end'], df_issues['issue_response_time'])
plt.xlabel('Date')
plt.ylabel('Issue Response Time (days)')
plt.title('Issue Response Time')
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
plt.ylim(0)
# plot for issue_resolution_time
plt.figure()
plt.plot(df_issues['period_end'], df_issues['issue_resolution_time'])
plt.xlabel('Date')
plt.ylabel('Issue Resolution Time (days)')
plt.title('Issue Resolution Time')
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
plt.ylim(0)
# plot for pr_response_time
plt.figure()
plt.plot(df_prs['period_end'], df_prs['pr_response_time'])
plt.xlabel('Date')
plt.ylabel('Pull Request Response Time (days)')
plt.title('Pull Request Response Time')
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
plt.ylim(0)
# plot for pr_resolution_time
plt.figure()
plt.plot(df_prs['period_end'], df_prs['pr_resolution_time'])
plt.xlabel('Date')
plt.ylabel('Pull Request Resolution Time (days)')
plt.title('Pull Request Resolution Time')
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
plt.ylim(0)
# show plots
plt.show()
if __name__ == '__main__':
app.run(main)
================================================
FILE: .github/analytics/issue_activity_since_date.gql
================================================
{
# Queries all the issues in a repo. For each issue, we get some basic data such as
# the number, state, labels, and title. The most important part is the 'timelineItems'
# which are the events that happened to the issue, we can use the information about
# the datetime about certain key events to define some metrics. Note that we are
# getting more information than is probably needed but its fine for now.
repository(owner: "_REPO_OWNER_", name: "_REPO_NAME_") {
issues(first: 100) {
totalCount
edges {
cursor
node {
number
title
createdAt
state
closedAt
updatedAt
url
labels(first: 100) {
edges {
node {
name
}
}
}
timelineItems(first: 100, itemTypes: [LABELED_EVENT, CONVERTED_TO_DISCUSSION_EVENT, ISSUE_COMMENT, CLOSED_EVENT]) {
totalCount
edges {
node {
__typename
... on ConvertedToDiscussionEvent {
createdAt
}
... on IssueComment {
author {
login
}
createdAt
}
... on ClosedEvent {
actor {
login
}
createdAt
}
... on LabeledEvent {
label {
name
}
createdAt
}
}
}
}
}
}
}
}
}
================================================
FILE: .github/analytics/pr_data_query.gql
================================================
query {
# Queries all the Pull Requests in a repo. For each issue, we get some basic data such as
# the number, state, reviews, and title. The most important part is the 'timelineItems'
# which are the events that happened to the issue, we can use the information about
# the datetime about certain key events to define some metrics. We also use the 'reviews'
# as indicators for certain metrics. Note that we are getting more information than is
# probably needed but its fine for now.
repository(owner:"_REPO_OWNER_", name:"_REPO_NAME_") {
pullRequests(first:100) {
totalCount
edges {
cursor
node {
number
state
title
createdAt
author{
login
}
mergedAt
reviews(first: 100){
nodes {
createdAt
}
}
timelineItems(first: 100, itemTypes: [LABELED_EVENT, ASSIGNED_EVENT, MERGED_EVENT, READY_FOR_REVIEW_EVENT, CLOSED_EVENT]) {
edges {
node {
__typename
... on ClosedEvent {
actor {
login
}
createdAt
}
... on LabeledEvent {
label {
name
}
actor {
login
}
createdAt
}
... on MergedEvent {
actor {
login
}
createdAt
}
... on ReadyForReviewEvent {
actor {
login
}
createdAt
}
... on AssignedEvent {
actor {
login
}
createdAt
}
}
}
}
}
}
}
}
}
================================================
FILE: .github/analytics/requirements.txt
================================================
pandas
absl-py
requests
matplotlib
================================================
FILE: .github/pull_request_template.md
================================================
# What does this PR do?
<!--
Great, you are contributing to Flax!
But... please read the following carefully so we can make sure your PR is merged
easily.
Replace this text block with a description of the change and which issue it
fixes (if applicable). Please also include relevant motivation/context.
Once you're done, someone in the Flax team will review your PR shortly. They may
suggest changes to make the code even better. If no one reviewed your PR after a
week has passed, don't hesitate to post a new comment @-mentioning the same
persons (sometimes notifications get lost).
-->
Fixes # (issue)
## Checklist
- [ ] This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
- [ ] This change is discussed in a Github issue/[discussion](https://github.com/google/flax/discussions) (please add a link).
- [ ] The documentation and docstrings adhere to the [documentation guidelines](https://github.com/google/flax/blob/main/docs/README.md#how-to-write-code-documentation).
- [ ] This change includes necessary high-coverage tests. (No quality testing = no merge!)
================================================
FILE: .github/workflows/flax_publish.yml
================================================
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Flax - Build and upload to PyPI
on:
release:
types: [published]
jobs:
build_package:
name: Build package
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Build package
run: pipx run build
- name: List files
run: ls -l dist/
- uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
with:
name: distribution
path: |
dist/*.tar.gz
dist/*.whl
upload_pypi:
name: Release & Upload to PyPI
# Only publish release to PyPI when a github release is created.
if: github.event_name == 'release' && github.event.action == 'published'
needs: build_package
runs-on: ubuntu-latest
environment: release
permissions:
id-token: write
steps:
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
name: distribution
path: dist
- name: List files
run: ls -l dist/
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4
discord_release:
if: github.repository_owner == 'google'
runs-on: ubuntu-latest
steps:
- name: Get release URL
id: get-release-url
run: |
URL="https://github.com/google/flax/releases"
echo "::set-output name=URL::$URL"
- name: Get content
uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1
id: get-content
with:
stringToTruncate: |
Flax [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released!
${{ github.event.release.body }}
maxLength: 2000
truncationSymbol: "..."
- name: Discord Webhook Action
uses: tsickert/discord-webhook@c840d45a03a323fbc3f7507ac7769dbd91bfb164 # v5.3.0
with:
webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }}
content: ${{ steps.get-content.outputs.string }}
================================================
FILE: .github/workflows/flax_test.yml
================================================
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Flax - Test
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
on:
push:
branches:
- main
- 'test_*'
pull_request:
branches:
- main
jobs:
pre-commit:
name: Test pre-commit hooks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: '3.11'
- run: python -m pip install pre-commit
- uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'pyproject.toml') }}
- run: pre-commit run --show-diff-on-failure --color=always --all-files
commit-count:
name: Check commit count
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# We allow at most 5 commits in a branch to ensure our CI doesn't break.
- name: Check commit count in PR
if: always()
shell: bash
run: |
set -x
# $GITHUB_REF is in format `refs/heads/<branch_name>`. We fetch it under
# the name `commit-count` so we can refer to it below.
# Do an unshallow fetch so we retrieve all commits (this is necessary
# because ations/checkout@v2 fetches a shallow copy).
git fetch origin --unshallow $GITHUB_REF:commit-count
git fetch origin main
diff=$(git rev-list --count origin/main...commit-count)
# $GITHUB_REF adds an additional commit to the commit tree, so $diff is
# one too high when executing this as a Github Action.
if (( $diff > 6)); then
echo "ERROR! More than 5 commits in PR -- please squash your commits."
url=https://flax.readthedocs.io/en/latest/contributing.html#too-many-commits-in-a-pull-request
echo "See $url for help on how to resolve this."
exit 1
fi
test-import:
name: Test import standalone
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.11', '3.12', '3.13', '3.14']
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
with:
version: "0.8.13"
- name: Install standalone dependencies only
run: |
uv sync
- name: Test importing Flax
run: |
uv run --no-sync python -c "import flax"
tests:
name: Run Tests
needs: [pre-commit, commit-count]
runs-on: ubuntu-24.04-16core
strategy:
# Make sure to change `github_check_runs` in `copy.bara.sky` if you change the tests here.
matrix:
python-version: ['3.11', '3.12', '3.13']
test-type: [doctest, pytest]
jax-version: [newest]
include:
# keep in sync with internal type checking
- python-version: '3.12'
test-type: pytype
jax-version: 'newest'
- python-version: '3.12'
test-type: mypy
jax-version: 'newest'
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Setup uv
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
with:
version: "0.8.13"
python-version: ${{ matrix.python-version }}
enable-cache: true
- name: Install dependencies
run: |
rm -fr .venv
uv sync --extra testing --extra docs
- name: Install JAX
run: |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
uv pip install -U jax jaxlib
else
uv pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
if [[ "${{ matrix.test-type }}" == "pytest" ]]; then
uv pip install -U tensorflow-datasets
fi
- name: Test with ${{ matrix.test-type }}
run: |
if [[ "${{ matrix.test-type }}" == "doctest" ]]; then
uv run --no-sync tests/run_all_tests.sh --only-doctest
elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then
uv run --no-sync tests/run_all_tests.sh --only-pytest
elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then
uv run --no-sync tests/run_all_tests.sh --only-pytype
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then
uv run --no-sync tests/run_all_tests.sh --only-mypy
else
echo "Unknown test type: ${{ matrix.test-type }}"
exit 1
fi
- name: Upload coverage to Codecov
if: matrix.test-type == 'pytest'
uses: codecov/codecov-action@1e68e06f1dbfde0e4cefc87efeba9e4643565303 # v5.1.2
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
with:
file: ./coverage.xml
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Report success or failure as github status
if: always()
shell: bash
run: |
status="${{ job.status }}"
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
curl -sS --request POST \
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
--header 'content-type: application/json' \
--data '{
"state": "'$lowercase_status'",
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
"description": "'$status'",
"context": "github-actions/Build"
}'
# This is a temporary workflow to test flax on Python 3.14 and
# skipping deps like tensorstore, tensorflow etc
tests-python314:
name: Run Tests on Python 3.14
needs: [pre-commit, commit-count]
runs-on: ubuntu-24.04-16core
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Setup uv
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
with:
version: "0.9.2"
python-version: "3.14"
activate-environment: true
enable-cache: true
- name: Install dependencies
run: |
rm -fr .venv
uv sync --extra testing --extra docs
- name: Test with pytest
run: |
export XLA_FLAGS='--xla_force_host_platform_device_count=4'
find tests/ -name "*.py" | grep -vE 'io_test|tensorboard' | xargs pytest -n auto
================================================
FILE: .github/workflows/flaxlib_publish.yml
================================================
name: Flaxlib - Build and upload to PyPI
# for testing only:
on:
push:
branches: [main]
paths: ['flaxlib/**']
release:
types: [published]
jobs:
build_wheels:
if: github.event_name == 'push' && contains(github.event.head_commit.modified, 'flaxlib/')
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
# macos-13 is an intel runner, macos-14 is apple silicon
os: [ubuntu-latest, windows-latest, macos-13, macos-14]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
- name: Setup Rust
uses: actions-rust-lang/setup-rust-toolchain@11df97af8e8102fd60b60a77dfbf58d40cd843b8 # v1.10.1
- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.21.0
- name: Build wheels
run: python -m cibuildwheel --output-dir ./flaxlib/wheelhouse ./flaxlib
env:
# rust doesn't seem to be available for musl linux on i686
CIBW_SKIP: "*-musllinux_i686"
CIBW_ENVIRONMENT: 'PATH="$HOME/.cargo/bin:$PATH" CARGO_TERM_COLOR="always"'
CIBW_ENVIRONMENT_WINDOWS: 'PATH="$UserProfile\.cargo\bin;$PATH"'
CIBW_BEFORE_BUILD: rustup show
CIBW_BEFORE_BUILD_LINUX: |
curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=stable --profile=minimal -y &&
rustup show
- uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
with:
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
path: ./flaxlib/wheelhouse/*.whl
build_sdist:
if: github.event_name == 'push' && contains(github.event.head_commit.modified, 'flaxlib/')
name: Build source distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Setup Rust
uses: actions-rust-lang/setup-rust-toolchain@11df97af8e8102fd60b60a77dfbf58d40cd843b8 # v1.10.1
- name: Build sdist
run: pipx run build --sdist flaxlib
- uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
with:
name: cibw-sdist
path: ./flaxlib/dist/*.tar.gz
upload_pypi:
if: github.event_name == 'push' && contains(github.event.head_commit.modified, 'flaxlib/')
name: Upload to PyPI
needs: [build_wheels, build_sdist]
runs-on: ubuntu-latest
permissions:
id-token: write
steps:
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools build wheel twine
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
# unpacks all CIBW artifacts into dist/
pattern: cibw-*
path: ./flaxlib/dist
merge-multiple: true
- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.FLAXLIB_PYPI_TOKEN }}
run: |
twine upload flaxlib/dist/*
================================================
FILE: .github/workflows/jax_nightly.yml
================================================
name: CI - with JAX nightly
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting this file
branches:
- main
paths:
- '**workflows/jax_nightly.yml'
jobs:
jax-nightly:
runs-on: ubuntu-latest
permissions:
contents: read
issues: write # for failed-build-issue
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Setup uv
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
with:
version: "0.8.13"
- name: Install dependencies
run: |
uv sync --extra testing --extra docs
- name: Install JAX
run: |
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
- name: Run test suite
if: success()
run: |
uv run tests/run_all_tests.sh --only-pytest
- name: Notify failed build
uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0
if: failure() && github.event.pull_request == null
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
================================================
FILE: .gitignore
================================================
*~
\#*\#
*.pyc
.tfds
.DS_Store
dist/
build/
*.egg-info
*.rej
.pytype
.vscode/*
/.devcontainer
docs*/**/_autosummary
docs*/_build
docs*/**/tmp
flaxlib_src/build
flaxlib_src/builddir
flaxlib_src/dist
flaxlib_src/subprojects
.venv
venv/
venv.bak/
# used by direnv
.envrc
# uv
uv.lock
# custom
/tmp-files
.agent/rules/flax.md
.github/copilot-instructions.md
.env
================================================
FILE: .pre-commit-config.yaml
================================================
# Install the pre-commit hooks below with
# 'pre-commit install'
# Auto-update the version of the hooks with
# 'pre-commit autoupdate'
# Run the hooks on all files with
# 'pre-commit run --all'
repos:
- repo: https://github.com/mwouts/jupytext
rev: v1.13.8
hooks:
- id: jupytext
args: [--sync]
# diable pyink for now
# - repo: https://github.com/google/pyink
# rev: 23.5.0
# hooks:
# - id: pyink
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-toml
- id: trailing-whitespace
exclude: ^docs.*\.md$
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
- id: nbstripout
exclude: ^examples/.*
args: [
--keep-output,
--keep-count,
--extra-keys,
"cell.metadata.executionInfo cell.metadata.id metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab",
]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.3
hooks:
# Run the Ruff linter.
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
# Disable Ruff formatter for now
# # Run the Ruff formatter.
# - id: ruff-format
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
hooks:
- id: pyupgrade
args: [--py310-plus]
================================================
FILE: .readthedocs.yml
================================================
# deprecated
================================================
FILE: AUTHORS
================================================
# This is the list the Flax authors for copyright purposes.
#
# This does not necessarily list everyone who has contributed code, since in
# some cases, their employer may be the copyright holder. To see the full list
# of contributors, see the revision history in source control.
Google LLC
================================================
FILE: CHANGELOG.md
================================================
Changelog
----------
vNext
------
(Add your change to a random empty line to avoid merge conflicts)
-
-
-
-
- removed GeGLU simplistic activation, it should be implemented manually.
-
-
-
- removed FLAX_LAZY_RNG flag support for old non-lazy PRNG derivation mode
-
-
-
-
-
-
-
-
-
-
0.8.2
-----
- fixed rng guide outputs by @chiamp in https://github.com/google/flax/pull/3685
- enforce mask kwarg in norm layers by @chiamp in https://github.com/google/flax/pull/3663
- added kwargs to self.param and self.variable by @chiamp in https://github.com/google/flax/pull/3675
- added nnx normalization tests by @chiamp in https://github.com/google/flax/pull/3689
- added NNX init_cache docstring example by @chiamp in https://github.com/google/flax/pull/3688
- added nnx attention equivalence test by @chiamp in https://github.com/google/flax/pull/3687
- Fix bug that assumed frozen-dict keys were strings. by @copybara-service in https://github.com/google/flax/pull/3692
- added nnx rmsnorm by @chiamp in https://github.com/google/flax/pull/3691
- updated nnx compute_stats by @chiamp in https://github.com/google/flax/pull/3693
- fixed intercept_methods docstring by @chiamp in https://github.com/google/flax/pull/3694
- [nnx] Add Sphinx Docs by @cgarciae in https://github.com/google/flax/pull/3678
- Fix pointless docstring example of nn.checkpoint / nn.remat. by @levskaya in https://github.com/google/flax/pull/3703
- added default params rng to .apply by @chiamp in https://github.com/google/flax/pull/3698
- [nnx] add partial_init by @cgarciae in https://github.com/google/flax/pull/3674
- make make_rng default to 'params' by @chiamp in https://github.com/google/flax/pull/3699
- Add SimpleCell. by @carlosgmartin in https://github.com/google/flax/pull/3697
- fix Module.module_paths docstring by @cgarciae in https://github.com/google/flax/pull/3709
- Guarantee the latest JAX version on CI by @cgarciae in https://github.com/google/flax/pull/3705
- Replace deprecated API `jax.tree.map` by @copybara-service in https://github.com/google/flax/pull/3715
- Use `jax.tree_util.tree_map` instead of deprecated `jax.tree.map`. by @copybara-service in https://github.com/google/flax/pull/3714
- [nnx] simplify readme by @cgarciae in https://github.com/google/flax/pull/3707
- [nnx] add demo.ipynb by @cgarciae in https://github.com/google/flax/pull/3680
- Fix Tabulate's compute_flops by @cgarciae in https://github.com/google/flax/pull/3721
- [nnx] simplify TraceState by @cgarciae in https://github.com/google/flax/pull/3724
- Add broadcast of `strides` and `kernel_dilation` to `nn.ConvTranspose` by @IvyZX in https://github.com/google/flax/pull/3731
- [nnx] Fix State.__sub__ by @cgarciae in https://github.com/google/flax/pull/3704
- [nnx] always fold_in on fork + new ForkedKeys return type by @cgarciae in https://github.com/google/flax/pull/3722
- [nnx] explicit Variables by @cgarciae in https://github.com/google/flax/pull/3720
- Improves fingerprint definition for Modules in nn.jit. by @copybara-service in https://github.com/google/flax/pull/3736
- Flax: avoid key reuse in tests by @copybara-service in https://github.com/google/flax/pull/3740
- added Einsum layer by @chiamp in https://github.com/google/flax/pull/3710
- nn.jit: automatic fingerprint definition for dataclass attributes by @cgarciae in https://github.com/google/flax/pull/3737
- [NVIDIA] Use custom grad accumulation for FP8 params by @kaixih in https://github.com/google/flax/pull/3623
- removed nnx dataclass by @chiamp in https://github.com/google/flax/pull/3742
- [nnx] cleanup graph_utils by @cgarciae in https://github.com/google/flax/pull/3728
- Fix doctest and unbreak head by @IvyZX in https://github.com/google/flax/pull/3753
- [nnx] add pytree support by @cgarciae in https://github.com/google/flax/pull/3732
- fixed intercept_methods docstring by @chiamp in https://github.com/google/flax/pull/3752
- Add ConvLSTMCell to docs. by @carlosgmartin in https://github.com/google/flax/pull/3712
- [nnx] remove flagslib by @cgarciae in https://github.com/google/flax/pull/3733
- Fix tests after applying JAX key-reuse checker. See: by @copybara-service in https://github.com/google/flax/pull/3748
0.8.1
-----
- Added default collection in `make_rng`.
- Added `InstanceNorm` and renamed `channel_axes` to `feature_axes`.
- Added norm equivalence tests.
- Added `Module.module_paths` and doc.
- make `Sequential.__call__` compact.
- Added `nn.compact_name_scope` v3.
- Add explicit control over frozen/slots setting in `flax.struct.dataclass`.
- Replacing `jax.tree_util.tree_map` with mapping over leafs.
- Fixed docs and docstrings.
0.8.0
-----
- Added [NNX](https://github.com/google/flax/tree/main/flax/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch.
- Added `nn.compact_name_scope` decorator that enables methods to act as compact name scopes as with regular Haiku methods. This makes porting Haiku code easier.
- Add copy() method to Module. This is a user-friendly version of the internal clone() method with better
defaults for common use cases.
- Added [`BatchApply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#batchapply) class.
- Added `sow_weights` option in attention layer.
- Added [`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.MultiHeadAttention.html) alias.
- Added kwargs support for `nn.jit`.
- Deprecated `normalize` activation function, in favor of `standardize`.
- Added `GeGLU` activation function.
- Added `Enum` support for `tabulate` function.
- Added simple argument-only lifted `nn.grad` function.
0.7.5
-----
- Report forward and backward pass FLOPs of modules and submodules in `linen.Module.tabulate` and `summary.tabulate` (in new `flops` and `vjp_flops` table columns). Pass `compute_flops=True` and/or `compute_vjp_flops=True` to include these columns.
- Re-factored `MultiHeadDotProductAttention`'s call method signature, by adding
`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic`
to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389).
- Use new typed PRNG keys throughout flax: this essentially involved changing
uses of `jax.random.PRNGKey` to `jax.random.key`.
(See [JEP 9263](https://github.com/jax-ml/jax/pull/17297) for details).
If you notice dispatch performance regressions after this change, be sure
you update `jax` to version 0.4.16 or newer.
- Added `has_improved` field to EarlyStopping and changed the return signature of
`EarlyStopping.update` from returning a tuple to returning just the updated class.
See more details in [#3385](https://github.com/google/flax/pull/3385)
0.7.4
-----
New features:
- Add QK-normalization to MultiHeadDotProductAttention
- Allow apply's method argument to accept submodules
- Add module path to nn.module.
- [JAX] Generate new type of PRNG keys
Bug fixes:
- Directly call original method if method interceptor stack is empty.
- fix stackoverflow when loading pickled module
- Improve kw_only_dataclass.
- Allow pass-through implementation of state dict
- Promote dot_general injections from a function to a module.
0.7.2
-----
New features:
- make `flax.core.copy` `add_or_replace` optional
- Add `use_fast_variance` option to `GroupNorm` and `BatchNorm` to allow disabling it.
Bug fixes:
- Use `field_specifiers` instead of `field_descriptors` in `@dataclass_transform`.
- Fix `nn.Module` typing.
- [JAX] Replace uses of `jax.experimental.pjit.with_sharding_constraint` with `jax.lax.with_sharding_constraint`.
0.7.1
-----
Breaking changes:
- Migrating Flax from returning FrozenDicts to returning regular dicts. More details can be found in this [announcement](https://github.com/google/flax/discussions/3191)
New features:
- Use pyink
- added dict migration guide to index
- add scan over layers section
- Expose options to customize rich.Table
- add support for initializing carry variables in scan
- Let Flax-Orbax to not port the shape of `target` arrays when they port the `target` shardings.
Bug fixes:
- Use import `orbax.checkpoint` which is a better import pattern.
- Use import `orbax.checkpoint as ocp` to avoid the verbosity of using 'orbax.checkpoint` every time.
- [linen] Add alternative, more numerically stable, variance calculation to `LayerNorm`.
- [linen] Minor cleanup to normalization code.
- Fix norm calculation bug for 0-rank arrays.
- [JAX] Remove references to jax.config.jax_array.
- [linen] Use `stack` instead of `concatenate` in `compute_stats`, to handle scalar stats case.
- [linen] More minor cleanup in normalization `compute_stats`.
- Fix warnings from atari gym.
- Refactor TypeHandler to operate over batches of values, rather than individual ones. This allows more flexibility for implementations that may operate more efficiently on batches.
- Fix carry slice logic
- make flax_basics guide use utility fns
- Fix checkpointing guide error at head
- Improve scan docs
0.7.0
-----
- RNNCellBase refactor.
0.6.11
-----
- Set Orbax-as-backend to be the default checkpointing method.
- Fix setup trigger issue under sharing and transforms.
- Add collection to self.scope.reserve(name, col) so that sow works with the same name in different collections.
- Minor improvements for Sequential.
- Improve the error message in MultiHeadDotProductAttention.
- Allow manually specifying the rng key for Dropout.
- RNN refactor.
- fixed missing separator for rng fold in.
0.6.10
-----
- Rudimentary quantization support: some layers can be parametrized with custom dot_general and conv_general_dilated.
0.6.9
-----
- Depend on `orbax-checkpoint` package instead of `orbax`.
- Refactored setup scripts to `project.toml`.
- Added pretty_repr utility fn.
- Fix get_partition_spec on replicated array.
- Updates imagenet.ipynb to use GPU Colab runtime, and fixed config.
- Upgrade checkpointing code to `jax.sharding`, and with more warnings.
0.6.8
-----
- The automatic checkpoint migration was temporarily rolled back due to legacy compatibility issues.
- We still recommend you to use the [upgrade guide](https://flax.readthedocs.io/en/latest/guides/orbax_upgrade_guide.html) and migrate completely to the Orbax API to ensure stability.
- Or alternatively, add `flax.config.update('flax_use_orbax_checkpointing', True)` to your project to avoid being impacted by the automatic migration process.
- Added utility functions to frozen_dict api.
- Migrated Flax away from `register_keypaths`.
- Fixes kwargs in convert_to_graphs_tuple_fn.
- Fixed examples in a few ways:
- Bumped the TF version
- Used latest checkpoint formats
- Other misc fixes.
0.6.7
-----
- New checkpoints will be saved using Orbax! Please check out [upgrade guide](https://flax.readthedocs.io/en/latest/guides/orbax_upgrade_guide.html) and consider migrating completely to the Orbax API.
- You could `flax.config.update('flax_use_orbax_checkpointing', False)` to temporarily disable this migration, but note that Flax legacy checkpointing will be removed 3 months from Mar 10, 2023.
- Migrating `FrozenDict` to regular dict: utility functions now work on both.
- Migrated Flax dataclass and `FrozenDict` to JAX pytree keypath API.
- Fixed pytype and improved typing for `Module`
- Fixed up uses of PyTree and PyTreeDef types.
0.6.6
-----
- 0.6.5 was yanked so this release contains all that was in 0.6.5 as well.
- Migrated regular dict to FrozenDict, currently controlled by a flag.
- Refactored and separate out name relaxation policy changes.
- Added RMS normalization layer.
0.6.5
-----
- Added logical partitioning helpers for using pjit with Flax.
- Add ``Module.lazy_init`` to avoid compute during Module initialization.
0.6.4
-----
New features:
- Our [ReadTheDoc](https://flax.readthedocs.io/en/latest/index.html) site is a lot more organized now! More improvements on the way.
- Flax auto-SPMD parallelism API to work seamlessly with `jax.pjit`: https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html
- Added new `zeros_init` and `ones_init` initializers.
- Adds standardize initializer.
- Allowed specifying method as a string.
- Allowed runtime overwrite of `flax.config` flags.
Bug fixes:
- Added missing `dataclass.fields` from `__repr__`.
- Renamed ConvLSTM to ConvLSTMCell.
- Fix some tiny inconsistencies between scope.py and module.py.
- Improved many many docstrings, comments and error messages.
0.6.3
-----
New features:
- Flax checkpointing now uses [Orbax](https://github.com/google/orbax) for more flexiblity and features.
- Added support for python 3.10 and removed support for 3.7.
Bug fixes:
- Fixed rng generation in DenseGeneral init.
- Improved support for Mac M1 chip.
- Bumped package versions for a bunch of examples.
- Improved many docstrings and error messages.
0.6.2
-----
New features:
- Add rng_collection argument to Dropout.
- Fix flax.linen.stochastic.Dropout.
- Add flag allow_partial_mpa_restoration in checkpointing.
- Use `gfile.remove` for files because it doesn't work on GCS files.
- Added guides for: Flax the Sharp Bits, Checkpointing, Extracting Gradients
- Improved existed documentation pages.
- Improved errors, error messages and tests.
- Removed codebase's trailing whitespaces.
Bug fixes:
- Fixes launch_gce.sh with imagenet example.
- Properly report AttributeErrors from descriptors.
- Fixes usages of `pmap`.
- Return None if no _parent_ref is set.
- Cap dynamic scale to float32 max.
- no-op when double wrapping with struct.dataclass.
- Allow variable_with_axes to have empty axes when axes is set to an empty tuple.
- Don't create reference cycles among Modules.
0.6.1
-----
- Adds axis_name and axis_index_groups to LayerNorm and GroupNorm. by @copybara-service in [#2402](https://github.com/google/flax/pull/2402)
- Plumb spmd_axis_name through transforms.vmap through to JAX vmap by @copybara-service in [#2398](https://github.com/google/flax/pull/2398)
- Support multiple inputs in flax lifted vjp/custom_vjp by @copybara-service in [#2399](https://github.com/google/flax/pull/2399)
- Improve tabulate by @cgarciae in [#2316](https://github.com/google/flax/pull/2316)
- Add path_aware_map function by @cgarciae in [#2371](https://github.com/google/flax/pull/2371)
- Add static_argnums to nn.checkpoint by @cgarciae in [#2457](https://github.com/google/flax/pull/2457)
- Adding "count_include_pad" argument to flax.linen.pooling.avg_pool by @dslisleedh in [#2451](https://github.com/google/flax/pull/2451)
- Add perturb() to allow capturing intermediate gradients by @IvyZX in [#2476](https://github.com/google/flax/pull/2476)
0.6.0
-----
- Removed deprecated optimizers in `flax.optim` package.
- Moved `flax.optim.dynamic_scale` to `flax.training.dynamic_scale`.
- Switched to using `jax.named_scope` for all profile naming, cut some pointless
stack traces out.
0.5.3
-----
New features:
- Added `nn.switch` as a lifted version of `jax.lax.switch`.
- Added a method for detecting the use of "init" functions.
- Added checkpointing support for `jax.experimental.GlobalDeviceArray`, a useful array type for multiprocess/multihost computing.
- Added async option to `save_checkpoints()` on single-process scenario.
- Improved documentation pages.
Bug fixes:
- Fixed variable aliasing in put_variable
- Fixed missing passthrough of nn.scan unroll arg
- Fixed the MNIST example
0.5.2
-----
- Fixes missing PyYAML dependency.
0.5.1
-----
New features:
- Added `nn.tabulate` and `Module.tabulate` to generate rich representations of the network structure.
0.5.0
-----
- Added `flax.jax_utils.ad_shard_unpad()` by @lucasb-eyer
- Implemented [default dtype FLIP](https://github.com/google/flax/blob/main/docs/flip/1777-default-dtype.md).
This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32.
This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate
complex numbers to their real component by default. Instead the complex dtype is preserved by default.
Bug fixes:
- Fix support for JAX's experimental_name_stack.
Breaking changes:
- In rare cases the dtype of a layer can change due to [default dtype FLIP](https://github.com/google/flax/blob/main/docs/flip/1777-default-dtype.md). See the "Backward compatibility" section of the proposal for more information.
0.4.2
-----
New features:
- Add lifted conditional `nn.cond`.
- Improved error messages: parameters not found, loading checkpoints.
- Replace `jax.tree_multimap` (deprecated) with `jax.tree.map`.
- Add the "Module Lifecycle" design note.
- Add support for JAX dynamic stack-based named_call
Bug fixes:
- Handle rate==1.0 edgecase in Dropout.
- Fix bug where Linen Module state is reused.
- Bug fixes and generalizations of nn.partitioning API.
0.4.1
-----
New features:
- Added locally-connected (unshared CNN) layer `flax.linen.ConvLocal`.
- Improved seq2seq example: Factored our model and input pipeline code.
- Added Optax update guide and deprecated `flax.optim`.
- Added `sep` argument to `flax.traverse_util.flatten_dict()`.
- Implemented Sequential module, in `flax.linen.combinators`.
0.4.0
------
Breaking changes:
- flax.deprecated.nn is removed. Please pin to flax==0.3.6 if you are still using it.
- PixelCNN++ example is removed. It was not working well on TPU.
- linen Normalization layers no longer downcast double and complex floats tofloat32
when computing the mean and variance.
New features:
- Added `flax.linen.custom_vjp` for custom derivatives inside a `Module`.
- Add `param_dtype` attribute to standard Linen Modules for specifying parameter dtypes.
0.3.6
------
Breaking changes:
- Move `flax.nn` to `flax.deprecated.nn`.
New features:
- Add experimental checkpoint policy argument. See `flax.linen.checkpoint`
- Add lifted versions of jvp and vjp.
- Add lifted transformation for mapping variables. See `flax.linen.map_variables`.
0.3.5
------
Breaking changes:
- You can no longer pass an int as the `kernel_size` for a `flax.linen.Conv.
Instead a type error is raised stating that
a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not
ambigious when the kernel rank is known.
- `flax.linen.enable_named_call` and `flax.linen.disable_named_call` now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now `flax.linen.override_named_call` that provided a context manager to locally disable/enable named_call.
- NamedTuples are no longer converted to tuples on assignment to a `linen.Module`.
New features:
- Flax internal stack frames are now removed from exception state traces.
- Added `flax.linen.nowrap` to decorate method that should not be transformed
because they are stateful.
- Flax no longer uses implicit rank broadcasting. Thus, you can now use Flax with `--jax_numpy_rank_promotion=raise`.
Bugfixes:
- linen Modules and dataclasses made with `flax.struct.dataclass` or `flax.struct.PyTreeNode` are now correctly recognized as dataclasses by static analysis tools like PyLance. Autocomplete of constructors has been verified to work with VSCode.
- Fixed a bug in FrozenDict which didn't allow copying dicts with reserved names.
- Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated ([bug](https://github.com/google/flax/issues/1429)).
- Mixed precision training with float16 now works correctly with the attention layers.
- auto-generated linen Module `__hash__`, `__eq__`, `__repr__` no longer fail by default on non-init attributes.
0.3.4
------
Possibly breaking changes:
- When calling `init` the 'intermediates' collection is no longer mutable.
Therefore, intermediates will no longer be returned from initialization by default.
- Don't update batch statistics during initialization.
- When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the `deterministic` argument in `MultiHeadDotProductAttention`.
Other changes:
- Rewrote various examples to use Optax instead of Flax optimizers (e.g., Imagenet, SST2).
- Added an NLP text classification example (on the SST-2 dataset) to
[`examples/sst2`](https://github.com/google/flax/tree/main/examples/sst2).
that uses a bidirectional LSTM (BiLSTM) to encode the input text.
- Added `flax.training.train_state` to simplify using Optax optimizers.
- `mutable` argument is now available on `Module.init` and `Module.init_with_outputs`
- Bug fix: Correctly handle non-default parameters of Linen Modules with nested inheritance.
- Expose `dot_product_attention_weights`, allowing access to attention weights.
- `BatchNorm` instances will behave correctly during init when called multiple times.
- Added a more extensive "how to contribute" guide in `contributing.md`.
- Add proper cache behavior for [`lift.jit`](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.jit.html#flax.linen.jit),
fixing cache misses.
- Fix bug in Embed layer: make sure it behaves correctly when embedding is np.array.
- Fix `linen.Module` for deep inheritance chains.
- Fix bug in DenseGeneral: correctly expand bias to account for batch & noncontracting dimensions.
- Allow Flax lifted transforms to work on partially applied Modules.
- Make `MultiOptimizer` use `apply_gradient` instead of `apply_param_gradient`.
0.3.3
------
Possible breaking changes:
- Bug Fix: Disallow modifying attributes in Modules after they are initialized.
- Raise an error when saving a checkpoint which has a smaller step than the
latest checkpoint already saved.
- MultiOptimizer now rejects the case where multiple sub optimizers update the
same parameter.
Other changes:
- Added custom error classes to many Linen errors. See:
https://flax.readthedocs.io/en/latest/flax.errors.html
- Adds `Module.bind` for binding variables and RNGs to an interactive Module.
- Adds `nn.apply` and `nn.init` for transforming arbitrary functions that take a `linen.Module` as their first argument.
- Add option to overwrite existing checkpoints in `save_checkpoint`.
- Remove JAX omnistaging check for forward compatibility.
- Pathlib compatibility for checkpoint paths.
- `is_leaf` argument in `traverse_util.flatten_dict`
0.3.2
------
`flax.nn` deprecation message no longer appears if you import flax directly.
NOTE: You must now explicitly import `flax.nn` if you want to use the old
pre-Linen `flax.nn.Module`.
0.3.1
------
Many improvements to Linen, and the old `flax.nn` is officially deprecated!
Notably, there's a clean API for extracting intermediates from modules
defined using `@nn.compact`, a more ergonomic API for using Batch Norm and Dropout in modules
defined using `setup`, support for `MultiOptimizer` with Linen, and multiple safety, performance
and error message improvements.
Possible breaking changes:
- Call setup lazily. See #938 for motivation and more details.
- Linen `Module` instances are now frozen after `setup` has been called.
Previously mutations after setup could be dropped silently. Now the stateless requirement
is enforced by raising a TypeError in `__setattr__` after `setup`.
- Pytrees of dicts and lists are transformed into FrozenDict and tuples during
attribute assignment.
This avoids undetected submodules and inner state.
- Bug Fix `flax.core.apply` and `Module.apply`. Now it returns a tuple
containing the output and a frozen empty
collection when `mutable` is specified as an empty list.
- `broadcast_dims` is now an attribute to `Dropout` instead of a `__call__`
argument.
- `use_running_average` and `deterministic` no longer have a default. They
should be passed explicitly
- Bug Fix `Scope.variable` mutability check, before a variable could only be
initialized if the 'params' collection was mutable.
Other Improvements:
- Re-introduced the `lm1b` language modeling example
- Recognizes batch free inputs in pooling layers. (for use with vmap)
- Add Adadelta optimizer
- Fully deprecate all "pre-Linen" `flax.nn` classes and methods.
- Some Module arguments can now be passed either as dataclass attribute or
as argument to `__call__`. See [design note](https://flax.readthedocs.io/en/latest/guides/arguments.html)
- Add `sow` method to `Module` and `capture_intermediates` argument to `Module.apply`.
See [howto](https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html) for usage patterns.
- Support passing in modules directly as attributes to other modules, and
deal with them correctly both in top-level modules and in submodules.
- Don't require the `variable` argument to `Module.apply` to be a FrozenDict
- Add support for dict/FrozenDict when using `ModelParamTraversal`
As a result `MultiOptimizer` can be used properly with linen modules.
- Added OptimizedLSTM: ~33% faster than the original LSTM when using <=1024 units
- Fix dtype handling for Adam and LAMB optimizers in 64bit mode.
- Added `is_mutable()` method to `Variable` and `is_mutable_collection()` to `flax.linen.Module`.
- Add `axis_name` arg to `flax.linen.vmap`
- Enable broadcast in `flax.linen.scan`
- Fix behavior when inner module classes were defined in another module
- Add automatic giant array chunking in msgpack checkpoints.
- Log info message when a checkpoint is not found in the directory.
v0.3
-----
Linen is now out of Alpha (flax.nn is being deprecated)!
- `flax.core.apply` and linen `Module.apply` will now only return the variables
collections that were specified as mutable.
- Fixed handling of multiple separate subclasses of a Module.
- We now allow assignment of mixed Module pytrees in setup.
- Refactored collection creation to fail early when modifying an undefined collection as
before an non-existing non-mutable collection would just be silently ignored.
- Added the silu activation function.
- Add offset argument to Adafactor optimizer for fine-tuning schedules.
- Relaxed limit on calling methods on unbound modules.
- Relaxed parameter attribute check
- Added centered version of RMSProp.
- Added GCE getting started kit.
- Renamed -gpu_type to -accelerator_type.
- Fixed bug in MultiOptimizer causing it to throw away empty dictionary
### Improvements
- Made FrozenDict constructor freeze correctly.
- Made freeze a synonym of the FrozenDict constructor
- Optimize freezing FrozenDicts by sharing immutable internal state.
- We simplified __setattr__ handling of trees with Modules.
- Minor improvements in dtype handling, broadcast option for dropout.
- Added a dtype specification to Embed layer, made Adafactor use float32
state consistently, and added a broadcasting option to the Dropout layer.
- Improved frozen dict performance.
- (Massive) docs improvements
- End to end benchmarks added.
- Examples were updated to Linen.
v0.2.2
----
- Added Reinforcement Learning example (examples/ppo).
- Fix Adafactor bug that prevented factorization.
- Fix scan broadcast issue in functional core.
- Fix initialization RNGs to work with omnistaging for jitted inits.
- Replaces usage of 'param' kind to 'params' collection.
- Fix LARS optimizer for zero param initialization.
- Added various examples in Linen API. See [README.md](https://github.com/google/flax/blob/main/flax/linen/README.md) for more information.
- Full JAX omnistaging compatibility.
v0.2
----
- Added JAX trace-level checks for transforms.
- BatchNorm added axis_index_groups for control in parallel training.
- Optimizers broken out into separate directory with base class and implementations.
- traverse_util added flatten_dict and unflatten_dict utility methods for nested dicts.
v0.1
----
### API Changes
- Add ConvTranspose Module to nn.linear
- Rename the following optional arguments to nn.linear.Conv:
`lhs_dilation` -> `input_dilation`,
`rhs_dilation` -> `kernel_dilation`
- Change default layer names from numbers '0', '1', etc. to
include the Module class name, e.g. 'Dense_0', 'LayerNorm_1'.
================================================
FILE: LICENSE
================================================
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<div align="center">
<img src="https://raw.githubusercontent.com/google/flax/main/images/flax_logo_250px.png" alt="logo"></img>
</div>
# Flax: A neural network library and ecosystem for JAX designed for flexibility
[](https://github.com/google/flax/actions/workflows/flax_test.yml)
[](https://pypi.org/project/flax/)
[**Overview**](#overview)
| [**Quick install**](#quick-install)
| [**What does Flax look like?**](#what-does-flax-look-like)
| [**Documentation**](https://flax.readthedocs.io/)
Released in 2024, Flax NNX is a new simplified Flax API that is designed to make
it easier to create, inspect, debug, and analyze neural networks in
[JAX](https://jax.readthedocs.io/). It achieves this by adding first class support
for Python reference semantics. This allows users to express their models using
regular Python objects, enabling reference sharing and mutability.
Flax NNX evolved from the [Flax Linen API](https://flax-linen.readthedocs.io/), which
was released in 2020 by engineers and researchers at Google Brain in close collaboration
with the JAX team.
You can learn more about Flax NNX on the [dedicated Flax documentation site](https://flax.readthedocs.io/). Make sure you check out:
* [Flax NNX basics](https://flax.readthedocs.io/en/latest/nnx_basics.html)
* [MNIST tutorial](https://flax.readthedocs.io/en/latest/mnist_tutorial.html)
* [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html)
* [Evolution from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html)
**Note:** Flax Linen's [documentation has its own site](https://flax-linen.readthedocs.io/).
The Flax team's mission is to serve the growing JAX neural network
research ecosystem - both within Alphabet and with the broader community,
and to explore the use-cases where JAX shines. We use GitHub for almost
all of our coordination and planning, as well as where we discuss
upcoming design changes. We welcome feedback on any of our discussion,
issue and pull request threads.
You can make feature requests, let us know what you are working on,
report issues, ask questions in our [Flax GitHub discussion
forum](https://github.com/google/flax/discussions).
We expect to improve Flax, but we don't anticipate significant
breaking changes to the core API. We use [Changelog](https://github.com/google/flax/tree/main/CHANGELOG.md)
entries and deprecation warnings when possible.
In case you want to reach us directly, we're at flax-dev@google.com.
## Overview
Flax is a high-performance neural network library and ecosystem for
JAX that is **designed for flexibility**:
Try new forms of training by forking an example and by modifying the training
loop, not adding features to a framework.
Flax is being developed in close collaboration with the JAX team and
comes with everything you need to start your research, including:
* **Neural network API** (`flax.nnx`): Including [`Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear), [`Conv`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv), [`BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm), [`LayerNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm), [`GroupNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.GroupNorm), [Attention](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html) ([`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html#flax.nnx.MultiHeadAttention)), [`LSTMCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.LSTMCell), [`GRUCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.GRUCell), [`Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout).
* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device.
* **Educational examples**: [MNIST](https://flax.readthedocs.io/en/latest/mnist_tutorial.html), [Inference/sampling with the Gemma language model (transformer)](https://github.com/google/flax/tree/main/examples/gemma).
## Quick install
Flax uses JAX, so do check out [JAX installation instructions on CPUs, GPUs and TPUs](https://jax.readthedocs.io/en/latest/installation.html).
You will need Python 3.8 or later. Install Flax from PyPi:
```
pip install flax
```
To upgrade to the latest version of Flax, you can use:
```
pip install --upgrade git+https://github.com/google/flax.git
```
To install some additional dependencies (like `matplotlib`) that are required but not included
by some dependencies, you can use:
```bash
pip install "flax[all]"
```
## What does Flax look like?
We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.
To learn more about the `Module` abstraction, check out our [docs](https://flax.readthedocs.io/), our [broad intro to the Module abstraction](https://github.com/google/flax/blob/main/docs/linen_intro.ipynb). For additional concrete demonstrations of best practices, refer to our
[guides](https://flax.readthedocs.io/en/latest/guides/index.html) and
[developer notes](https://flax.readthedocs.io/en/latest/developer_notes/index.html).
Example of an MLP:
```py
class MLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x: jax.Array):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
```
Example of a CNN:
```py
class CNN(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def __call__(self, x):
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x
```
Example of an autoencoder:
```py
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)
class AutoEncoder(nnx.Module):
def __init__(self, rngs):
self.encoder = Encoder(rngs)
self.decoder = Decoder(rngs)
def __call__(self, x) -> jax.Array:
return self.decoder(self.encoder(x))
def encode(self, x) -> jax.Array:
return self.encoder(x)
```
## Citing Flax
To cite this repository:
```
@software{flax2020github,
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.12.6},
year = {2024},
}
```
In the above bibtex entry, names are in alphabetical order, the version number
is intended to be that from [flax/version.py](https://github.com/google/flax/blob/main/flax/version.py), and the year corresponds to the project's open-source release.
## Note
Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product.
================================================
FILE: benchmarks/README.md
================================================
# Benchmarks
These are mini benchmarks to measure the performance of NNX operations.
Sample profile command:
```shell
python -m cProfile -o ~/tmp/overhead.prof benchmarks/nnx_graph_overhead.py --mode=nnx --depth=100 --total_steps=1000
```
Sample profile inspection:
```shell
snakeviz ~/tmp/overhead.prof
```
================================================
FILE: benchmarks/nnx_graph_overhead.py
================================================
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
import numpy as np
import optax
from time import time
from flax import nnx
from absl import flags
from absl import app
FLAGS = flags.FLAGS
flags.DEFINE_enum(
'mode', 'nnx', ['all', 'nnx', 'jax'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 100, 'Total number of training steps')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
flags.DEFINE_integer('depth', 5, 'Depth of the model')
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
def __call__(self, x):
return x @ self.w + self.b
class Block(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))
class Count(nnx.Variable):
pass
class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = nnx.List([
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
])
self.linear_out = Block(dhidden, dout, rngs=rngs)
def __call__(self, x):
self.count.value += 1
x = nnx.relu(self.linear_in(x))
for layer in self.intermediates:
x = nnx.relu(layer(x))
x = self.linear_out(x)
return x
def main(argv):
print(argv)
mode: str = FLAGS.mode
total_steps: int = FLAGS.total_steps
width: int = FLAGS.width
depth: int = FLAGS.depth
print(f'{mode=}, {total_steps=}, {width=}')
X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
# ------------------------------------------------------------
# NNX
# ------------------------------------------------------------
if mode in ['all', 'nnx']:
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
t0 = time()
@nnx.jit
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
pass
t0 = time()
for _ in range(total_steps):
step_nnx(model, optimizer)
total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth
print('### NNX ###')
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')
# ------------------------------------------------------------
# JAX
# ------------------------------------------------------------
if mode in ['all', 'jax']:
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
t0 = time()
@jax.jit
def step_jax(graphdef, state):
return graphdef, state
graphdef, state = nnx.split((model, optimizer))
t0 = time()
for _ in range(total_steps):
graphdef, state = step_jax(graphdef, state)
total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth
print('### JAX ###')
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')
print()
if __name__ == '__main__':
app.run(main)
================================================
FILE: benchmarks/nnx_mlpmixer_training.py
================================================
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# %%
from functools import partial
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import numpy as np
from einop import einop
from time import time
from tqdm import tqdm
from flax import nnx
from absl import flags
from absl import app
FLAGS = flags.FLAGS
flags.DEFINE_enum(
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
flags.DEFINE_integer('depth', 4, 'Depth of the model')
class MlpBlock(nnx.Module):
def __init__(self, din: int, mlp_dim: int, rngs: nnx.Rngs):
self.din, self.mlp_dim = din, mlp_dim
self.linear_in = nnx.Linear(din, mlp_dim, rngs=rngs)
self.linear_out = nnx.Linear(mlp_dim, din, rngs=rngs)
def __call__(self, x):
return self.linear_out(nnx.gelu(self.linear_in(x)))
class MixerBlock(nnx.Module):
def __init__(
self,
tokens_mlp_dim: int,
channels_mlp_dim: int,
hidden_dim: int,
rngs: nnx.Rngs,
):
self.tokens_mlp_dim = tokens_mlp_dim
self.channels_mlp_dim = channels_mlp_dim
self.hidden_dim = hidden_dim
self.token_mixing = MlpBlock(tokens_mlp_dim, hidden_dim, rngs=rngs)
self.channel_mixing = MlpBlock(channels_mlp_dim, hidden_dim, rngs=rngs)
self.ln1 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
self.ln2 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
def __call__(self, x):
y = self.ln1(x)
y = y.swapaxes(1, 2)
y = self.token_mixing(y)
y = y.swapaxes(1, 2)
x = x + y
y = self.ln2(x)
return x + self.channel_mixing(y)
class MlpMixer(nnx.Module):
def __init__(
self,
din: int,
kernel_size: tuple[int, int],
strides: tuple[int, int],
num_blocks: int,
hidden_dim: int,
tokens_mlp_dim: int,
channels_mlp_dim: int,
rngs: nnx.Rngs,
):
self.din = din
self.kernel_size = kernel_size
self.num_blocks = num_blocks
self.hidden_dim = hidden_dim
self.tokens_mlp_dim = tokens_mlp_dim
self.channels_mlp_dim = channels_mlp_dim
self.stem = nnx.Conv(
din + 1,
channels_mlp_dim,
kernel_size=kernel_size,
strides=strides,
rngs=rngs,
)
self.blocks = [
MixerBlock(tokens_mlp_dim, channels_mlp_dim, hidden_dim, rngs=rngs)
for _ in range(num_blocks)
]
self.pre_head_layer_norm = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
self.conv_t = nnx.ConvTranspose(
channels_mlp_dim, din, kernel_size=kernel_size, strides=strides, rngs=rngs
)
def __call__(self, *, x, t):
# add time feature to input
t = einop(t, 'n -> n h w c', h=x.shape[1], w=x.shape[2], c=1)
x = jnp.concatenate([x, t], axis=-1)
# create patches
x = self.stem(x)
h, w = x.shape[1], x.shape[2]
x = einop(x, 'n h w c -> n (h w) c')
# apply blocks
for block in self.blocks:
x = block(x)
x = self.pre_head_layer_norm(x)
# recreate image
x = einop(x, 'n (h w) c -> n h w c', h=h, w=w)
x = self.conv_t(x)
return x
def main(argv):
print(argv)
mode: str = FLAGS.mode
total_steps: int = FLAGS.total_steps
batch_size: int = FLAGS.batch_size
width: int = FLAGS.width
depth: int = FLAGS.depth
print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')
X = np.random.uniform(size=(batch_size, 28, 28, 1))
if mode == 'nnx' or mode == 'all':
rngs = nnx.Rngs(0)
flow = MlpMixer(
din=1,
kernel_size=(2, 2),
strides=(2, 2),
num_blocks=4,
hidden_dim=512,
tokens_mlp_dim=196,
channels_mlp_dim=512,
rngs=rngs,
)
optimizer = nnx.Optimizer(
flow, tx=optax.adamw(1e-4), wrt=nnx.Param
)
t0 = time()
mse = lambda a, b: jnp.mean((a - b) ** 2)
@nnx.jit(donate_argnums=(0, 1, 2))
def train_step_nnx(flow, optimizer, rngs, x_1):
print('JITTING NNX')
x_0 = jax.random.normal(rngs(), x_1.shape)
t = jax.random.uniform(rngs(), (len(x_1),))
x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t)
dx_t = x_1 - x_0
loss, grads = nnx.value_and_grad(
lambda flow: mse(flow(x=x_t, t=t), dx_t)
)(flow)
optimizer.update(flow, grads)
return loss
losses = []
t0 = time()
for step in tqdm(range(total_steps), desc='NNX'):
loss = train_step_nnx(flow, optimizer, rngs, X)
losses.append(loss)
total_time = time() - t0
print('### NNX ###')
print(f'final loss: {losses[-1]}')
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
if mode == 'jax' or mode == 'all':
rngs = nnx.Rngs(0)
flow = MlpMixer(
din=1,
kernel_size=(2, 2),
strides=(2, 2),
num_blocks=depth,
hidden_dim=width,
tokens_mlp_dim=196,
channels_mlp_dim=width,
rngs=rngs,
)
optimizer = nnx.Optimizer(
flow, tx=optax.adamw(1e-4), wrt=nnx.Param
)
graphdef, state = nnx.split((flow, optimizer, rngs))
t0 = time()
mse = lambda a, b: jnp.mean((a - b) ** 2)
@partial(nnx.jit, donate_argnums=0)
def train_step_jax(state, x_1):
print('JITTING JAX')
flow, optimizer, rngs = nnx.merge(graphdef, state)
x_0 = jax.random.normal(rngs(), x_1.shape)
t = jax.random.uniform(rngs(), (len(x_1),))
x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t)
dx_t = x_1 - x_0
loss, grads = nnx.value_and_grad(
lambda flow: mse(flow(x=x_t, t=t), dx_t)
)(flow)
optimizer.update(flow, grads)
state = nnx.state((flow, optimizer, rngs))
return loss, state
losses = []
t0 = time()
for step in tqdm(range(total_steps), desc='JAX'):
loss, state = train_step_jax(state, X)
losses.append(loss)
nnx.update((flow, optimizer, rngs), state)
total_time = time() - t0
print('### JAX ###')
print(f'final loss: {losses[-1]}')
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
if __name__ == '__main__':
app.run(main)
================================================
FILE: benchmarks/nnx_simple_training.py
================================================
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# %%
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import optax
from time import time
from flax import nnx
from absl import flags
from absl import app
FLAGS = flags.FLAGS
flags.DEFINE_enum(
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
flags.DEFINE_integer('depth', 5, 'Depth of the model')
def dataset(X, Y, batch_size):
while True:
idx = np.random.choice(len(X), size=batch_size)
yield X[idx], Y[idx]
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
def __call__(self, x):
return x @ self.w + self.b
class Block(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))
class Count(nnx.Variable):
pass
class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
self.linear_out = Block(dhidden, dout, rngs=rngs)
def __call__(self, x):
self.count.value += 1
x = nnx.relu(self.linear_in(x))
for layer in self.intermediates:
x = nnx.relu(layer(x))
x = self.linear_out(x)
return x
def main(argv):
print(argv)
mode: str = FLAGS.mode
total_steps: int = FLAGS.total_steps
batch_size: int = FLAGS.batch_size
width: int = FLAGS.width
depth: int = FLAGS.depth
print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')
X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
if mode == 'nnx' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
t0 = time()
@nnx.jit(donate_argnums=(0, 1))
def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch):
x, y = batch
def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)
grads: nnx.State = nnx.grad(loss_fn)(model)
optimizer.update(model, grads)
@nnx.jit(donate_argnums=0)
def test_step_nnx(model: MLP, batch):
x, y = batch
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}
for step, batch in enumerate(dataset(X, Y, batch_size)):
train_step_nnx(model, optimizer, batch)
if step % 1000 == 0:
logs = test_step_nnx(model, (X, Y))
if step >= total_steps - 1:
break
print('### NNX ###')
print(f'final loss: {logs["loss"]}')
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)
if mode == 'jax' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
t0 = time()
@partial(jax.jit, donate_argnums=0)
def train_step_jax(state, batch):
model, optimizer = nnx.merge(graphdef, state)
x, y = batch
def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)
grads = nnx.grad(loss_fn)(model)
optimizer.update(model,grads)
return nnx.state((model, optimizer))
@partial(jax.jit, donate_argnums=0)
def test_step_jax(state, batch):
model, optimizer = nnx.merge(graphdef, state)
x, y = batch
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
state = nnx.state((model, optimizer))
return state, {'loss': loss}
graphdef, state = nnx.split((model, optimizer))
for step, batch in enumerate(dataset(X, Y, batch_size)):
state = train_step_jax(state, batch)
if step % 1000 == 0:
state, logs = test_step_jax(state, (X, Y))
if step >= total_steps - 1:
break
model, optimizer = nnx.merge(graphdef, state)
print('### JAX ###')
print(f'final loss: {logs["loss"]}')
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)
if __name__ == '__main__':
app.run(main)
================================================
FILE: benchmarks/nnx_state_traversal.py
================================================
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example profile command:
# python -m cProfile -o ~/tmp/overhead.prof benchmarks/nnx_graph_overhead.py --mode=nnx --depth=100 --total_steps=1000
# View profile (need to install snakeviz):
# snakeviz ~/tmp/overhead.prof
import jax
from time import time
from flax import nnx
from absl import flags
from absl import app
FLAGS = flags.FLAGS
flags.DEFINE_integer('total_steps', 1000, 'Total number of training steps')
flags.DEFINE_integer('width', 4, 'Width of each level')
flags.DEFINE_integer('depth', 4, 'Depth of the model')
class NestedClass(nnx.Module):
def __init__(self, width, depth):
self.x = nnx.Variable(jax.numpy.ones((depth+1, )))
if depth > 0:
for i in range(width):
setattr(self, f'child{i}', NestedClass(width, depth-1))
def main(argv):
print(argv)
total_steps: int = FLAGS.total_steps
width: int = FLAGS.width
depth: int = FLAGS.depth
model = NestedClass(width, depth)
to_test = nnx.state(model)
print(f'{total_steps=}, {width=}')
#------------------------------------------------------------
# tree_flatten_with_path
#------------------------------------------------------------
t0 = time()
for _ in range(total_steps):
jax.tree_util.tree_flatten_with_path(to_test)
total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth
print("### tree_flatten_with_path ###")
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')
#------------------------------------------------------------
# tree_map_with_path
#------------------------------------------------------------
t0 = time()
for _ in range(total_steps):
jax.tree_util.tree_map_with_path(lambda _, x: x, to_test)
total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth
print("### tree_map_with_path ###")
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')
#------------------------------------------------------------
# tree_flatten
#------------------------------------------------------------
t0 = time()
for _ in range(total_steps):
jax.tree_util.tree_flatten(to_test)
total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth
print("### tree_flatten ###")
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')
if __name__ == '__main__':
app.run(main)
================================================
FILE: benchmarks/tracing/README.md
================================================
# Tracing and lowering benchmarks for Flax examples
See Flax
[documentation](https://flax.readthedocs.io/en/latest/examples/index.html) on
their examples.
## Getting started
bash
```
pip install -r benchmarks/tracing/requirements.txt
# Benchmark trace and lower timing for all workloads.
python tracing_benchmark.py
# Profile a single example.
python tracing_benchmark.py --example=wmt
# Profile just tracing for a single example.
python tracing_benchmark.py --example=wmt --mode=trace
```
================================================
FILE: benchmarks/tracing/__init__.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: benchmarks/tracing/gemma.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma helper functions."""
from typing import Any
from flax import nnx
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.gemma import transformer as transformer_lib
from flax.examples.gemma import utils
from flax.examples.gemma.configs import default as gemma_config
from flax.training import common_utils
import google_benchmark
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
def rsqrt_schedule(init_value: float, shift: int = 0):
def schedule(count):
return init_value * (count + shift) ** -0.5 * shift**0.5
return schedule
def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
return optax.join_schedules(
[
optax.linear_schedule(
init_value=0,
end_value=learning_rate,
transition_steps=warmup_steps,
),
rsqrt_schedule(init_value=learning_rate, shift=warmup_steps),
],
boundaries=[warmup_steps],
)
def compute_weighted_cross_entropy(
logits, targets, weights=None, label_smoothing=0.0
):
if logits.ndim != targets.ndim + 1:
raise ValueError(
'Incorrect shapes. Got shape %s logits and %s targets'
% (str(logits.shape), str(targets.shape))
)
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence)
+ (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_targets = common_utils.onehot(
targets, vocab_size, on_value=confidence, off_value=low_confidence
)
loss = -jnp.sum(soft_targets * nnx.log_softmax(logits), axis=-1)
loss = loss - normalizing_constant
normalizing_factor = np.prod(targets.shape)
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def compute_weighted_accuracy(logits, targets, weights=None):
if logits.ndim != targets.ndim + 1:
raise ValueError(
'Incorrect shapes. Got shape %s logits and %s targets'
% (str(logits.shape), str(targets.shape))
)
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
normalizing_factor = np.prod(logits.shape[:-1])
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def compute_metrics(logits, labels, weights, label_smoothing=0.0):
loss, weight_sum = compute_weighted_cross_entropy(
logits, labels, weights, label_smoothing
)
acc, _ = compute_weighted_accuracy(logits, labels, weights)
metrics = {
'loss': loss,
'accuracy': acc,
'denominator': weight_sum,
}
return metrics
def train_step(
state: utils.TrainState,
batch,
learning_rate_fn,
label_smoothing=0.0,
):
train_keys = ['inputs', 'inputs_position', 'inputs_segmentation', 'targets']
(inputs, inputs_positions, inputs_segmentation, targets) = (
batch.get(k, None) for k in train_keys
)
pad_id = 0
weights = jnp.where(inputs > pad_id, 1, 0).astype(jnp.float32)
input_mask = inputs > pad_id
attention_mask = transformer_lib.make_causal_attn_mask(input_mask)
mask = (
inputs_segmentation[:, :, None] == inputs_segmentation[:, None, :]
)
attention_mask = jnp.logical_and(mask, attention_mask)
def loss_fn(params):
module = nnx.merge(state.graphdef, params)
logits, _ = module(
inputs,
positions=inputs_positions,
attention_mask=attention_mask,
cache=None,
)
loss, weight_sum = compute_weighted_cross_entropy(
logits, targets, weights, label_smoothing
)
mean_loss = loss / weight_sum
return mean_loss, logits
step = state.step
lr = learning_rate_fn(step)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits, targets, weights)
metrics['learning_rate'] = lr
return new_state, metrics
def get_fake_batch(batch_size: int) -> Any:
rng = jax.random.PRNGKey(0)
batch = {}
for k in (
'inputs',
'inputs_position',
'inputs_segmentation',
'targets',
'targets_position',
'targets_segmentation',
):
batch[k] = jax.random.randint(rng, (batch_size, 128), 0, 9999999, jnp.int32)
return batch
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
vocab_size: int | None = None,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
if vocab_size is None:
vocab_size = config.vocab_size
if config.transformer_name is not None:
model_config = transformer_lib.TransformerConfig.from_version_name(
config.transformer_name,
num_embed=vocab_size,
dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
axis_rules=config.axis_rules,
)
else:
assert config.transformer_params is not None
model_config = transformer_lib.TransformerConfig.from_dict(
**config.transformer_params,
num_embed=vocab_size,
dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
axis_rules=config.axis_rules,
)
devices_array = utils.create_device_mesh(config)
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
rng = jax.random.PRNGKey(config.seed)
rng, init_rng = jax.random.split(rng)
def constructor(config: transformer_lib.TransformerConfig, key: jax.Array):
return transformer_lib.Transformer(config, rngs=nnx.Rngs(params=key))
learning_rate_fn = create_learning_rate_schedule(
learning_rate=config.learning_rate, warmup_steps=config.warmup_steps
)
optimizer = optax.adamw(
learning_rate_fn,
b1=0.9,
b2=0.98,
eps=1e-9,
weight_decay=config.weight_decay,
)
state, state_sharding = utils.setup_initial_state(
constructor, optimizer, model_config, init_rng, mesh
)
data_sharding = jax.NamedSharding(mesh, jax.P(config.data_sharding))
jit_train_step = jax.jit(
train_step,
in_shardings=(
state_sharding,
data_sharding,
), # type: ignore
out_shardings=(state_sharding, None), # type: ignore
static_argnames=('learning_rate_fn', 'label_smoothing'),
donate_argnums=0,
)
batch = get_fake_batch(config.per_device_batch_size)
batch = jax.tree.map(lambda x: jnp.asarray(x, device=data_sharding), batch)
return (
jit_train_step,
(state, batch, learning_rate_fn, 0.0),
dict(),
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_gemma_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, gemma_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_gemma_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, gemma_config.get_config, state
)
if __name__ == '__main__':
tracing_benchmark.run_benchmarks()
================================================
FILE: benchmarks/tracing/imagenet.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ImageNet helper functions for benchmarking."""
import functools
from typing import Any
from flax import linen as nn
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.imagenet import models
from flax.examples.imagenet.configs import default as imagenet_config
from flax.training import common_utils
from flax.training import dynamic_scale as dynamic_scale_lib
from flax.training import train_state
import google_benchmark
import jax
import jax.numpy as jnp
import ml_collections
import optax
NUM_CLASSES = 1000
class TrainState(train_state.TrainState):
batch_stats: Any
dynamic_scale: dynamic_scale_lib.DynamicScale
def create_model(*, model_cls, half_precision, **kwargs):
platform = jax.local_devices()[0].platform
if half_precision:
if platform == 'tpu':
model_dtype = jnp.bfloat16
else:
model_dtype = jnp.float16
else:
model_dtype = jnp.float32
return model_cls(num_classes=NUM_CLASSES, dtype=model_dtype, **kwargs)
def initialized(key, image_size, model):
input_shape = (1, image_size, image_size, 3)
@jax.jit
def init(*args):
return model.init(*args)
variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
return variables['params'], variables['batch_stats']
def cross_entropy_loss(logits, labels):
one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES)
xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
return jnp.mean(xentropy)
def create_train_state(
rng, config: ml_collections.ConfigDict, model, image_size, learning_rate_fn
):
dynamic_scale = None
platform = jax.local_devices()[0].platform
if config.half_precision and platform == 'gpu':
dynamic_scale = dynamic_scale_lib.DynamicScale()
params, batch_stats = initialized(rng, image_size, model)
tx = optax.sgd(
learning_rate=learning_rate_fn,
momentum=config.momentum,
nesterov=True,
)
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx,
batch_stats=batch_stats,
dynamic_scale=dynamic_scale,
)
return state
def get_fake_batch(batch_size: int = 128) -> dict[str, jnp.ndarray]:
images = jax.random.uniform(
jax.random.key(0), (batch_size, 224, 224, 3), dtype=jnp.float32
)
labels = jax.random.randint(
jax.random.key(1), (batch_size,), minval=0, maxval=1000, dtype=jnp.int32
)
return {'image': images, 'label': labels}
class BenchmarkResNet(models.ResNet):
@nn.compact
def __call__(self, x, train: bool = True):
conv = functools.partial(self.conv, use_bias=False, dtype=self.dtype)
norm = functools.partial(
nn.BatchNorm,
use_running_average=not train,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype,
axis_name=None,
)
x = conv(
self.num_filters,
(7, 7),
(2, 2),
padding=[(3, 3), (3, 3)],
name='conv_init',
)(x)
x = norm(name='bn_init')(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(
self.num_filters * 2**i,
strides=strides,
conv=conv,
norm=norm,
act=self.act,
)(x)
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
x = jnp.asarray(x, self.dtype)
return x
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
if config.model == 'ResNet50':
model_cls = functools.partial(
BenchmarkResNet,
stage_sizes=[3, 4, 6, 3],
block_cls=models.BottleneckResNetBlock,
)
else:
model_cls = getattr(models, config.model)
model = create_model(
model_cls=model_cls, half_precision=config.half_precision
)
learning_rate_fn = lambda step: 0.1
rng = jax.random.key(0)
image_size = 224
state = create_train_state(
rng, config, model, image_size, learning_rate_fn
)
batch = get_fake_batch(config.batch_size)
return (
bench_train_step,
(state, batch, learning_rate_fn),
{},
)
@functools.partial(jax.jit, static_argnums=(2,))
def bench_train_step(state, batch, learning_rate_fn):
def compute_metrics(logits, labels):
loss = cross_entropy_loss(logits, labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
def loss_fn(params):
logits, new_model_state = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
batch['image'],
mutable=['batch_stats'],
)
loss = cross_entropy_loss(logits, batch['label'])
weight_penalty_params = jax.tree_util.tree_leaves(params)
weight_decay = 0.0001
weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1)
weight_penalty = weight_decay * 0.5 * weight_l2
loss = loss + weight_penalty
return loss, (new_model_state, logits)
step = state.step
dynamic_scale = state.dynamic_scale
lr = learning_rate_fn(step)
if dynamic_scale:
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
dynamic_scale, is_fin, aux, grads = grad_fn(state.params)
else:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
aux, grads = grad_fn(state.params)
new_model_state, logits = aux[1]
metrics = compute_metrics(logits, batch['label'])
metrics['learning_rate'] = lr
new_state = state.apply_gradients(
grads=grads,
batch_stats=new_model_state['batch_stats'],
)
if dynamic_scale:
new_state = new_state.replace(
opt_state=jax.tree_util.tree_map(
functools.partial(jnp.where, is_fin),
new_state.opt_state,
state.opt_state,
),
params=jax.tree_util.tree_map(
functools.partial(jnp.where, is_fin), new_state.params, state.params
),
dynamic_scale=dynamic_scale,
)
metrics['scale'] = dynamic_scale.scale
return new_state, metrics
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_imagenet_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, imagenet_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_imagenet_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, imagenet_config.get_config, state
)
if __name__ == '__main__':
tracing_benchmark.run_benchmarks()
================================================
FILE: benchmarks/tracing/lm1b.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LM1B helper functions for benchmarking."""
import functools
from typing import Any
from flax import linen as nn
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.lm1b import models
from flax.examples.lm1b.configs import default as lm1b_config
from flax.training import common_utils
from flax.training import train_state
import google_benchmark
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
def rsqrt_schedule(init_value: float, shift: int = 0):
def schedule(count):
return init_value * (count + shift) ** -0.5 * shift**0.5
return schedule
def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
return optax.join_schedules(
[
optax.linear_schedule(
init_value=0,
end_value=learning_rate,
transition_steps=warmup_steps,
),
rsqrt_schedule(init_value=learning_rate, shift=warmup_steps),
],
boundaries=[warmup_steps],
)
def compute_weighted_cross_entropy(
logits, targets, weights=None, label_smoothing=0.0
):
if logits.ndim != targets.ndim + 1:
raise ValueError(
'Incorrect shapes. Got shape %s logits and %s targets'
% (str(logits.shape), str(targets.shape))
)
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence)
+ (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_targets = common_utils.onehot(
targets, vocab_size, on_value=confidence, off_value=low_confidence
)
loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
loss = loss - normalizing_constant
normalizing_factor = np.prod(targets.shape)
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def compute_weighted_accuracy(logits, targets, weights=None):
if logits.ndim != targets.ndim + 1:
raise ValueError(
'Incorrect shapes. Got shape %s logits and %s targets'
% (str(logits.shape), str(targets.shape))
)
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
normalizing_factor = np.prod(logits.shape[:-1])
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, jnp.ndarray]:
batch_size = config.per_device_batch_size
max_len = config.max_target_length
inputs = jax.random.randint(
jax.random.key(0),
(batch_size, max_len),
minval=0,
maxval=config.vocab_size,
dtype=jnp.int32,
)
inputs_position = jnp.tile(
jnp.arange(max_len, dtype=jnp.int32), (batch_size, 1)
)
inputs_segmentation = jnp.ones((batch_size, max_len), dtype=jnp.int32)
return {
'inputs': inputs,
'inputs_position': inputs_position,
'inputs_segmentation': inputs_segmentation,
}
@functools.partial(jax.jit, static_argnums=(2, 3))
def bench_train_step(state, batch, config, learning_rate_fn):
def compute_metrics(logits, labels, weights):
loss, weight_sum = compute_weighted_cross_entropy(
logits, labels, weights, 0.0
)
acc, _ = compute_weighted_accuracy(logits, labels, weights)
metrics = {
'loss': loss,
'accuracy': acc,
'denominator': weight_sum,
}
return metrics
inputs = batch['inputs']
inputs_positions = batch['inputs_position']
inputs_segmentation = batch['inputs_segmentation']
weights = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)
dropout_rng = jax.random.fold_in(jax.random.key(0), state.step)
def loss_fn(params):
logits = models.TransformerLM(config).apply(
{'params': params},
inputs,
inputs_positions=inputs_positions,
inputs_segmentation=inputs_segmentation,
rngs={'dropout': dropout_rng},
)
loss, weight_sum = compute_weighted_cross_entropy(
logits, inputs, weights, 0.0
)
mean_loss = loss / weight_sum
return mean_loss, logits
step = state.step
lr = learning_rate_fn(step)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits, inputs, weights)
metrics['learning_rate'] = lr
return new_state, metrics
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
train_config = models.TransformerConfig(
vocab_size=config.vocab_size,
output_vocab_size=config.vocab_size,
logits_via_embedding=config.logits_via_embedding,
dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
emb_dim=config.emb_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
qkv_dim=config.qkv_dim,
mlp_dim=config.mlp_dim,
max_len=max(config.max_target_length, config.max_eval_target_length),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
deterministic=False,
decode=False,
kernel_init=jax.nn.initializers.xavier_uniform(),
bias_init=jax.nn.initializers.normal(stddev=1e-6),
)
model = models.TransformerLM(train_config)
learning_rate_fn = create_learning_rate_schedule(
learning_rate=config.learning_rate, warmup_steps=config.warmup_steps
)
optimizer = optax.adamw(
learning_rate_fn,
b1=0.9,
b2=0.98,
eps=1e-9,
weight_decay=config.weight_decay,
)
rng = jax.random.key(0)
init_rng, _ = jax.random.split(rng)
initial_variables = model.init(
init_rng,
jnp.ones(
(config.per_device_batch_size, config.max_target_length), jnp.int32
),
jnp.ones(
(config.per_device_batch_size, config.max_target_length), jnp.int32
),
jnp.ones(
(config.per_device_batch_size, config.max_target_length), jnp.int32
),
)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=initial_variables['params'],
tx=optimizer,
)
batch = get_fake_batch(config)
return (
bench_train_step,
(state, batch, train_config, learning_rate_fn),
{},
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_lm1b_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, lm1b_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_lm1b_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, lm1b_config.get_config, state
)
if __name__ == '__main__':
tracing_benchmark.run_benchmarks()
================================================
FILE: benchmarks/tracing/mnist.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNIST helper functions."""
from functools import partial
from typing import Any
from flax import nnx
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.mnist.configs import default as mnist_config
import google_benchmark
import jax
import jax.numpy as jnp
import ml_collections
import optax
class CNN(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs)
self.dropout1 = nnx.Dropout(rate=0.025)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs)
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.dropout2 = nnx.Dropout(rate=0.025)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def __call__(self, x, rngs: nnx.Rngs):
x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x), rngs=rngs))))
x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x))))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.dropout2(self.linear1(x), rngs=rngs))
x = self.linear2(x)
return x
def loss_fn(model: CNN, batch, rngs):
logits = model(batch['image'], rngs)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']
).mean()
return loss, logits
def get_fake_batch(batch_size: int) -> dict[str, Any]:
rng = jax.random.PRNGKey(0)
images = jax.random.normal(rng, (batch_size, 28, 28, 1), jnp.float32)
labels = jax.random.randint(rng, (batch_size,), 0, 10, jnp.int32)
return {'image': images, 'label': labels}
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
model = CNN(rngs=nnx.Rngs(0))
batch = get_fake_batch(config.batch_size)
rngs = nnx.Rngs(0)
loss_fn_jit = jax.jit(loss_fn)
return (
loss_fn_jit,
(model, batch, rngs),
dict(),
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_mnist_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, mnist_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_mnist_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, mnist_config.get_config, state
)
if __name__ == '__main__':
tracing_benchmark.run_benchmarks()
================================================
FILE: benchmarks/tracing/nlp_seq.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NLP Sequence Tagging helper functions for benchmarking."""
import functools
from typing import Any
from flax import linen as nn
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.nlp_seq import models
from flax.examples.nlp_seq.configs import default as nlp_seq_config
from flax.training import common_utils
from flax.training import train_state
import google_benchmark
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
def create_learning_rate_scheduler(
factors='constant * linear_warmup * rsqrt_decay',
base_learning_rate=0.5,
warmup_steps=8000,
decay_factor=0.5,
steps_per_decay=20000,
steps_per_cycle=100000,
):
factors = [n.strip() for n in factors.split('*')]
def step_fn(step):
ret = 1.0
for name in factors:
if name == 'constant':
ret *= base_learning_rate
elif name == 'linear_warmup':
ret *= jnp.minimum(1.0, step / warmup_steps)
elif name == 'rsqrt_decay':
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
elif name == 'rsqrt_normalized_decay':
ret *= jnp.sqrt(warmup_steps)
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
elif name == 'decay_every':
ret *= decay_factor ** (step // steps_per_decay)
elif name == 'cosine_decay':
progress = jnp.maximum(
0.0, (step - warmup_steps) / float(steps_per_cycle)
)
ret *= jnp.maximum(
0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))
)
else:
raise ValueError('Unknown factor %s.' % name)
return jnp.asarray(ret, dtype=jnp.float32)
return step_fn
def compute_weighted_cross_entropy(logits, targets, weights=None):
if logits.ndim != targets.ndim + 1:
raise ValueError(
'Incorrect shapes. Got shape %s logits and %s targets'
% (str(logits.shape), str(targets.shape))
)
onehot_targets = common_utils.onehot(targets, logits.shape[-1])
loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
normalizing_factor = onehot_targets.sum()
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def compute_weighted_accuracy(logits, targets, weights=None):
if logits.ndim != targets.ndim + 1:
raise ValueError(
'Incorrect shapes. Got shape %s logits and %s targets'
% (str(logits.shape), str(targets.shape))
)
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
normalizing_factor = np.prod(logits.shape[:-1])
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, jnp.ndarray]:
batch_size = config.batch_size
max_len = config.max_length
vocab_size = config.vocab_size
inputs = jax.random.randint(
jax.random.key(0),
(batch_size, max_len),
minval=0,
maxval=vocab_size,
dtype=jnp.int32,
)
targets = jax.random.randint(
jax.random.key(1),
(batch_size, max_len),
minval=0,
maxval=config.output_vocab_size,
dtype=jnp.int32,
)
return {
'inputs': inputs,
'targets': targets,
}
@functools.partial(jax.jit, static_argnums=(2, 3))
def bench_train_step(state, batch, config, learning_rate_fn):
def local_compute_metrics(logits, labels, weights):
loss, weight_sum = compute_weighted_cross_entropy(
logits, labels, weights
)
acc, _ = compute_weighted_accuracy(logits, labels, weights)
metrics = {
'loss': loss,
'accuracy': acc,
'denominator': weight_sum,
}
return metrics
inputs = batch['inputs']
targets = batch['targets']
weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32)
dropout_rng = jax.random.fold_in(jax.random.key(0), state.step)
def loss_fn(params):
model = models.Transformer(config)
logits = model.apply(
{'params': params},
inputs=inputs,
train=True,
rngs={'dropout': dropout_rng},
)
loss, weight_sum = compute_weighted_cross_entropy(
logits, targets, weights
)
mean_loss = loss / weight_sum
return mean_loss, logits
step = state.step
lr = learning_rate_fn(step)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
metrics = local_compute_metrics(logits, targets, weights)
metrics['learning_rate'] = lr
return new_state, metrics
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
if not hasattr(config, 'vocab_size'):
config.vocab_size = 30000
if not hasattr(config, 'output_vocab_size'):
config.output_vocab_size = 50
model_config = models.TransformerConfig(
vocab_size=config.vocab_size,
output_vocab_size=config.output_vocab_size,
max_len=config.max_length,
)
model = models.Transformer(model_config)
learning_rate_fn = create_learning_rate_scheduler(
base_learning_rate=config.learning_rate
)
optimizer = optax.adamw(
learning_rate_fn,
b1=0.9,
b2=0.98,
eps=1e-9,
weight_decay=config.weight_decay,
)
rng = jax.random.key(0)
init_rng, _ = jax.random.split(rng)
init_batch = jnp.ones((config.batch_size, config.max_length), jnp.int32)
initial_variables = model.init(init_rng, inputs=init_batch, train=False)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=initial_variables['params'],
tx=optimizer,
)
batch = get_fake_batch(config)
return (
bench_train_step,
(state, batch, model_config, learning_rate_fn),
{},
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_nlp_seq_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, nlp_seq_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_nlp_seq_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, nlp_seq_config.get_config, state
)
if __name__ == '__main__':
tracing_benchmark.run_benchmarks()
================================================
FILE: benchmarks/tracing/ogbg_molpcba.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OGBG-MolPCBA helper functions for benchmarking."""
from typing import Any
from clu import metrics
import flax
import flax.linen as nn
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.ogbg_molpcba import models
from flax.examples.ogbg_molpcba.configs import default as ogbg_config
from flax.training import train_state
import google_benchmark
import jax
import jax.numpy as jnp
import jraph
import ml_collections
import optax
def create_model(
config: ml_collections.ConfigDict, deterministic: bool
) -> nn.Module:
if config.model == 'GraphNet':
return models.GraphNet(
latent_size=config.latent_size,
num_mlp_layers=config.num_mlp_layers,
message_passing_steps=config.message_passing_steps,
output_globals_size=config.num_classes,
dropout_rate=config.dropout_rate,
skip_connections=config.skip_connections,
layer_norm=config.layer_norm,
use_edge_model=config.use_edge_model,
deterministic=deterministic,
)
if config.model == 'GraphConvNet':
return models.GraphConvNet(
latent_size=config.latent_size,
num_mlp_layers=config.num_mlp_layers,
message_passing_steps=config.message_passing_steps,
output_globals_size=config.num_classes,
dropout_rate=config.dropout_rate,
skip_connections=config.skip_connections,
layer_norm=config.layer_norm,
deterministic=deterministic,
)
raise ValueError(f'Unsupported model: {config.model}.')
def create_optimizer(
config: ml_collections.ConfigDict,
) -> optax.GradientTransformation:
if config.optimizer == 'adam':
return optax.adam(learning_rate=config.learning_rate)
if config.optimizer == 'sgd':
return optax.sgd(
learning_rate=config.learning_rate, momentum=config.momentum
)
raise ValueError(f'Unsupported optimizer: {config.optimizer}.')
def binary_cross_entropy_with_mask(
*, logits: jnp.ndarray, labels: jnp.ndarray, mask: jnp.ndarray
):
assert logits.shape == labels.shape == mask.shape
assert len(logits.shape) == 2
labels = jnp.where(mask, labels, -1)
positive_logits = logits >= 0
relu_logits = jnp.where(positive_logits, logits, 0)
abs_logits = jnp.where(positive_logits, logits, -logits)
return relu_logits - (logits * labels) + (jnp.log(1 + jnp.exp(-abs_logits)))
def predictions_match_labels(
*, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs
) -> jnp.ndarray:
del kwargs
preds = logits > 0
return (preds == labels).astype(jnp.float32)
def replace_globals(graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
return graphs._replace(globals=jnp.ones([graphs.n_node.shape[0], 1]))
def get_predicted_logits(
state: train_state.TrainState,
graphs: jraph.GraphsTuple,
rngs: dict[str, jnp.ndarray] | None,
) -> jnp.ndarray:
pred_graphs = state.apply_fn(state.params, graphs, rngs=rngs)
logits = pred_graphs.globals
return logits
def get_valid_mask(
labels: jnp.ndarray, graphs: jraph.GraphsTuple
) -> jnp.ndarray:
labels_mask = ~jnp.isnan(labels)
graph_mask = jraph.get_graph_padding_mask(graphs)
return labels_mask & graph_mask[:, None]
@flax.struct.dataclass
class TrainMetrics(metrics.Collection):
accuracy: metrics.Average.from_fun(predictions_match_labels)
loss: metrics.Average.from_output('loss')
@jax.jit
def ogbg_train_step(
state: train_state.TrainState,
graphs: jraph.GraphsTuple,
rngs: dict[str, jnp.ndarray],
) -> tuple[train_state.TrainState, metrics.Collection]:
def loss_fn(params, graphs):
curr_state = state.replace(params=params)
labels = graphs.globals
graphs = replace_globals(graphs)
logits = get_predicted_logits(curr_state, graphs, rngs)
mask = get_valid_mask(labels, graphs)
loss = binary_cross_entropy_with_mask(
logits=logits, labels=labels, mask=mask
)
mean_loss = jnp.sum(jnp.where(mask, loss, 0)) / jnp.sum(mask)
return mean_loss, (loss, logits, labels, mask)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, (loss, logits, labels, mask)), grads = grad_fn(state.params, graphs)
state = state.apply_gradients(grads=grads)
metrics_update = TrainMetrics.single_from_model_output(
loss=loss, logits=logits, labels=labels, mask=mask
)
return state, metrics_update
def get_fake_graphs(config: ml_collections.ConfigDict) -> jraph.GraphsTuple:
rng = jax.random.key(0)
n_graphs = config.batch_size
n_total_nodes = n_graphs * 20
n_total_edges = n_graphs * 40
nodes = jax.random.normal(rng, (n_total_nodes, 9))
edges = jax.random.normal(rng, (n_total_edges, 3))
senders = jax.random.randint(rng, (n_total_edges,), 0, n_total_nodes)
receivers = jax.random.randint(rng, (n_total_edges,), 0, n_total_nodes)
n_node = jnp.full((n_graphs,), 20, dtype=jnp.int32)
n_edge = jnp.full((n_graphs,), 40, dtype=jnp.int32)
globals_ = jax.random.bernoulli(
rng, shape=(n_graphs, config.num_classes)
).astype(jnp.float32)
return jraph.GraphsTuple(
nodes=nodes,
edges=edges,
senders=senders,
receivers=receivers,
n_node=n_node,
n_edge=n_edge,
globals=globals_,
)
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
rng = jax.random.key(0)
rng, init_rng, dropout_rng = jax.random.split(rng, 3)
graphs = get_fake_graphs(config)
init_net = create_model(config, deterministic=True)
init_graphs = replace_globals(graphs)
params = jax.jit(init_net.init)(init_rng, init_graphs)
tx = create_optimizer(config)
net = create_model(config, deterministic=False)
state = train_state.TrainState.create(
apply_fn=net.apply, params=params, tx=tx
)
return (
ogbg_train_step,
(state, graphs, {'dropout': dropout_rng}),
{},
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_ogbg_molpcba_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, ogbg_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_ogbg_molpcba_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, ogbg_config.get_config, state
)
if __name__ == '__main__':
tracing_benchmark.run_benchmarks()
================================================
FILE: benchmarks/tracing/ppo.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PPO helper functions for RL benchmarking."""
from typing import Any
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.ppo import models
from flax.examples.ppo import ppo_lib
from flax.examples.ppo.configs import default as ppo_config
import google_benchmark
import jax
import jax.numpy as jnp
import ml_collections
def get_fake_batch(batch_size: int = 256) -> tuple[jnp.ndarray, ...]:
"""Generate a batch of fake Atari observations and trajectories.
Args:
batch_size: Size of the minibatch.
Returns:
A tuple of (states, actions, old_log_probs, returns, advantages).
"""
# Atari observations: (batch_size, height, width, stacked_frames)
states = jax.random.randint(
jax.random.key(0),
(batch_size, 84, 84, 4),
minval=0,
maxval=256,
dtype=jnp.int32,
)
# Actions: discrete action space (e.g., 6 actions for Pong)
actions = jax.random.randint(
jax.random.key(1), (batch_size,), minval=0, maxval=6, dtype=jnp.int32
)
# Old log probabilities from behavior policy
old_log_probs = jax.random.normal(
jax.random.key(2), (batch_size,), dtype=jnp.float32
)
# Returns (discounted cumulative rewards)
returns = jax.random.normal(
jax.random.key(3), (batch_size,), dtype=jnp.float32
)
# Advantages (GAE advantages)
advantages = jax.random.normal(
jax.random.key(4), (batch_size,), dtype=jnp.float32
)
return (states, actions, old_log_probs, returns, advantages)
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
"""Returns the apply function and args for the given config.
Args:
config: The training configuration.
Returns:
A tuple of the apply function, args, and kwargs.
"""
# Create model (6 actions for Pong)
num_outputs = 6
model = models.ActorCritic(num_outputs=num_outputs)
# Initialize model parameters
rng = jax.random.key(0)
init_shape = jnp.ones((1, 84, 84, 4), jnp.float32)
initial_params = model.init(rng, init_shape)['params']
# Create train state
# For benchmarking, we don't need actual training loops, just one step
train_steps = 1000 # Dummy value for state creation
state = ppo_lib.create_train_state(initial_params, model, config, train_steps)
# Generate fake trajectories
trajectories = get_fake_batch(config.batch_size)
# PPO hyperparameters
clip_param = config.clip_param
vf_coeff = config.vf_coeff
entropy_coeff = config.entropy_coeff
# ppo_lib.train_step is already JIT-compiled with static batch_size
return (
ppo_lib.train_step,
(state, trajectories, config.batch_size),
{
'clip_param': clip_param,
'vf_coeff': vf_coeff,
'entropy_coeff': entropy_coeff,
},
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_ppo_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, ppo_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_ppo_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, ppo_config.get_config, state
)
if __name__ == '__main__':
tracing_benchmark.run_benchmarks()
================================================
FILE: benchmarks/tracing/requirements.txt
================================================
absl-py
flax
google-benchmark
jax
ml_collections
numpy
optax
================================================
FILE: benchmarks/tracing/run_all_benchmarks.sh
================================================
#!/bin/bash
set -e
export XLA_FLAGS=--xla_force_host_platform_device_count=8
TARGETS=(
mnist
vae
sst2
gemma
imagenet
seq2seq
lm1b
nlp_seq
ogbg_molpcba
wmt
ppo
)
for target in "${TARGETS[@]}"; do
echo "============================================"
echo "Running benchmark: ${target}"
echo "============================================"
benchy "third_party/py/flax/benchmarks/tracing:${target}"
echo ""
done
================================================
FILE: benchmarks/tracing/seq2seq.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Seq2Seq helper functions."""
from typing import Any
from flax import linen as nn
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.seq2seq import models
from flax.examples.seq2seq.configs import default as seq2seq_config
from flax.examples.seq2seq.input_pipeline import CharacterTable
from flax.examples.seq2seq.input_pipeline import get_sequence_lengths
from flax.examples.seq2seq.input_pipeline import mask_sequences
from flax.training import train_state
import google_benchmark
import jax
import jax.numpy as jnp
import ml_collections
import optax
def cross_entropy_loss(logits, labels, lengths):
xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1)
masked_xe = jnp.mean(mask_sequences(xe, lengths))
return -masked_xe
def compute_metrics(logits, labels, eos_id):
lengths = get_sequence_lengths(labels, eos_id)
loss = cross_entropy_loss(logits, labels, lengths)
token_accuracy = jnp.argmax(logits, -1) == jnp.argmax(labels, -1)
sequence_accuracy = (
jnp.sum(mask_sequences(token_accuracy, lengths), axis=-1) == lengths
)
accuracy = jnp.mean(sequence_accuracy)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
@jax.jit
def seq2seq_train_step(state, batch, lstm_rng, eos_id):
labels = batch['answer'][:, 1:]
lstm_key = jax.random.fold_in(lstm_rng, state.step)
def loss_fn(params):
logits, _ = state.apply_fn(
{'params': params},
batch['query'],
batch['answer'],
rngs={'lstm': lstm_key},
)
loss = cross_entropy_loss(
logits, labels, get_sequence_lengths(labels, eos_id)
)
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits, labels, eos_id)
return state, metrics
def get_fake_batch(batch_size: int, ctable: CharacterTable) -> dict[str, Any]:
return ctable.get_batch(batch_size)
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
rng = jax.random.key(0)
ctable = CharacterTable("0123456789+= ", config.max_len_query_digit)
model = models.Seq2seq(
teacher_force=False,
hidden_size=config.hidden_size,
eos_id=ctable.eos_id,
vocab_size=ctable.vocab_size,
)
rng1, rng2 = jax.random.split(rng)
params = model.init(
{'params': rng1, 'lstm': rng2},
jnp.ones(ctable.encoder_input_shape, jnp.float32),
jnp.ones(ctable.decoder_input_shape, jnp.float32),
)['params']
tx = optax.adam(config.learning_rate)
state = train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=tx
)
batch = get_fake_batch(config.batch_size, ctable)
return seq2seq_train_step, (state, batch, rng, ctable.eos_id), {}
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_seq2seq_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, seq2seq_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_seq2seq_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, seq2seq_config.get_config, state
)
if __name__ == "__main__":
tracing_benchmark.run_benchmarks()
================================================
FILE: benchmarks/tracing/sst2.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SST2 helper functions."""
from typing import Any
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.sst2 import models
from flax.examples.sst2.configs import default as sst2_config
from flax.training import train_state as train_state_lib
import google_benchmark
import jax
import jax.numpy as jnp
import ml_collections
import optax
Array = jnp.ndarray
TrainState = train_state_lib.TrainState
@jax.vmap
def sigmoid_cross_entropy_with_logits(*, labels: Array, logits: Array) -> Array:
zeros = jnp.zeros_like(logits, dtype=logits.dtype)
condition = logits >= zeros
relu_logits = jnp.where(condition, logits, zeros)
neg_abs_logits = jnp.where(condition, -logits, logits)
return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits))
def model_from_config(config: ml_collections.ConfigDict):
model = models.TextClassifier(
embedding_size=config.embedding_size,
hidden_size=config.hidden_size,
vocab_size=config.vocab_size,
output_size=config.output_size,
dropout_rate=config.dropout_rate,
word_dropout_rate=config.word_dropout_rate,
unk_idx=config.unk_idx,
)
return model
def get_initial_params(rng, model):
token_ids = jnp.ones((2, 3), jnp.int32)
lengths = jnp.ones((2,), dtype=jnp.int32)
variables = model.init(rng, token_ids, lengths, deterministic=True)
return variables['params']
def create_train_state(rng, config: ml_collections.ConfigDict, model):
params = get_initial_params(rng, model)
tx = optax.chain(
optax.sgd(learning_rate=config.learning_rate, momentum=config.momentum),
optax.add_decayed_weights(weight_decay=config.weight_decay),
)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
return state
def compute_metrics(*, labels: Array, logits: Array):
if labels.ndim == 1:
labels = jnp.expand_dims(labels, axis=1)
loss = sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
binary_predictions = logits >= 0.0
binary_accuracy = jnp.equal(binary_predictions, labels)
return {
'loss': jnp.sum(loss),
'accuracy': jnp.sum(binary_accuracy),
'count': logits.shape[0],
}
def train_step(
state: TrainState,
batch: dict[str, Array],
rngs: dict[str, Any],
) -> tuple[TrainState, Any]:
step = state.step
rngs = {name: jax.random.fold_in(rng, step) for name, rng in rngs.items()}
def loss_fn(params):
variables = {'params': params}
logits = state.apply_fn(
variables,
batch['token_ids'],
batch['length'],
deterministic=False,
rngs=rngs,
)
labels = batch['label']
if labels.ndim == 1:
labels = jnp.expand_dims(labels, 1)
loss = jnp.mean(
sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
)
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
value, grads = grad_fn(state.params)
(_, logits) = value
new_state = state.apply_gradients(grads=grads)
metrics = compute_metrics(labels=batch['label'], logits=logits)
return new_state, metrics
def get_fake_batch(batch_size: int) -> dict[str, Any]:
rng = jax.random.key(0)
max_length = 60
token_ids = jax.random.randint(
rng, (batch_size, max_length), 0, 1000, jnp.int32
)
lengths = jnp.full((batch_size,), max_length, jnp.int32)
labels = jax.random.uniform(rng, (batch_size,), jnp.float32)
return {
'token_ids': token_ids,
'length': lengths,
'label': labels,
}
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
rng = jax.random.key(0)
config = config.copy_and_resolve_references()
if config.vocab_size is None:
config.vocab_size = 1000
model = model_from_config(config)
state = create_train_state(rng, config, model)
batch = get_fake_batch(config.batch_size)
_, dropout_rng = jax.random.split(rng)
rngs = {'dropout': dropout_rng}
train_step_jit = jax.jit(train_step)
return train_step_jit, (state, batch, rngs), {}
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_sst2_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, sst2_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_sst2_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, sst2_config.get_config, state
)
if __name__ == '__main__':
tracing_benchmark.run_benchmarks()
================================================
FILE: benchmarks/tracing/tracing_benchmark.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared benchmark utilities for Jax tracing flax examples."""
from collections.abc import Callable
import sys
from typing import Any
from absl import app
from absl import flags
from absl import logging
import google_benchmark
import jax
flags.DEFINE_enum(
"mode",
"trace_and_lower",
["trace", "lower", "trace_and_lower"],
"Measure trace, lower, or trace_and_lower.",
)
def clear_caches(state):
state.pause_timing()
jax.clear_caches()
state.resume_timing()
def benchmark_tracing(
get_apply_fn_and_args: Callable[..., Any],
get_config: Callable[[], Any],
state: Any,
) -> None:
"""Benchmark for tracing a flax example."""
config = get_config()
apply_fn, args, kwargs = get_apply_fn_and_args(config)
while state:
if flags.FLAGS.mode == 'trace' or flags.FLAGS.mode == 'trace_and_lower':
_ = apply_fn.trace(*args, **kwargs)
clear_caches(state)
def benchmark_lowering(
get_apply_fn_and_args: Callable[..., Any],
get_config: Callable[[], Any],
state: Any,
platform: str = 'tpu',
) -> None:
"""Benchmark for lowering a flax example."""
config = get_config()
apply_fn, args, kwargs = get_apply_fn_and_args(config)
traced = apply_fn.trace(*args, **kwargs)
while state:
if flags.FLAGS.mode == 'lower' or flags.FLAGS.mode == 'trace_and_lower':
_ = traced.lower(lowering_platforms=(platform,))
clear_caches(state)
def run_single_example(
get_apply_fn_and_args: Callable[..., Any],
get_config: Callable[[], Any],
) -> None:
"""Run a single example for profiling."""
def main(argv):
del argv
if flags.FLAGS.mode == 'lower':
raise ValueError(
'`--mode=lower` is not supported when profiling a single example.'
)
config = get_config()
apply_fn, args, kwargs, *_ = get_apply_fn_and_args(config)
traced = apply_fn.trace(*args, **kwargs)
lowered = traced.lower(lowering_platforms=('tpu',))
logging.info('lowered: %s', lowered.as_text('hlo'))
app.run(main)
def run_benchmarks() -> None:
"""Run registered google_benchmark benchmarks."""
flags.FLAGS(sys.argv, known_only=True)
flags.FLAGS.mark_as_parsed()
google_benchmark.main()
================================================
FILE: benchmarks/tracing/vae.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VAE helper functions."""
from typing import Any
from flax import linen as nn
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.vae import models
from flax.examples.vae.configs import default as vae_config
from flax.training import train_state
import google_benchmark
import jax
import jax.numpy as jnp
import ml_collections
import optax
@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
logits = nn.log_sigmoid(logits)
return -jnp.sum(
labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits))
)
@jax.vmap
def kl_divergence(mean, logvar):
return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))
def train_step(state, batch, z_rng, latents):
def loss_fn(params):
recon_x, mean, logvar = models.model(latents).apply(
{'params': params}, batch, z_rng
)
bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
kld_loss = kl_divergence(mean, logvar).mean()
loss = bce_loss + kld_loss
return loss
grads = jax.grad(loss_fn)(state.params)
return state.apply_gradients(grads=grads)
def get_fake_batch(batch_size: int) -> Any:
return jnp.ones((batch_size, 784), jnp.float32)
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
rng = jax.random.key(0)
rng, key = jax.random.split(rng)
batch = get_fake_batch(config.batch_size)
params = models.model(config.latents).init(key, batch, rng)['params']
state = train_state.TrainState.create(
apply_fn=models.model(config.latents).apply,
params=params,
tx=optax.adam(config.learning_rate),
)
train_step_jit = jax.jit(train_step, static_argnames=('latents',))
return (
train_step_jit,
(state, batch, rng, config.latents),
dict(),
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_vae_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, vae_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_vae_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, vae_config.get_config, state
)
if __name__ == '__main__':
tracing_benchmark.run_benchmarks()
================================================
FILE: benchmarks/tracing/wmt.py
================================================
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""WMT helper functions for benchmarking."""
import functools
from typing import Any
from flax import linen as nn
from flax.benchmarks.tracing import tracing_benchmark
from flax.examples.wmt import models
from flax.examples.wmt.configs import default as wmt_config
from flax.training import common_utils
from flax.training import dynamic_scale as dynamic_scale_lib
from flax.training import train_state
import google_benchmark
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
class TrainState(train_state.TrainState):
dynamic_scale: dynamic_scale_lib.DynamicScale
def rsqrt_schedule(init_value: float, shift: int = 0):
def schedule(count):
return init_value * (count + shift) ** -0.5 * shift**0.5
return schedule
def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
return optax.join_schedules(
[
optax.linear_schedule(
init_value=0,
end_value=learning_rate,
transition_steps=warmup_steps,
),
rsqrt_schedule(init_value=learning_rate, shift=warmup_steps),
],
boundaries=[warmup_steps],
)
def preferred_dtype(config):
platform = jax.local_devices()[0].platform
if config.use_mixed_precision:
if platform == 'tpu':
return jnp.bfloat16
elif platform == 'gpu':
return jnp.float16
return jnp.float32
def compute_weighted_cross_entropy(
logits, targets, weights=None, label_smoothing=0.0
):
if logits.ndim != targets.ndim + 1:
raise ValueError(
'Incorrect shapes. Got shape %s logits and %s targets'
% (str(logits.shape), str(targets.shape))
)
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence)
+ (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_targets = common_utils.onehot(
targets, vocab_size, on_value=confidence, off_value=low_confidence
)
loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
loss = loss - normalizing_constant
normalizing_factor = np.prod(targets.shape)
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def compute_weighted_accuracy(logits, targets, weights=None):
if logits.ndim != targets.ndim + 1:
raise ValueError(
'Incorrect shapes. Got shape %s logits and %s targets'
% (str(logits.shape), str(targets.shape))
)
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
normalizing_factor = np.prod(logits.shape[:-1])
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def compute_metrics(logits, labels, weights, label_smoothing=0.0):
loss, weight_sum = compute_weighted_cross_entropy(
logits, labels, weights, label_smoothing
)
acc, _ = compute_weighted_accuracy(logits, labels, weights)
metrics = {
'loss': loss,
'accuracy': acc,
'denominator': weight_sum,
}
return metrics
def wmt_train_step(
state,
batch,
config,
learning_rate_fn,
label_smoothing=0.0,
dropout_rng=None,
):
train_keys = [
'inputs',
'targets',
'inputs_position',
'targets_position',
'inputs_segmentation',
'targets_segmentation',
]
(
inputs,
targets,
inputs_positions,
targets_positions,
inputs_segmentation,
targets_segmentation,
) = (batch.get(k, None) for k in train_keys)
weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32)
dropout_rng = jax.random.fold_in(dropout_rng, state.step)
def loss_fn(params):
logits = models.Transformer(config).apply(
{'params': params},
inputs,
targets,
inputs_positions=inputs_positions,
targets_positions=targets_positions,
inputs_segmentation=inputs_segmentation,
targets_segmentation=targets_segmentation,
rngs={'dropout': dropout_rng},
)
loss, weight_sum = compute_weighted_cross_entropy(
logits, targets, weights, label_smoothing
)
mean_loss = loss / weight_sum
return mean_loss, logits
step = state.step
if state.dynamic_scale:
grad_fn = state.dynamic_scale.value_and_grad(loss_fn, has_aux=True)
dynamic_scale, is_fin, (_, logits), grads = grad_fn(state.params)
state = state.replace(dynamic_scale=dynamic_scale)
else:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits, targets, weights)
metrics['learning_rate'] = learning_rate_fn(step)
if state.dynamic_scale:
select_fn = functools.partial(jnp.where, is_fin)
new_state = new_state.replace(
opt_state=jax.tree_util.tree_map(
select_fn, new_state.opt_state, state.opt_state
),
params=jax.tree_util.tree_map(
select_fn, new_state.params, state.params
),
)
metrics['loss_scale'] = dynamic_scale.scale * metrics['denominator']
return new_state, metrics
def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, Any]:
rng = jax.random.key(0)
batch_size = config.per_device_batch_size
max_len = config.max_target_length
inputs = jax.random.randint(
rng, (batch_size, max_len), 0, config.vocab_size, jnp.int32
)
targets = jax.random.randint(
rng, (batch_size, max_len), 0, config.vocab_size, jnp.int32
)
return {
'inputs': inputs,
'targets': targets,
'inputs_position': None,
'targets_position': None,
'inputs_segmentation': None,
'targets_segmentation': None,
}
def get_apply_fn_and_args(
config: ml_collections.ConfigDict,
) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
dtype = preferred_dtype(config)
train_config = models.TransformerConfig(
vocab_size=config.vocab_size,
output_vocab_size=config.vocab_size,
share_embeddings=config.share_embeddings,
logits_via_embedding=config.logits_via_embedding,
dtype=dtype,
emb_dim=config.emb_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
qkv_dim=config.qkv_dim,
mlp_dim=config.mlp_dim,
max_len=max(config.max_target_length, config.max_eval_target_length),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
deterministic=False,
decode=False,
kernel_init=jax.nn.initializers.xavier_uniform(),
bias_init=jax.nn.initializers.normal(stddev=1e-6),
)
model = models.Transformer(train_config)
learning_rate_fn = create_learning_rate_schedule(
learning_rate=config.learning_rate, warmup_steps=config.warmup_steps
)
optimizer = optax.adamw(
learning_rate_fn,
b1=0.9,
b2=0.98,
eps=1e-9,
weight_decay=config.weight_decay,
)
rng = jax.random.key(0)
init_rng, dropout_rng = jax.random.split(rng)
batch = get_fake_batch(config)
inputs = batch['inputs']
targets = batch['targets']
initial_variables = model.init(init_rng, inputs, targets)
dynamic_scale = None
platform = jax.local_devices()[0].platform
if config.use_mixed_precision and platform == 'gpu':
dynamic_scale = dynamic_scale_lib.DynamicScale()
state = TrainState.create(
apply_fn=model.apply,
params=initial_variables['params'],
tx=optimizer,
dynamic_scale=dynamic_scale,
)
jit_train_step = jax.jit(
wmt_train_step,
static_argnums=(2, 3, 4),
)
return (
jit_train_step,
(state, batch, train_config, learning_rate_fn, 0.0, dropout_rng),
{},
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_wmt_trace(state):
tracing_benchmark.benchmark_tracing(
get_apply_fn_and_args, wmt_config.get_config, state
)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def test_flax_wmt_lower(state):
tracing_benchmark.benchmark_lowering(
get_apply_fn_and_args, wmt_config.get_config, state
)
if __name__ == '__main__':
tracing_benchmark.run_benchmarks()
================================================
FILE: contributing.md
================================================
# How to Contribute
Please see https://flax.readthedocs.io/en/latest/contributing.html for more information.
================================================
FILE: docs/.gitignore
================================================
_formatted_howtos
================================================
FILE: docs/.readthedocs.yaml
================================================
# .readthedocs.yml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.12"
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
formats:
- htmlzip
- epub
# - pdf
# Optionally set the version of Python and requirements required to build your docs
python:
install:
- method: pip
path: .
extra_requirements:
- all
- testing
- docs
================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/README.md
================================================
# Deprecation
This folder contains the deprecated Flax Linen documentation. For the latest Flax NNX docs, check out the `docs_nnx` folder.
# Where to find the docs
The FLAX Linen documentation can be found here: https://flax-linen.readthedocs.io/en/latest/
# How to build the docs
1. Clone the `flax` repository with `git clone https://github.com/google/flax.git`.
1. In the main `flax` folder, install the required dependencies using `uv pip install -e .[docs]`.
1. [Optional] If you need to make any local changes to the docs, create and switch to a branch. Make your changes to the docs in that branch.
1. To build the docs, in the `flax/docs` folder run the make script: `make html`. Alternatively, install [`entr`](https://github.com/eradman/entr/), which helps run arbitrary commands when files change. Then run `find ../ ! -regex '.*/[\.|\_].*' | entr -s 'make html'`.
1. If the build is successful, you should get the `The HTML pages are in _build/html.` message. You can preview the docs in `flax/docs/_build/html`.
# How to run embedded code tests
We use `doctest` blocks for embedded code in documents, that are also
tested. Learn more at https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html
To run tests locally, run `make doctest`
# How to write code documentation
Our documentation is written in reStructuredText for Sphinx. It is a
meta-language that is compiled into online documentation. For more details,
check out
[Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html).
As a result, our docstrings adhere to a specific syntax that has to be kept in
mind. Below we provide some guidelines.
To learn how to contribute to Jupyter Notebooks or other formats in Flax docs,
refer to the dedicated
[Contributing](https://flax.readthedocs.io/en/latest/contributing.html) page.
## How much information to put in a docstring
Docstring should be informative. We prefer to err on the side of too much
documentation than too little. For instance, providing a one-line explanation
to a new `Module` which implements new functionality is not sufficient.
Furthermore, we highly encourage adding examples to your docstrings, so users
can directly see how code can be used.
## How to write inline tested code
We use [doctest](https://docs.python.org/3/library/doctest.html) syntax for
writing examples in documentation. These examples are ran as tests as part of
our CI process. In order to write `doctest` code in your documentation, please
use the following notation:
```bash
# Example code::
#
# def sum(a, b):
# return a + b
#
# sum(0, 1)
```
The `Example code` string at the beginning can be replaced by anything as long
as there are two semicolons and a newline following it, and the code is
indented.
## How to use "code font"
When writing code font in a docstring, please use double backticks. Example:
```bash
# This returns a ``str`` object.
```
Note that argument names and objects like True, None or any strings should
usually be put in `code`.
## How to create cross-references/links
It is possible to create cross-references to other classes, functions, and
methods. In the following, `obj_typ` is either `class`, `func`, or `meth`.
```bash
# First method:
# <obj_type>:`path_to_obj`
# Second method:
# :<obj_type>:`description <path_to_obj>`
```
You can use the second method if the `path_to_obj` is very long. Some examples:
```bash
# Create: a reference to class flax.linen.Module.
# :class:`flax.linen.Module`
# Create a reference to local function my_func.
# :func:`my_func`
# Create a reference "Module.apply()" to method flax.linen.Module.apply.
# :meth:`Module.apply() <flax.linen.Module.apply>` #
```
To creata a hyperlink, use the following syntax:
```bash
# Note the double underscore at the end:
# `Link to Google <http://www.google.com>`__
```
### How to specify arguments for classes and methods
* Class attributes should be specified using the `Attributes:` tag.
* Method argument should be specified using the `Args:` tags.
* All attributes and arguments should have types.
Here is an example from our library:
```python
class DenseGeneral(Module):
"""A linear transformation with flexible axes.
Attributes:
features: int or tuple with number of output features.
axis: int or tuple with axes to apply the transformation on. For instance,
(-2, -1) will apply the transformation to the last two axes.
batch_dims: tuple with batch axes.
use_bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: float32).
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
"""
features: Union[int, Iterable[int]]
axis: Union[int, Iterable[int]] = -1
batch_dims: Iterable[int] = ()
use_bias: bool = True
dtype: Dtype = jnp.float32
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
precision: Any = None
@compact
def __call__(self, inputs: Array) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
...
```
================================================
FILE: docs/_ext/codediff.py
================================================
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sphinx directive for creating code diff tables.
Use directive as follows:
.. codediff::
:title: <LEFT_CODE_BLOCK_TITLE>, <RIGHT_CODE_BLOCK_TITLE>
<CODE_BLOCK_LEFT>
---
<CODE_BLOCK_RIGHT>
In order to highlight a line of code, append "#!" to it.
"""
import sphinx
from docutils import nodes
from docutils.parsers.rst import directives
from docutils.statemachine import ViewList
from sphinx.util.docutils import SphinxDirective
MISSING = object()
class CodeDiffParser:
def parse(
self,
lines: list[str],
title: str,
groups: list[str] | None = None,
skip_test: str | None = None,
code_sep: str = '---',
sync: object = MISSING,
):
"""Parse the code diff block and format it so that it
renders in different tabs and is tested by doctest.
For example:
.. testcode:: tab0, tab2, tab3
<CODE_BLOCK_A>
.. codediff::
:title: Tab 0, Tab 1, Tab 2, Tab 3
:groups: tab0, tab1, tab2, tab3
:skip_test: tab1, tab3
<CODE_BLOCK_B0>
---
<CODE_BLOCK_B1>
---
<CODE_BLOCK_B2>
---
<CODE_BLOCK_B3>
For group tab0: <CODE_BLOCK_A> and <CODE_BLOCK_B0> are executed.
For group tab1: Nothing is executed.
For group tab2: <CODE_BLOCK_A> and <CODE_BLOCK_B2> are executed.
For group tab3: <CODE_BLOCK_A> is executed.
Arguments:
lines: a string list, where each element is a single string code line
title: a single string that contains the titles of each tab (they should
be separated by commas)
groups: a single string that contains the group of each tab (they should
be separated by commas). Code snippets that are part of the same group
will be executed together. If groups=None, then the group names will
default to the tab title names.
skip_test: a single string denoting which group(s) to skip testing (they
should be separated by commas). This is useful for legacy code snippets
that no longer run correctly anymore. If skip_test=None, then no tests
are skipped.
code_sep: the separator character(s) used to denote a separate code block
for a new tab. The default code separator is '---'.
sync: an option for Sphinx directives, that will sync all tabs together.
This means that if the user clicks to switch to another tab, all tabs
will switch to the new tab.
"""
titles = [t.strip() for t in title.split(',')]
num_tabs = len(titles)
sync = sync is not MISSING
# skip legacy code snippets in upgrade guides
if skip_test is not None:
skip_tests = {index.strip() for index in skip_test.split(',')}
else:
skip_tests = set()
code_blocks = '\n'.join(lines)
if code_blocks.count(code_sep) != num_tabs - 1:
raise ValueError(
f'Expected {num_tabs-1} code separator(s) for {num_tabs} tab(s), but got {code_blocks.count(code_sep)} code separator(s) instead.'
)
code_blocks = [
code_block.split('\n')
for code_block in code_blocks.split(code_sep + '\n')
] # list[code_tab_list1[string_line1, ...], ...]
# by default, put each code snippet in a different group denoted by an index number, to be executed separately
if groups is not None:
groups = [group_name.strip() for group_name in groups.split(',')]
else:
groups = titles
if len(groups) != num_tabs:
raise ValueError(
f'Expected {num_tabs} group assignment(s) for {num_tabs} tab(s), but got {len(groups)} group assignment(s) instead.'
)
tabs = []
test_codes = []
for i, code_block in enumerate(code_blocks):
if groups[i] not in skip_tests:
test_codes.append((code_block, groups[i]))
tabs.append((titles[i], self._code_block(code_block)))
output = self._tabs(*tabs, sync=sync)
return output, test_codes
def _code_block(self, lines):
"""Creates a codeblock."""
# Remove right trailing whitespace so we can detect the comments.
lines = [x.rstrip() for x in lines]
highlight = lambda x: x.endswith('#!')
code = map(lambda x: x[:-2].rstrip() if highlight(x) else x, lines)
highlights = [i + 1 for i in range(len(lines)) if highlight(lines[i])]
highlights = ','.join(str(i) for i in highlights)
directive = ['.. code-block:: python']
if highlights:
directive += [f' :emphasize-lines: {highlights}']
# Indent code and add empty line so the code is picked up by the directive.
return directive + [''] + list(map(lambda x: ' ' + x, code))
def _tabs(self, *contents: tuple[str, list[str]], sync):
output = ['.. tab-set::'] + [' ']
for title, content in contents:
output += [f' .. tab-item:: {title}']
if sync:
key = title.strip()
output += [f' :sync: {key}']
output += [' ']
output += [' ' + line for line in content]
return output
class CodeDiffDirective(SphinxDirective):
has_content = True
option_spec = {
'title': directives.unchanged,
'groups': directives.unchanged,
'skip_test': directives.unchanged,
'code_sep': directives.unchanged,
'sync': directives.flag,
}
def run(self):
table_code, test_codes = CodeDiffParser().parse(
list(self.content), **self.options
)
# Create a test node as a comment node so it won't show up in the docs.
# We add attribute "testnodetype" so it is be picked up by the doctest
# builder. This functionality is not officially documented but can be found
# in the source code:
# https://github.com/sphinx-doc/sphinx/blob/master/sphinx/ext/doctest.py
# (search for 'testnodetype').
test_nodes = []
for test_code, group in test_codes:
test_node = nodes.comment(
'\n'.join(test_code),
'\n'.join(test_code),
testnodetype='testcode',
groups=[group],
)
self.set_source_info(test_node)
test_node['options'] = {}
test_node['language'] = 'python3'
test_nodes.append(test_node)
# The table node is the side-by-side diff view that will be shown on RTD.
table_node = nodes.paragraph()
self.content = ViewList(table_code, self.content.parent)
self.state.nested_parse(self.content, self.content_offset, table_node)
return [table_node] + test_nodes
def setup(app):
app.add_directive('codediff', CodeDiffDirective)
return {
'version': sphinx.__display_version__,
'parallel_read_safe': True,
'parallel_write_safe': True,
}
================================================
FILE: docs/_ext/codediff_test.py
================================================
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for codediff Sphinx extension."""
from absl.testing import parameterized
from codediff import CodeDiffParser
class CodeDiffTest(parameterized.TestCase):
def test_parse(self):
input_text = r"""@jax.jit #!
def get_initial_params(key): #!
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = CNN().init(key, init_val)['params']
extra_line
return initial_params
---
@jax.pmap #!
def get_initial_params(key):
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = CNN().init(key, init_val)['params']
return initial_params"""
expected_table = """.. tab-set::\n \n .. tab-item:: Single device\n \n .. code-block:: python\n :emphasize-lines: 1,2\n \n @jax.jit\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n extra_line\n return initial_params\n \n .. tab-item:: Ensembling on multiple devices\n \n .. code-block:: python\n :emphasize-lines: 1\n \n @jax.pmap\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n return initial_params"""
expected_testcodes = [
r"""@jax.jit #!
def get_initial_params(key): #!
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = CNN().init(key, init_val)['params']
extra_line
return initial_params
""",
r"""@jax.pmap #!
def get_initial_params(key):
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = CNN().init(key, init_val)['params']
return initial_params""",
]
title_left = 'Single device'
title_right = 'Ensembling on multiple devices'
actual_table, actual_testcodes = CodeDiffParser().parse(
lines=input_text.split('\n'),
title=f'{title_left}, {title_right}',
)
actual_table = '\n'.join(actual_table)
actual_testcodes = ['\n'.join(testcode) for testcode, _ in actual_testcodes]
self.assertEqual(expected_table, actual_table)
self.assertEqual(expected_testcodes[0], actual_testcodes[0])
self.assertEqual(expected_testcodes[1], actual_testcodes[1])
@parameterized.parameters(
{
'input_text': r"""x = 1
---
x = 2
""",
'title': 'Tab 0, Tab1, Tab2',
'groups': None,
'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 1 code separator\\(s\\) instead.',
},
{
'input_text': r"""x = 1
---
x = 2
---
x = 3
---
x = 4
""",
'title': 'Tab 0, Tab1, Tab2',
'groups': None,
'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 3 code separator\\(s\\) instead.',
},
{
'input_text': r"""x = 1
---
x = 2
---
x = 3
""",
'title': 'Tab 0, Tab1, Tab2',
'groups': 'tab0, tab2',
'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 2 group assignment\\(s\\) instead.',
},
{
'input_text': r"""x = 1
---
x = 2
---
x = 3
""",
'title': 'Tab 0, Tab1, Tab2',
'groups': 'tab0, tab1, tab2, tab3',
'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 4 group assignment\\(s\\) instead.',
},
)
def test_parse_errors(self, input_text, title, groups, error_msg):
with self.assertRaisesRegex(ValueError, error_msg):
_, _ = CodeDiffParser().parse(
lines=input_text.split('\n'),
title=title,
groups=groups,
)
================================================
FILE: docs/_ext/flax_module.py
================================================
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sphinx directive for visualizing Flax modules.
Use directive as follows:
.. flax_module::
:module: flax.linen
:class: Dense
"""
import importlib
import sphinx
import sphinx.ext.autosummary.generate as ag
from docutils import nodes
from docutils.parsers.rst import directives
from docutils.statemachine import ViewList
from sphinx.util.docutils import SphinxDirective
from docs.conf_sphinx_patch import generate_autosummary_content
def render_module(modname: str, qualname: str, app):
parent = importlib.import_module(modname)
obj = getattr(parent, qualname)
template = ag.AutosummaryRenderer(app)
template_name = 'flax_module'
imported_members = False
recursive = False
context = {}
return generate_autosummary_content(
qualname,
obj,
parent,
template,
template_name,
imported_members,
app,
recursive,
context,
modname,
qualname,
)
class FlaxModuleDirective(SphinxDirective):
has_content = True
option_spec = {
'module': directives.unchanged,
'class': directives.unchanged,
}
def run(self):
module_template = render_module(
self.options['module'], self.options['class'], self.env.app
)
module_template = module_template.splitlines()
# Create a container for the rendered nodes
container_node = nodes.container()
self.content = ViewList(module_template, self.content.parent)
self.state.nested_parse(self.content, self.content_offset, container_node)
return [container_node]
def setup(app):
app.add_directive('flax_module', FlaxModuleDirective)
return {
'version': sphinx.__display_version__,
'parallel_read_safe': True,
'parallel_write_safe': True,
}
================================================
FILE: docs/_static/css/flax_theme.css
================================================
@import url("theme.css");
.wy-nav-content {
max-width: 1290px;
}
.rst-content table.docutils {
width: 100%;
}
.rst-content table.docutils td {
vertical-align: top;
padding: 0;
}
.rst-content table.docutils td p {
padding: 8px;
}
.rst-content div[class^=highlight] {
border: 0;
margin: 0;
}
================================================
FILE: docs/_
gitextract_07u909eo/
├── .git-blame-ignore-revs
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ └── bug_report.md
│ ├── analytics/
│ │ ├── README.md
│ │ ├── get_repo_metrics.py
│ │ ├── issue_activity_since_date.gql
│ │ ├── pr_data_query.gql
│ │ └── requirements.txt
│ ├── pull_request_template.md
│ └── workflows/
│ ├── flax_publish.yml
│ ├── flax_test.yml
│ ├── flaxlib_publish.yml
│ └── jax_nightly.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── AUTHORS
├── CHANGELOG.md
├── LICENSE
├── README.md
├── benchmarks/
│ ├── README.md
│ ├── nnx_graph_overhead.py
│ ├── nnx_mlpmixer_training.py
│ ├── nnx_simple_training.py
│ ├── nnx_state_traversal.py
│ └── tracing/
│ ├── README.md
│ ├── __init__.py
│ ├── gemma.py
│ ├── imagenet.py
│ ├── lm1b.py
│ ├── mnist.py
│ ├── nlp_seq.py
│ ├── ogbg_molpcba.py
│ ├── ppo.py
│ ├── requirements.txt
│ ├── run_all_benchmarks.sh
│ ├── seq2seq.py
│ ├── sst2.py
│ ├── tracing_benchmark.py
│ ├── vae.py
│ └── wmt.py
├── contributing.md
├── docs/
│ ├── .gitignore
│ ├── .readthedocs.yaml
│ ├── Makefile
│ ├── README.md
│ ├── _ext/
│ │ ├── codediff.py
│ │ ├── codediff_test.py
│ │ └── flax_module.py
│ ├── _static/
│ │ └── css/
│ │ └── flax_theme.css
│ ├── _templates/
│ │ └── autosummary/
│ │ └── flax_module.rst
│ ├── api_reference/
│ │ ├── flax.core.frozen_dict.rst
│ │ ├── flax.cursor.rst
│ │ ├── flax.errors.rst
│ │ ├── flax.jax_utils.rst
│ │ ├── flax.linen/
│ │ │ ├── activation_functions.rst
│ │ │ ├── decorators.rst
│ │ │ ├── index.rst
│ │ │ ├── init_apply.rst
│ │ │ ├── initializers.rst
│ │ │ ├── inspection.rst
│ │ │ ├── layers.rst
│ │ │ ├── module.rst
│ │ │ ├── profiling.rst
│ │ │ ├── spmd.rst
│ │ │ ├── transformations.rst
│ │ │ └── variable.rst
│ │ ├── flax.serialization.rst
│ │ ├── flax.struct.rst
│ │ ├── flax.traceback_util.rst
│ │ ├── flax.training.rst
│ │ └── index.rst
│ ├── conf.py
│ ├── conf_sphinx_patch.py
│ ├── developer_notes/
│ │ ├── index.rst
│ │ ├── lift.md
│ │ └── module_lifecycle.rst
│ ├── examples/
│ │ ├── community_examples.rst
│ │ ├── core_examples.rst
│ │ ├── google_research_examples.rst
│ │ ├── index.rst
│ │ └── repositories_that_use_flax.rst
│ ├── faq.rst
│ ├── flip/
│ │ ├── 0000-template.md
│ │ ├── 1009-optimizer-api.md
│ │ ├── 1777-default-dtype.md
│ │ ├── 2396-rnn.md
│ │ ├── 2434-general-metadata.md
│ │ ├── 2974-kw-only-dataclasses.md
│ │ ├── 3099-rnnbase-refactor.md
│ │ ├── 4105-jax-style-nnx-transforms.md
│ │ └── README.md
│ ├── glossary.rst
│ ├── guides/
│ │ ├── converting_and_upgrading/
│ │ │ ├── convert_pytorch_to_flax.rst
│ │ │ ├── haiku_migration_guide.rst
│ │ │ ├── index.rst
│ │ │ ├── linen_upgrade_guide.rst
│ │ │ ├── optax_update_guide.rst
│ │ │ ├── orbax_upgrade_guide.rst
│ │ │ ├── regular_dict_upgrade_guide.rst
│ │ │ └── rnncell_upgrade_guide.rst
│ │ ├── data_preprocessing/
│ │ │ ├── full_eval.rst
│ │ │ ├── index.rst
│ │ │ ├── loading_datasets.ipynb
│ │ │ └── loading_datasets.md
│ │ ├── flax_fundamentals/
│ │ │ ├── arguments.md
│ │ │ ├── flax_basics.ipynb
│ │ │ ├── flax_basics.md
│ │ │ ├── index.rst
│ │ │ ├── rng_guide.ipynb
│ │ │ ├── rng_guide.md
│ │ │ ├── setup_or_nncompact.rst
│ │ │ └── state_params.rst
│ │ ├── flax_sharp_bits.ipynb
│ │ ├── flax_sharp_bits.md
│ │ ├── index.rst
│ │ ├── model_inspection/
│ │ │ ├── extracting_intermediates.rst
│ │ │ ├── index.rst
│ │ │ ├── model_surgery.ipynb
│ │ │ └── model_surgery.md
│ │ ├── parallel_training/
│ │ │ ├── ensembling.rst
│ │ │ ├── flax_on_pjit.ipynb
│ │ │ ├── flax_on_pjit.md
│ │ │ └── index.rst
│ │ ├── quantization/
│ │ │ ├── fp8_basics.ipynb
│ │ │ ├── fp8_basics.md
│ │ │ └── index.rst
│ │ └── training_techniques/
│ │ ├── batch_norm.rst
│ │ ├── dropout.rst
│ │ ├── index.rst
│ │ ├── lr_schedule.rst
│ │ ├── transfer_learning.ipynb
│ │ ├── transfer_learning.md
│ │ ├── use_checkpointing.ipynb
│ │ └── use_checkpointing.md
│ ├── index.rst
│ ├── linen_intro.ipynb
│ ├── linen_intro.md
│ ├── quick_start.ipynb
│ ├── quick_start.md
│ └── robots.txt
├── docs_nnx/
│ ├── .gitignore
│ ├── .readthedocs.yaml
│ ├── Makefile
│ ├── README.md
│ ├── _ext/
│ │ ├── codediff.py
│ │ ├── codediff_test.py
│ │ └── flax_module.py
│ ├── _static/
│ │ └── css/
│ │ └── flax_theme.css
│ ├── _templates/
│ │ └── autosummary/
│ │ └── flax_module.rst
│ ├── api_reference/
│ │ ├── flax.config.rst
│ │ ├── flax.core.frozen_dict.rst
│ │ ├── flax.nnx/
│ │ │ ├── bridge.rst
│ │ │ ├── filterlib.rst
│ │ │ ├── graph.rst
│ │ │ ├── helpers.rst
│ │ │ ├── index.rst
│ │ │ ├── module.rst
│ │ │ ├── nn/
│ │ │ │ ├── activations.rst
│ │ │ │ ├── attention.rst
│ │ │ │ ├── dtypes.rst
│ │ │ │ ├── index.rst
│ │ │ │ ├── initializers.rst
│ │ │ │ ├── linear.rst
│ │ │ │ ├── lora.rst
│ │ │ │ ├── normalization.rst
│ │ │ │ ├── recurrent.rst
│ │ │ │ └── stochastic.rst
│ │ │ ├── object.rst
│ │ │ ├── rnglib.rst
│ │ │ ├── spmd.rst
│ │ │ ├── state.rst
│ │ │ ├── summary.rst
│ │ │ ├── training/
│ │ │ │ ├── index.rst
│ │ │ │ ├── metrics.rst
│ │ │ │ └── optimizer.rst
│ │ │ ├── transforms.rst
│ │ │ ├── variables.rst
│ │ │ └── visualization.rst
│ │ ├── flax.struct.rst
│ │ ├── flax.training.rst
│ │ ├── flax.traverse_util.rst
│ │ └── index.rst
│ ├── conf.py
│ ├── conf_sphinx_patch.py
│ ├── contributing.md
│ ├── examples/
│ │ ├── core_examples.rst
│ │ ├── gemma.ipynb
│ │ ├── gemma.md
│ │ └── index.rst
│ ├── faq.rst
│ ├── flip/
│ │ ├── 0000-template.md
│ │ ├── 1009-optimizer-api.md
│ │ ├── 1777-default-dtype.md
│ │ ├── 2396-rnn.md
│ │ ├── 2434-general-metadata.md
│ │ ├── 2974-kw-only-dataclasses.md
│ │ ├── 3099-rnnbase-refactor.md
│ │ ├── 4105-jax-style-nnx-transforms.md
│ │ ├── 4844-var-eager-sharding.md
│ │ ├── 5310-tree-mode-nnx.md
│ │ └── README.md
│ ├── guides/
│ │ ├── blog.md
│ │ ├── bridge_guide.ipynb
│ │ ├── bridge_guide.md
│ │ ├── checkpointing.ipynb
│ │ ├── checkpointing.md
│ │ ├── demo.ipynb
│ │ ├── demo.md
│ │ ├── extracting_intermediates.ipynb
│ │ ├── extracting_intermediates.md
│ │ ├── filters_guide.ipynb
│ │ ├── filters_guide.md
│ │ ├── flax_gspmd.ipynb
│ │ ├── flax_gspmd.md
│ │ ├── index.rst
│ │ ├── jax_and_nnx_transforms.rst
│ │ ├── performance.ipynb
│ │ ├── performance.md
│ │ ├── pytree.ipynb
│ │ ├── pytree.md
│ │ ├── randomness.ipynb
│ │ ├── randomness.md
│ │ ├── surgery.ipynb
│ │ ├── surgery.md
│ │ ├── tiny_nnx.ipynb
│ │ ├── transforms.ipynb
│ │ ├── transforms.md
│ │ ├── view.ipynb
│ │ └── view.md
│ ├── guides_advanced.rst
│ ├── guides_basic.rst
│ ├── hijax/
│ │ ├── hijax.ipynb
│ │ ├── hijax.md
│ │ └── index.rst
│ ├── index.rst
│ ├── key_concepts.ipynb
│ ├── key_concepts.md
│ ├── migrating/
│ │ ├── convert_pytorch_to_flax.rst
│ │ ├── haiku_to_flax.rst
│ │ ├── index.rst
│ │ ├── linen_to_nnx.rst
│ │ └── nnx_010_to_nnx_011.rst
│ ├── mnist_tutorial.ipynb
│ ├── mnist_tutorial.md
│ ├── nnx_basics.ipynb
│ ├── nnx_basics.md
│ ├── nnx_glossary.rst
│ ├── philosophy.md
│ ├── robots.txt
│ └── why.rst
├── examples/
│ ├── README.md
│ ├── __init__.py
│ ├── cloud/
│ │ ├── README.md
│ │ ├── launch_gce.py
│ │ └── startup_script.sh
│ ├── gemma/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ ├── default.py
│ │ │ ├── gemma3_4b.py
│ │ │ ├── small.py
│ │ │ └── tiny.py
│ │ ├── helpers.py
│ │ ├── helpers_test.py
│ │ ├── input_pipeline.py
│ │ ├── input_pipeline_test.py
│ │ ├── layers.py
│ │ ├── layers_test.py
│ │ ├── main.py
│ │ ├── modules.py
│ │ ├── modules_test.py
│ │ ├── params.py
│ │ ├── positional_embeddings.py
│ │ ├── positional_embeddings_test.py
│ │ ├── requirements.txt
│ │ ├── sampler.py
│ │ ├── sampler_test.py
│ │ ├── sow_lib.py
│ │ ├── tokenizer.py
│ │ ├── train.py
│ │ ├── transformer.py
│ │ ├── transformer_test.py
│ │ └── utils.py
│ ├── imagenet/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ ├── default.py
│ │ │ ├── fake_data_benchmark.py
│ │ │ ├── tpu.py
│ │ │ ├── v100_x8.py
│ │ │ └── v100_x8_mixed_precision.py
│ │ ├── imagenet.ipynb
│ │ ├── imagenet_benchmark.py
│ │ ├── imagenet_fake_data_benchmark.py
│ │ ├── input_pipeline.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── models_test.py
│ │ ├── requirements.txt
│ │ ├── train.py
│ │ └── train_test.py
│ ├── linen_design_test/
│ │ ├── attention_simple.py
│ │ ├── autoencoder.py
│ │ ├── dense.py
│ │ ├── linear_regression.py
│ │ ├── mlp_explicit.py
│ │ ├── mlp_inline.py
│ │ └── mlp_lazy.py
│ ├── lm1b/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── input_pipeline.py
│ │ ├── input_pipeline_test.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── requirements.txt
│ │ ├── temperature_sampler.py
│ │ ├── temperature_sampler_test.py
│ │ ├── tokenizer.py
│ │ ├── train.py
│ │ ├── train_test.py
│ │ └── utils.py
│ ├── mnist/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── main.py
│ │ ├── mnist.ipynb
│ │ ├── mnist_benchmark.py
│ │ ├── requirements.txt
│ │ ├── train.py
│ │ └── train_test.py
│ ├── nlp_seq/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── input_pipeline.py
│ │ ├── input_pipeline_test.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── requirements.txt
│ │ └── train.py
│ ├── nnx_toy_examples/
│ │ ├── 01_functional_api.py
│ │ ├── 02_lifted_transforms.py
│ │ ├── 03_train_state.py
│ │ ├── 04_data_parallel_with_jit.py
│ │ ├── 05_vae.py
│ │ ├── 06_scan_over_layers.py
│ │ ├── 07_array_leaves.py
│ │ ├── 08_save_load_checkpoints.py
│ │ ├── 09_parameter_surgery.py
│ │ ├── 10_fsdp_and_optimizer.py
│ │ ├── hijax_basic.py
│ │ ├── hijax_demo.py
│ │ └── requirements.txt
│ ├── ogbg_molpcba/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ ├── default.py
│ │ │ ├── default_graph_net.py
│ │ │ ├── hparam_sweep.py
│ │ │ └── test.py
│ │ ├── input_pipeline.py
│ │ ├── input_pipeline_test.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── models_test.py
│ │ ├── ogbg_molpcba.ipynb
│ │ ├── ogbg_molpcba_benchmark.py
│ │ ├── requirements.txt
│ │ ├── train.py
│ │ └── train_test.py
│ ├── ppo/
│ │ ├── README.md
│ │ ├── agent.py
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── env_utils.py
│ │ ├── models.py
│ │ ├── ppo_lib.py
│ │ ├── ppo_lib_test.py
│ │ ├── ppo_main.py
│ │ ├── requirements.txt
│ │ ├── seed_rl_atari_preprocessing.py
│ │ └── test_episodes.py
│ ├── seq2seq/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── input_pipeline.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── requirements.txt
│ │ ├── seq2seq.ipynb
│ │ ├── train.py
│ │ └── train_test.py
│ ├── sst2/
│ │ ├── README.md
│ │ ├── build_vocabulary.py
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── input_pipeline.py
│ │ ├── input_pipeline_test.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── models_test.py
│ │ ├── requirements.txt
│ │ ├── sst2.ipynb
│ │ ├── train.py
│ │ ├── train_test.py
│ │ ├── vocab.txt
│ │ └── vocabulary.py
│ ├── vae/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── input_pipeline.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── requirements.txt
│ │ ├── results/
│ │ │ └── .gitignore
│ │ ├── train.py
│ │ └── utils.py
│ └── wmt/
│ ├── README.md
│ ├── bleu.py
│ ├── configs/
│ │ └── default.py
│ ├── decode.py
│ ├── input_pipeline.py
│ ├── input_pipeline_test.py
│ ├── main.py
│ ├── models.py
│ ├── requirements.txt
│ ├── tokenizer.py
│ ├── train.py
│ └── train_test.py
├── flax/
│ ├── __init__.py
│ ├── configurations.py
│ ├── core/
│ │ ├── __init__.py
│ │ ├── axes_scan.py
│ │ ├── flax_functional_engine.ipynb
│ │ ├── frozen_dict.py
│ │ ├── lift.py
│ │ ├── meta.py
│ │ ├── nn/
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── linear.py
│ │ │ ├── normalization.py
│ │ │ └── stochastic.py
│ │ ├── partial_eval.py
│ │ ├── scope.py
│ │ ├── spmd.py
│ │ ├── tracers.py
│ │ └── variables.py
│ ├── cursor.py
│ ├── errors.py
│ ├── experimental/
│ │ ├── __init__.py
│ │ └── nnx.py
│ ├── ids.py
│ ├── io.py
│ ├── jax_utils.py
│ ├── linen/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── activation.py
│ │ ├── attention.py
│ │ ├── batch_apply.py
│ │ ├── combinators.py
│ │ ├── dtypes.py
│ │ ├── experimental/
│ │ │ └── layers_with_named_axes.py
│ │ ├── fp8_ops.py
│ │ ├── initializers.py
│ │ ├── kw_only_dataclasses.py
│ │ ├── linear.py
│ │ ├── module.py
│ │ ├── normalization.py
│ │ ├── partitioning.py
│ │ ├── pooling.py
│ │ ├── recurrent.py
│ │ ├── spmd.py
│ │ ├── stochastic.py
│ │ ├── summary.py
│ │ └── transforms.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ └── tensorboard.py
│ ├── nnx/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── bridge/
│ │ │ ├── __init__.py
│ │ │ ├── interop.py
│ │ │ ├── module.py
│ │ │ ├── variables.py
│ │ │ └── wrappers.py
│ │ ├── compat.py
│ │ ├── extract.py
│ │ ├── filterlib.py
│ │ ├── graph.py
│ │ ├── graphlib.py
│ │ ├── helpers.py
│ │ ├── ids.py
│ │ ├── module.py
│ │ ├── nn/
│ │ │ ├── __init__.py
│ │ │ ├── activations.py
│ │ │ ├── attention.py
│ │ │ ├── dtypes.py
│ │ │ ├── initializers.py
│ │ │ ├── linear.py
│ │ │ ├── lora.py
│ │ │ ├── normalization.py
│ │ │ ├── recurrent.py
│ │ │ └── stochastic.py
│ │ ├── proxy_caller.py
│ │ ├── pytreelib.py
│ │ ├── reprlib.py
│ │ ├── rnglib.py
│ │ ├── scripts/
│ │ │ ├── requirements.txt
│ │ │ └── run-all-examples.bash
│ │ ├── spmd.py
│ │ ├── statelib.py
│ │ ├── summary.py
│ │ ├── tracers.py
│ │ ├── training/
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ └── optimizer.py
│ │ ├── transforms/
│ │ │ ├── __init__.py
│ │ │ ├── autodiff.py
│ │ │ ├── compilation.py
│ │ │ ├── general.py
│ │ │ ├── iteration.py
│ │ │ └── transforms.py
│ │ ├── traversals.py
│ │ ├── variablelib.py
│ │ └── visualization.py
│ ├── oss/
│ │ └── .git-blame-ignore-revs
│ ├── py.typed
│ ├── serialization.py
│ ├── struct.py
│ ├── testing/
│ │ ├── __init__.py
│ │ └── benchmark.py
│ ├── traceback_util.py
│ ├── training/
│ │ ├── __init__.py
│ │ ├── checkpoints.py
│ │ ├── common_utils.py
│ │ ├── dynamic_scale.py
│ │ ├── early_stopping.py
│ │ ├── lr_schedule.py
│ │ ├── orbax_utils.py
│ │ ├── prefetch_iterator.py
│ │ └── train_state.py
│ ├── traverse_util.py
│ ├── typing.py
│ └── version.py
├── flaxlib_src/
│ ├── .gitignore
│ ├── CMakeLists.txt
│ ├── Cargo.toml
│ ├── LICENSE
│ ├── README.md
│ ├── pyproject.toml
│ └── src/
│ ├── flaxlib/
│ │ ├── __init__.py
│ │ └── flaxlib_cpp.pyi
│ └── lib.cc
├── nnx.py
├── pylintrc
├── pyproject.toml
└── tests/
├── checkpoints_test.py
├── colab_tpu_jax_version.ipynb
├── configurations_test.py
├── core/
│ ├── core_frozen_dict_test.py
│ ├── core_lift_test.py
│ ├── core_meta_test.py
│ ├── core_scope_test.py
│ └── design/
│ ├── core_attention_test.py
│ ├── core_auto_encoder_test.py
│ ├── core_big_resnets_test.py
│ ├── core_custom_vjp_test.py
│ ├── core_dense_test.py
│ ├── core_flow_test.py
│ ├── core_resnet_test.py
│ ├── core_scan_test.py
│ ├── core_tied_autoencoder_test.py
│ ├── core_vmap_test.py
│ └── core_weight_std_test.py
├── cursor_test.py
├── download_dataset_metadata.sh
├── early_stopping_test.py
├── flaxlib_test.py
├── import_test.ipynb
├── io_test.py
├── jax_utils_test.py
├── linen/
│ ├── initializers_test.py
│ ├── kw_only_dataclasses_test.py
│ ├── linen_activation_test.py
│ ├── linen_attention_test.py
│ ├── linen_batch_apply_test.py
│ ├── linen_combinators_test.py
│ ├── linen_dtypes_test.py
│ ├── linen_linear_test.py
│ ├── linen_meta_test.py
│ ├── linen_module_test.py
│ ├── linen_recurrent_test.py
│ ├── linen_test.py
│ ├── linen_transforms_test.py
│ ├── partitioning_test.py
│ └── summary_test.py
├── nnx/
│ ├── __init__.py
│ ├── bridge/
│ │ ├── module_test.py
│ │ └── wrappers_test.py
│ ├── containers_test.py
│ ├── filters_test.py
│ ├── graph_utils_test.py
│ ├── helpers_test.py
│ ├── ids_test.py
│ ├── integration_test.py
│ ├── metrics_test.py
│ ├── module_test.py
│ ├── mutable_array_test.py
│ ├── nn/
│ │ ├── attention_test.py
│ │ ├── conv_test.py
│ │ ├── embed_test.py
│ │ ├── linear_test.py
│ │ ├── lora_test.py
│ │ ├── normalization_test.py
│ │ ├── recurrent_test.py
│ │ └── stochastic_test.py
│ ├── optimizer_test.py
│ ├── partitioning_test.py
│ ├── rngs_test.py
│ ├── spmd_test.py
│ ├── state_test.py
│ ├── summary_test.py
│ ├── test_traversals.py
│ ├── transforms_test.py
│ └── variable_test.py
├── pickle_test.py
├── run_all_tests.sh
├── serialization_test.py
├── struct_test.py
├── tensorboard_test.py
├── traceback_util_test.py
└── traverse_util_test.py
Showing preview only (398K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (5058 symbols across 319 files)
FILE: .github/analytics/get_repo_metrics.py
function load_query_from_file (line 38) | def load_query_from_file(fname, repo_owner, repo_name) -> str:
function send_query (line 47) | def send_query(query, query_type, cursor=None):
function get_all_responses (line 93) | def get_all_responses(query, query_type):
function parse_single_query (line 109) | def parse_single_query(data, query_type):
class GithubGrabber (line 129) | class GithubGrabber:
method __init__ (line 134) | def __init__(self, query_fname, query_type, repo_owner, repo_name):
method load_query (line 162) | def load_query(self):
method get (line 167) | def get(self):
function _to_datetime (line 176) | def _to_datetime(date_str: str) -> datetime:
function _get_issues_features (line 180) | def _get_issues_features(issues):
function _get_pr_features (line 205) | def _get_pr_features(prs):
function _start_of_month (line 251) | def _start_of_month(date: datetime) -> datetime:
function _shift_n_months (line 255) | def _shift_n_months(date: datetime, n: int) -> datetime:
function _rolling_window (line 267) | def _rolling_window(
function _process_prs (line 297) | def _process_prs(df: pd.DataFrame) -> pd.Series:
function _process_issues (line 306) | def _process_issues(df: pd.DataFrame) -> pd.Series:
function main (line 323) | def main(_):
FILE: benchmarks/nnx_graph_overhead.py
class Linear (line 35) | class Linear(nnx.Module):
method __init__ (line 36) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 40) | def __call__(self, x):
class Block (line 44) | class Block(nnx.Module):
method __init__ (line 45) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 49) | def __call__(self, x):
class Count (line 53) | class Count(nnx.Variable):
class MLP (line 57) | class MLP(nnx.Module):
method __init__ (line 58) | def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
method __call__ (line 66) | def __call__(self, x):
function main (line 75) | def main(argv):
FILE: benchmarks/nnx_mlpmixer_training.py
class MlpBlock (line 41) | class MlpBlock(nnx.Module):
method __init__ (line 42) | def __init__(self, din: int, mlp_dim: int, rngs: nnx.Rngs):
method __call__ (line 47) | def __call__(self, x):
class MixerBlock (line 51) | class MixerBlock(nnx.Module):
method __init__ (line 52) | def __init__(
method __call__ (line 67) | def __call__(self, x):
class MlpMixer (line 77) | class MlpMixer(nnx.Module):
method __init__ (line 78) | def __init__(
method __call__ (line 111) | def __call__(self, *, x, t):
function main (line 129) | def main(argv):
FILE: benchmarks/nnx_simple_training.py
function dataset (line 38) | def dataset(X, Y, batch_size):
class Linear (line 44) | class Linear(nnx.Module):
method __init__ (line 45) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 49) | def __call__(self, x):
class Block (line 53) | class Block(nnx.Module):
method __init__ (line 54) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 58) | def __call__(self, x):
class Count (line 62) | class Count(nnx.Variable):
class MLP (line 66) | class MLP(nnx.Module):
method __init__ (line 67) | def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
method __call__ (line 75) | def __call__(self, x):
function main (line 84) | def main(argv):
FILE: benchmarks/nnx_state_traversal.py
class NestedClass (line 34) | class NestedClass(nnx.Module):
method __init__ (line 35) | def __init__(self, width, depth):
function main (line 42) | def main(argv):
FILE: benchmarks/tracing/gemma.py
function rsqrt_schedule (line 32) | def rsqrt_schedule(init_value: float, shift: int = 0):
function create_learning_rate_schedule (line 38) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
function compute_weighted_cross_entropy (line 52) | def compute_weighted_cross_entropy(
function compute_weighted_accuracy (line 82) | def compute_weighted_accuracy(logits, targets, weights=None):
function compute_metrics (line 97) | def compute_metrics(logits, labels, weights, label_smoothing=0.0):
function train_step (line 110) | def train_step(
function get_fake_batch (line 157) | def get_fake_batch(batch_size: int) -> Any:
function get_apply_fn_and_args (line 172) | def get_apply_fn_and_args(
function test_flax_gemma_trace (line 243) | def test_flax_gemma_trace(state):
function test_flax_gemma_lower (line 251) | def test_flax_gemma_lower(state):
FILE: benchmarks/tracing/imagenet.py
class TrainState (line 35) | class TrainState(train_state.TrainState):
function create_model (line 40) | def create_model(*, model_cls, half_precision, **kwargs):
function initialized (line 52) | def initialized(key, image_size, model):
function cross_entropy_loss (line 63) | def cross_entropy_loss(logits, labels):
function create_train_state (line 69) | def create_train_state(
function get_fake_batch (line 93) | def get_fake_batch(batch_size: int = 128) -> dict[str, jnp.ndarray]:
class BenchmarkResNet (line 103) | class BenchmarkResNet(models.ResNet):
method __call__ (line 106) | def __call__(self, x, train: bool = True):
function get_apply_fn_and_args (line 143) | def get_apply_fn_and_args(
function bench_train_step (line 177) | def bench_train_step(state, batch, learning_rate_fn):
function test_flax_imagenet_trace (line 240) | def test_flax_imagenet_trace(state):
function test_flax_imagenet_lower (line 248) | def test_flax_imagenet_lower(state):
FILE: benchmarks/tracing/lm1b.py
function rsqrt_schedule (line 33) | def rsqrt_schedule(init_value: float, shift: int = 0):
function create_learning_rate_schedule (line 39) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
function compute_weighted_cross_entropy (line 53) | def compute_weighted_cross_entropy(
function compute_weighted_accuracy (line 83) | def compute_weighted_accuracy(logits, targets, weights=None):
function get_fake_batch (line 98) | def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, jnp.n...
function bench_train_step (line 122) | def bench_train_step(state, batch, config, learning_rate_fn):
function get_apply_fn_and_args (line 170) | def get_apply_fn_and_args(
function test_flax_lm1b_trace (line 239) | def test_flax_lm1b_trace(state):
function test_flax_lm1b_lower (line 247) | def test_flax_lm1b_lower(state):
FILE: benchmarks/tracing/mnist.py
class CNN (line 29) | class CNN(nnx.Module):
method __init__ (line 31) | def __init__(self, rngs: nnx.Rngs):
method __call__ (line 42) | def __call__(self, x, rngs: nnx.Rngs):
function loss_fn (line 51) | def loss_fn(model: CNN, batch, rngs):
function get_fake_batch (line 59) | def get_fake_batch(batch_size: int) -> dict[str, Any]:
function get_apply_fn_and_args (line 66) | def get_apply_fn_and_args(
function test_flax_mnist_trace (line 82) | def test_flax_mnist_trace(state):
function test_flax_mnist_lower (line 90) | def test_flax_mnist_lower(state):
FILE: benchmarks/tracing/nlp_seq.py
function create_learning_rate_scheduler (line 33) | def create_learning_rate_scheduler(
function compute_weighted_cross_entropy (line 71) | def compute_weighted_cross_entropy(logits, targets, weights=None):
function compute_weighted_accuracy (line 87) | def compute_weighted_accuracy(logits, targets, weights=None):
function get_fake_batch (line 102) | def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, jnp.n...
function bench_train_step (line 129) | def bench_train_step(state, batch, config, learning_rate_fn):
function get_apply_fn_and_args (line 176) | def get_apply_fn_and_args(
function test_flax_nlp_seq_trace (line 227) | def test_flax_nlp_seq_trace(state):
function test_flax_nlp_seq_lower (line 235) | def test_flax_nlp_seq_lower(state):
FILE: benchmarks/tracing/ogbg_molpcba.py
function create_model (line 33) | def create_model(
function create_optimizer (line 62) | def create_optimizer(
function binary_cross_entropy_with_mask (line 74) | def binary_cross_entropy_with_mask(
function predictions_match_labels (line 88) | def predictions_match_labels(
function replace_globals (line 96) | def replace_globals(graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
function get_predicted_logits (line 100) | def get_predicted_logits(
function get_valid_mask (line 110) | def get_valid_mask(
class TrainMetrics (line 119) | class TrainMetrics(metrics.Collection):
function ogbg_train_step (line 125) | def ogbg_train_step(
function get_fake_graphs (line 157) | def get_fake_graphs(config: ml_collections.ConfigDict) -> jraph.GraphsTu...
function get_apply_fn_and_args (line 184) | def get_apply_fn_and_args(
function test_flax_ogbg_molpcba_trace (line 212) | def test_flax_ogbg_molpcba_trace(state):
function test_flax_ogbg_molpcba_lower (line 220) | def test_flax_ogbg_molpcba_lower(state):
FILE: benchmarks/tracing/ppo.py
function get_fake_batch (line 28) | def get_fake_batch(batch_size: int = 256) -> tuple[jnp.ndarray, ...]:
function get_apply_fn_and_args (line 69) | def get_apply_fn_and_args(
function test_flax_ppo_trace (line 116) | def test_flax_ppo_trace(state):
function test_flax_ppo_lower (line 124) | def test_flax_ppo_lower(state):
FILE: benchmarks/tracing/seq2seq.py
function cross_entropy_loss (line 33) | def cross_entropy_loss(logits, labels, lengths):
function compute_metrics (line 39) | def compute_metrics(logits, labels, eos_id):
function seq2seq_train_step (line 55) | def seq2seq_train_step(state, batch, lstm_rng, eos_id):
function get_fake_batch (line 79) | def get_fake_batch(batch_size: int, ctable: CharacterTable) -> dict[str,...
function get_apply_fn_and_args (line 83) | def get_apply_fn_and_args(
function test_flax_seq2seq_trace (line 114) | def test_flax_seq2seq_trace(state):
function test_flax_seq2seq_lower (line 122) | def test_flax_seq2seq_lower(state):
FILE: benchmarks/tracing/sst2.py
function sigmoid_cross_entropy_with_logits (line 33) | def sigmoid_cross_entropy_with_logits(*, labels: Array, logits: Array) -...
function model_from_config (line 41) | def model_from_config(config: ml_collections.ConfigDict):
function get_initial_params (line 54) | def get_initial_params(rng, model):
function create_train_state (line 61) | def create_train_state(rng, config: ml_collections.ConfigDict, model):
function compute_metrics (line 71) | def compute_metrics(*, labels: Array, logits: Array):
function train_step (line 84) | def train_step(
function get_fake_batch (line 119) | def get_fake_batch(batch_size: int) -> dict[str, Any]:
function get_apply_fn_and_args (line 134) | def get_apply_fn_and_args(
function test_flax_sst2_trace (line 152) | def test_flax_sst2_trace(state):
function test_flax_sst2_lower (line 160) | def test_flax_sst2_lower(state):
FILE: benchmarks/tracing/tracing_benchmark.py
function clear_caches (line 34) | def clear_caches(state):
function benchmark_tracing (line 40) | def benchmark_tracing(
function benchmark_lowering (line 54) | def benchmark_lowering(
function run_single_example (line 70) | def run_single_example(
function run_benchmarks (line 91) | def run_benchmarks() -> None:
FILE: benchmarks/tracing/vae.py
function binary_cross_entropy_with_logits (line 31) | def binary_cross_entropy_with_logits(logits, labels):
function kl_divergence (line 39) | def kl_divergence(mean, logvar):
function train_step (line 43) | def train_step(state, batch, z_rng, latents):
function get_fake_batch (line 57) | def get_fake_batch(batch_size: int) -> Any:
function get_apply_fn_and_args (line 61) | def get_apply_fn_and_args(
function test_flax_vae_trace (line 83) | def test_flax_vae_trace(state):
function test_flax_vae_lower (line 91) | def test_flax_vae_lower(state):
FILE: benchmarks/tracing/wmt.py
class TrainState (line 34) | class TrainState(train_state.TrainState):
function rsqrt_schedule (line 38) | def rsqrt_schedule(init_value: float, shift: int = 0):
function create_learning_rate_schedule (line 44) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
function preferred_dtype (line 58) | def preferred_dtype(config):
function compute_weighted_cross_entropy (line 68) | def compute_weighted_cross_entropy(
function compute_weighted_accuracy (line 98) | def compute_weighted_accuracy(logits, targets, weights=None):
function compute_metrics (line 113) | def compute_metrics(logits, labels, weights, label_smoothing=0.0):
function wmt_train_step (line 126) | def wmt_train_step(
function get_fake_batch (line 202) | def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, Any]:
function get_apply_fn_and_args (line 224) | def get_apply_fn_and_args(
function test_flax_wmt_trace (line 298) | def test_flax_wmt_trace(state):
function test_flax_wmt_lower (line 306) | def test_flax_wmt_lower(state):
FILE: docs/_ext/codediff.py
class CodeDiffParser (line 39) | class CodeDiffParser:
method parse (line 40) | def parse(
method _code_block (line 140) | def _code_block(self, lines):
method _tabs (line 156) | def _tabs(self, *contents: tuple[str, list[str]], sync):
class CodeDiffDirective (line 172) | class CodeDiffDirective(SphinxDirective):
method run (line 182) | def run(self):
function setup (line 214) | def setup(app):
FILE: docs/_ext/codediff_test.py
class CodeDiffTest (line 21) | class CodeDiffTest(parameterized.TestCase):
method test_parse (line 22) | def test_parse(self):
method test_parse_errors (line 114) | def test_parse_errors(self, input_text, title, groups, error_msg):
FILE: docs/_ext/flax_module.py
function render_module (line 36) | def render_module(modname: str, qualname: str, app):
class FlaxModuleDirective (line 59) | class FlaxModuleDirective(SphinxDirective):
method run (line 66) | def run(self):
function setup (line 80) | def setup(app):
FILE: docs/conf_sphinx_patch.py
function generate_autosummary_content (line 32) | def generate_autosummary_content(
FILE: docs_nnx/_ext/codediff.py
class CodeDiffParser (line 39) | class CodeDiffParser:
method parse (line 40) | def parse(
method _code_block (line 140) | def _code_block(self, lines):
method _tabs (line 156) | def _tabs(self, *contents: tuple[str, list[str]], sync):
class CodeDiffDirective (line 172) | class CodeDiffDirective(SphinxDirective):
method run (line 182) | def run(self):
function setup (line 214) | def setup(app):
FILE: docs_nnx/_ext/codediff_test.py
class CodeDiffTest (line 21) | class CodeDiffTest(parameterized.TestCase):
method test_parse (line 22) | def test_parse(self):
method test_parse_errors (line 114) | def test_parse_errors(self, input_text, title, groups, error_msg):
FILE: docs_nnx/_ext/flax_module.py
function render_module (line 36) | def render_module(modname: str, qualname: str, app):
class FlaxModuleDirective (line 59) | class FlaxModuleDirective(SphinxDirective):
method run (line 66) | def run(self):
function setup (line 80) | def setup(app):
FILE: docs_nnx/conf_sphinx_patch.py
function generate_autosummary_content (line 32) | def generate_autosummary_content(
FILE: examples/cloud/launch_gce.py
function generate_startup_file (line 137) | def generate_startup_file(vm_name: str) -> str:
function launch_gce (line 162) | def launch_gce(*, vm_name: str, startup_script: str):
function print_howto (line 204) | def print_howto(login_args: Sequence[str]):
function main (line 232) | def main(_):
FILE: examples/gemma/configs/default.py
class Config (line 23) | class Config:
function get_config (line 132) | def get_config() -> TrainConfig:
FILE: examples/gemma/configs/gemma3_4b.py
class Config (line 23) | class Config:
method replace (line 131) | def replace(self, **kwargs):
function get_config (line 135) | def get_config() -> TrainConfig:
FILE: examples/gemma/configs/small.py
class Config (line 23) | class Config:
method replace (line 152) | def replace(self, **kwargs):
function get_config (line 156) | def get_config() -> TrainConfig:
FILE: examples/gemma/configs/tiny.py
class Config (line 23) | class Config:
method replace (line 143) | def replace(self, **kwargs):
function get_config (line 147) | def get_config() -> TrainConfig:
FILE: examples/gemma/helpers.py
function _flatten_path (line 29) | def _flatten_path(path: tuple[str | int, ...]) -> str:
function module_from_linen_variables (line 41) | def module_from_linen_variables(
FILE: examples/gemma/helpers_test.py
class ModuleFromLinenVariablesTest (line 30) | class ModuleFromLinenVariablesTest(parameterized.TestCase):
method test_same_structure (line 44) | def test_same_structure(self, inputs_shape, num_features, use_bias):
method test_different_structure (line 79) | def test_different_structure(self, inputs_shape, num_features, use_bias):
FILE: examples/gemma/input_pipeline.py
class NormalizeFeatureNamesOp (line 28) | class NormalizeFeatureNamesOp:
method __call__ (line 31) | def __call__(self, features: Features) -> Features:
function get_raw_dataset (line 38) | def get_raw_dataset(dataset_name: str, split: str) -> tf.data.Dataset:
function pack_dataset (line 56) | def pack_dataset(
function _pack_with_tf_ops (line 139) | def _pack_with_tf_ops(
function shift_data_by_truncation (line 260) | def shift_data_by_truncation(x):
function preprocess_data (line 270) | def preprocess_data(
function get_datasets (line 324) | def get_datasets(
FILE: examples/gemma/input_pipeline_test.py
class InputPipelineTest (line 32) | class InputPipelineTest(absltest.TestCase):
method setUp (line 34) | def setUp(self):
method _get_datasets (line 40) | def _get_datasets(self):
method test_train_ds (line 62) | def test_train_ds(self):
method test_eval_ds (line 79) | def test_eval_ds(self):
FILE: examples/gemma/layers.py
class Einsum (line 31) | class Einsum(nnx.Module):
method __init__ (line 34) | def __init__(
method __call__ (line 46) | def __call__(self, x: ArrayLike) -> Array:
method shape (line 50) | def shape(self) -> Shape:
class RMSNorm (line 54) | class RMSNorm(nnx.Module):
method __init__ (line 57) | def __init__(
method __call__ (line 67) | def __call__(self, x: Array) -> Array:
FILE: examples/gemma/layers_test.py
class EinsumTest (line 25) | class EinsumTest(parameterized.TestCase):
method test_einsum (line 40) | def test_einsum(self, inputs_shape, params_shape, eqn, expected_shape):
method test_shape (line 55) | def test_shape(self, shape):
class RMSNormTest (line 60) | class RMSNormTest(parameterized.TestCase):
method test_rmsnorm (line 62) | def test_rmsnorm(self, x, expected):
FILE: examples/gemma/main.py
function main (line 43) | def main(argv):
FILE: examples/gemma/modules.py
class AttentionType (line 39) | class AttentionType(enum.Enum):
class Embedder (line 44) | class Embedder(nnx.Module):
method __init__ (line 47) | def __init__(
method encode (line 60) | def encode(self, x: ArrayLike) -> Array:
method decode (line 65) | def decode(self, x: ArrayLike) -> Array:
method embed_dim (line 69) | def embed_dim(self):
method num_embed (line 73) | def num_embed(self):
class Attention (line 77) | class Attention(nnx.Module):
method __init__ (line 80) | def __init__(
method __call__ (line 173) | def __call__(
method head_dim (line 282) | def head_dim(self):
method num_heads (line 286) | def num_heads(self):
method num_kv_heads (line 294) | def num_kv_heads(self):
method use_qkv_einsum (line 302) | def use_qkv_einsum(self):
method init_cache (line 305) | def init_cache(
class FeedForward (line 324) | class FeedForward(nnx.Module):
method __init__ (line 327) | def __init__(
method __call__ (line 363) | def __call__(self, x: ArrayLike) -> Array:
class Block (line 375) | class Block(nnx.Module):
method __init__ (line 378) | def __init__(
method __call__ (line 507) | def __call__(
method init_cache (line 538) | def init_cache(
function maybe_with_partitioning (line 551) | def maybe_with_partitioning(fn, axis_rules, axis_rules_args=()):
FILE: examples/gemma/modules_test.py
class EmbedderTest (line 27) | class EmbedderTest(parameterized.TestCase):
method test_encode (line 37) | def test_encode(self, vocab_size, embed_dim, inputs, expected):
method test_decode (line 55) | def test_decode(self, vocab_size, embed_dim, inputs, expected):
class AttentionTest (line 66) | class AttentionTest(parameterized.TestCase):
method test_head_dim (line 76) | def test_head_dim(self, head_dim):
method test_use_qkv_einsum (line 101) | def test_use_qkv_einsum(
method test_attention (line 133) | def test_attention(
method test_sliding_window (line 170) | def test_sliding_window(self, sliding_window_size):
class FeedForwardTest (line 211) | class FeedForwardTest(parameterized.TestCase):
method test_ffw (line 222) | def test_ffw(
class BlockTest (line 243) | class BlockTest(parameterized.TestCase):
method test_block (line 258) | def test_block(
method test_post_attention_norm (line 313) | def test_post_attention_norm(
method test_post_ffw_norm (line 386) | def test_post_ffw_norm(
FILE: examples/gemma/params.py
function load_and_format_params (line 28) | def load_and_format_params(path: str) -> Params:
function load_metadata (line 37) | def load_metadata(path: str) -> Any | None:
function load_params (line 45) | def load_params(path: str) -> Params:
function param_remapper (line 52) | def param_remapper(orig_params: Params) -> Params:
function nest_params (line 77) | def nest_params(params: Params) -> Params:
FILE: examples/gemma/positional_embeddings.py
function add_positional_embedding (line 23) | def add_positional_embedding(
function apply_rope (line 45) | def apply_rope(
FILE: examples/gemma/positional_embeddings_test.py
class PositionalEmbeddingsTest (line 29) | class PositionalEmbeddingsTest(parameterized.TestCase):
method test_adds_positional_embeddings (line 40) | def test_adds_positional_embeddings(
method test_rope_positional_embeddings (line 66) | def test_rope_positional_embeddings(
FILE: examples/gemma/sampler.py
function _sample_top_p (line 38) | def _sample_top_p(probs: jnp.ndarray, p: float, key: jax.Array) -> jnp.n...
function _compute_attention_masks (line 53) | def _compute_attention_masks(
class _SamplingState (line 81) | class _SamplingState:
class SamplerOutput (line 125) | class SamplerOutput:
class Sampler (line 141) | class Sampler:
method __init__ (line 144) | def __init__(
method transformer (line 168) | def transformer(self) -> transformer_lib.Transformer:
method transformer_state (line 172) | def transformer_state(self) -> statelib.State:
method transformer_state (line 176) | def transformer_state(self, state: statelib.State) -> statelib.State:
method dtype (line 201) | def dtype(self) -> jnp.dtype:
method _sample_step (line 206) | def _sample_step(
method init_sample_state (line 289) | def init_sample_state(
method tokenize (line 354) | def tokenize(self, input_string: str) -> jax.Array:
method mask_tokens_after_eos_ids (line 362) | def mask_tokens_after_eos_ids(self, token_buffer):
method _sample_fn (line 378) | def _sample_fn(
method __call__ (line 397) | def __call__(
FILE: examples/gemma/sampler_test.py
class MockVocab (line 35) | class MockVocab(spm.SentencePieceProcessor):
method __init__ (line 37) | def __init__(self):
method pad_id (line 58) | def pad_id(self) -> int:
method bos_id (line 61) | def bos_id(self) -> int:
method eos_id (line 64) | def eos_id(self) -> int:
method GetPieceSize (line 67) | def GetPieceSize(self) -> int: # pylint: disable=invalid-name
method DecodeIds (line 70) | def DecodeIds(self, ids: Iterable[int]) -> str: # pylint: disable=inv...
method EncodeAsIds (line 74) | def EncodeAsIds(self, text: str) -> list[int]: # pylint: disable=inva...
class SamplerTest (line 79) | class SamplerTest(parameterized.TestCase):
method assertReasonableTensor (line 81) | def assertReasonableTensor(self, array, expected_shape=None):
method test_samples (line 86) | def test_samples(self):
method test_state_update (line 134) | def test_state_update(self):
method test_invalid_state_update (line 172) | def test_invalid_state_update(self):
method test_forbidden_tokens (line 216) | def test_forbidden_tokens(self):
method test_forward_equivalence (line 264) | def test_forward_equivalence(self):
method test_sampler_init_sample_state (line 324) | def test_sampler_init_sample_state(self):
method test_sampler_mask_tokens_after_eos_ids (line 365) | def test_sampler_mask_tokens_after_eos_ids(self):
method test_sampler_sows_intermediates (line 409) | def test_sampler_sows_intermediates(self):
method test_compute_attention_mask (line 497) | def test_compute_attention_mask(self):
method test_models_from_kaggle (line 530) | def test_models_from_kaggle(self, url):
FILE: examples/gemma/sow_lib.py
class LayerIntermediates (line 25) | class LayerIntermediates:
method merge (line 38) | def merge(self, decoding_step, layer: nnx.Module):
method trim (line 70) | def trim(self, max_length: int):
class TransformerIntermediates (line 80) | class TransformerIntermediates:
method merge (line 89) | def merge(self, decoding_step, transformer: nnx.Module):
method trim (line 109) | def trim(self, max_length: int):
class SowConfig (line 118) | class SowConfig:
method maybe_sow_embeddings (line 139) | def maybe_sow_embeddings(
method maybe_sow_rs_after_attention (line 148) | def maybe_sow_rs_after_attention(
method maybe_sow_rs_after_ffw (line 157) | def maybe_sow_rs_after_ffw(
method maybe_sow_mlp_hidden_topk (line 166) | def maybe_sow_mlp_hidden_topk(
method maybe_sow_attn_logits_topk (line 178) | def maybe_sow_attn_logits_topk(
FILE: examples/gemma/tokenizer.py
function _dump_chars_to_textfile (line 38) | def _dump_chars_to_textfile(
function _train_sentencepiece (line 67) | def _train_sentencepiece(
function _load_sentencepiece_tokenizer (line 145) | def _load_sentencepiece_tokenizer(
function load_or_train_tokenizer (line 160) | def load_or_train_tokenizer(
class TokenizeOp (line 184) | class TokenizeOp:
method __call__ (line 188) | def __call__(self, features: Features) -> Features:
function load_sentencepiece_processor (line 194) | def load_sentencepiece_processor(vocab_path: str):
FILE: examples/gemma/train.py
class MeshRules (line 47) | class MeshRules:
method __call__ (line 53) | def __call__(self, *keys: str) -> tuple[str, ...]:
class TrainConfig (line 60) | class TrainConfig:
method replace (line 152) | def replace(self, **kwargs):
method __post_init__ (line 155) | def __post_init__(self):
function rsqrt_schedule (line 160) | def rsqrt_schedule(
function create_learning_rate_schedule (line 183) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
function compute_weighted_cross_entropy (line 198) | def compute_weighted_cross_entropy(
function compute_weighted_accuracy (line 240) | def compute_weighted_accuracy(logits, targets, weights=None):
function compute_metrics (line 265) | def compute_metrics(logits, labels, weights, label_smoothing=0.0):
function train_step (line 283) | def train_step(
function eval_step (line 341) | def eval_step(
function evaluate (line 368) | def evaluate(
function train_and_evaluate (line 393) | def train_and_evaluate(config: TrainConfig, workdir: str):
FILE: examples/gemma/transformer.py
function make_attention_layers_types (line 37) | def make_attention_layers_types(
class QueryPreAttentionNormalisation (line 50) | class QueryPreAttentionNormalisation(enum.Enum):
class TransformerConfig (line 83) | class TransformerConfig:
method query_pre_attn_scalar (line 111) | def query_pre_attn_scalar(self) -> float:
method from_path (line 122) | def from_path(cls, path: str) -> TransformerConfig:
method from_params (line 129) | def from_params(cls, params: params_lib.Params) -> TransformerConfig:
method from_version_name (line 185) | def from_version_name(cls, name: str, **override) -> TransformerConfig:
method from_dict (line 206) | def from_dict(cls, **config: Any) -> TransformerConfig:
method gemma_2b (line 215) | def gemma_2b(cls, **override) -> TransformerConfig:
method gemma_7b (line 235) | def gemma_7b(cls, **override):
method gemma2_2b (line 255) | def gemma2_2b(cls, **override):
method gemma2_9b (line 282) | def gemma2_9b(cls, **override):
method gemma2_27b (line 307) | def gemma2_27b(cls, **override):
method gemma3_1b (line 332) | def gemma3_1b(cls, **override):
method gemma3_4b (line 361) | def gemma3_4b(cls, **override):
method gemma3_12b (line 391) | def gemma3_12b(cls, **override):
method gemma3_27b (line 421) | def gemma3_27b(cls, **override):
method __post_init__ (line 450) | def __post_init__(self):
function _map_linen_var_names (line 459) | def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]:
function _assign_linen_params_to_nnx_state (line 480) | def _assign_linen_params_to_nnx_state(
class Transformer (line 497) | class Transformer(nnx.Module):
method from_params (line 501) | def from_params(
method __init__ (line 522) | def __init__(
method __call__ (line 563) | def __call__(
method embed_dim (line 612) | def embed_dim(self) -> int:
method num_embed (line 616) | def num_embed(self) -> int:
method num_layers (line 620) | def num_layers(self) -> int:
method init_cache (line 623) | def init_cache(
method init_intermediates (line 639) | def init_intermediates(
function make_causal_attn_mask (line 693) | def make_causal_attn_mask(
function build_positions_from_mask (line 713) | def build_positions_from_mask(input_mask: Array) -> Array:
FILE: examples/gemma/transformer_test.py
function create_fake_params (line 28) | def create_fake_params(config: transformer_lib.TransformerConfig):
class TransformerTest (line 86) | class TransformerTest(parameterized.TestCase):
method test_transformer (line 133) | def test_transformer(
method test_logit_softcap (line 186) | def test_logit_softcap(
method test_creates_cache (line 274) | def test_creates_cache(self, config, cache_size, keys, k_shape, v_shape):
method test_forward_no_cache (line 306) | def test_forward_no_cache(
method test_attention_types (line 340) | def test_attention_types(
method test_load_from_params (line 414) | def test_load_from_params(self, config):
method test_sow_intermediates (line 432) | def test_sow_intermediates(self, sow_config):
FILE: examples/gemma/utils.py
class TrainState (line 36) | class TrainState(train_state.TrainState):
function create_device_mesh (line 44) | def create_device_mesh(config: Any):
function fill_unspecified_mesh_axes (line 98) | def fill_unspecified_mesh_axes(
function _to_array (line 131) | def _to_array(x):
function setup_initial_state (line 137) | def setup_initial_state(
FILE: examples/imagenet/configs/default.py
function get_config (line 19) | def get_config():
function metrics (line 52) | def metrics():
FILE: examples/imagenet/configs/fake_data_benchmark.py
function get_config (line 22) | def get_config():
FILE: examples/imagenet/configs/tpu.py
function get_config (line 19) | def get_config():
FILE: examples/imagenet/configs/v100_x8.py
function get_config (line 20) | def get_config():
FILE: examples/imagenet/configs/v100_x8_mixed_precision.py
function get_config (line 20) | def get_config():
FILE: examples/imagenet/imagenet_benchmark.py
class ImagenetBenchmark (line 37) | class ImagenetBenchmark(Benchmark):
method _test_8x_v100_half_precision (line 41) | def _test_8x_v100_half_precision(
method test_8x_v100_half_precision_short (line 76) | def test_8x_v100_half_precision_short(self):
method test_8x_v100_half_precision_full (line 88) | def test_8x_v100_half_precision_full(self):
FILE: examples/imagenet/imagenet_fake_data_benchmark.py
class ImagenetBenchmarkFakeData (line 37) | class ImagenetBenchmarkFakeData(Benchmark):
method test_fake_data (line 40) | def test_fake_data(self):
FILE: examples/imagenet/input_pipeline.py
function distorted_bounding_box_crop (line 28) | def distorted_bounding_box_crop(
function _resize (line 80) | def _resize(image, image_size):
function _at_least_x_are_equal (line 86) | def _at_least_x_are_equal(a, b, x):
function _decode_and_random_crop (line 93) | def _decode_and_random_crop(image_bytes, image_size):
function _decode_and_center_crop (line 116) | def _decode_and_center_crop(image_bytes, image_size):
function normalize_image (line 144) | def normalize_image(image):
function preprocess_for_train (line 150) | def preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE...
function preprocess_for_eval (line 169) | def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_...
function create_split (line 187) | def create_split(
FILE: examples/imagenet/main.py
function main (line 43) | def main(argv):
FILE: examples/imagenet/models.py
class ResNetBlock (line 30) | class ResNetBlock(nn.Module):
method __call__ (line 40) | def __call__(
class BottleneckResNetBlock (line 60) | class BottleneckResNetBlock(nn.Module):
method __call__ (line 70) | def __call__(self, x):
class ResNet (line 90) | class ResNet(nn.Module):
method __call__ (line 102) | def __call__(self, x, train: bool = True):
FILE: examples/imagenet/models_test.py
class ResNetTest (line 29) | class ResNetTest(parameterized.TestCase):
method test_resnet_model (line 32) | def test_resnet_model(self):
method test_resnet_18_model (line 46) | def test_resnet_18_model(self, model):
FILE: examples/imagenet/train.py
function create_model (line 50) | def create_model(*, model_cls, half_precision, **kwargs):
function initialized (line 62) | def initialized(key, image_size, model):
function cross_entropy_loss (line 73) | def cross_entropy_loss(logits, labels):
function compute_metrics (line 79) | def compute_metrics(logits, labels):
function create_learning_rate_fn (line 90) | def create_learning_rate_fn(
function train_step (line 112) | def train_step(state, batch, learning_rate_fn):
function eval_step (line 174) | def eval_step(state, batch):
function prepare_tf_data (line 180) | def prepare_tf_data(xs):
function create_input_iter (line 195) | def create_input_iter(
class TrainState (line 220) | class TrainState(train_state.TrainState):
function restore_checkpoint (line 225) | def restore_checkpoint(state, workdir):
function save_checkpoint (line 229) | def save_checkpoint(state, workdir):
function create_train_state (line 243) | def create_train_state(
function train_and_evaluate (line 268) | def train_and_evaluate(
FILE: examples/imagenet/train_test.py
class TrainTest (line 37) | class TrainTest(parameterized.TestCase):
method setUp (line 39) | def setUp(self):
method test_create_model (line 44) | def test_create_model(self):
method test_create_model_local (line 53) | def test_create_model_local(self):
method test_train_and_evaluate (line 68) | def test_train_and_evaluate(self, model):
FILE: examples/linen_design_test/attention_simple.py
class Dense (line 27) | class Dense(Module):
method __call__ (line 36) | def __call__(self, inputs):
class SoftmaxAttn (line 55) | class SoftmaxAttn(Module):
method __call__ (line 58) | def __call__(self, weights):
class Dropout (line 63) | class Dropout(Module):
method __call__ (line 67) | def __call__(self, x, deterministic=False, rng=None):
class SoftmaxAttnWDropout (line 81) | class SoftmaxAttnWDropout(Module):
method __call__ (line 86) | def __call__(self, x):
class RawDotProductAttention (line 92) | class RawDotProductAttention(Module):
method __call__ (line 96) | def __call__(self, query, key, value, bias=None, dtype=jnp.float32):
class DotProductAttention (line 115) | class DotProductAttention(Module):
method __call__ (line 121) | def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):
function concise_vmap (line 142) | def concise_vmap(module, in_axes, out_axes, axis_size=None, **var_specs):
class MultiHeadDotProductAttention (line 157) | class MultiHeadDotProductAttention(Module):
method __call__ (line 166) | def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):
FILE: examples/linen_design_test/autoencoder.py
class MLP (line 26) | class MLP(Module):
method __call__ (line 30) | def __call__(self, x):
class AutoEncoder (line 38) | class AutoEncoder(Module):
method setup (line 43) | def setup(self):
method __call__ (line 48) | def __call__(self, x):
method encode (line 51) | def encode(self, x):
method decode (line 55) | def decode(self, z):
FILE: examples/linen_design_test/dense.py
class Dense (line 21) | class Dense(Module):
method __call__ (line 28) | def __call__(self, inputs):
FILE: examples/linen_design_test/linear_regression.py
function predict (line 27) | def predict(params):
function loss_fn (line 32) | def loss_fn(params):
function init_params (line 37) | def init_params(rng):
FILE: examples/linen_design_test/mlp_explicit.py
class DenseExplicit (line 26) | class DenseExplicit(Dense):
method setup (line 29) | def setup(self):
class MLP (line 40) | class MLP(Module):
method setup (line 42) | def setup(self):
method __call__ (line 52) | def __call__(self, x):
FILE: examples/linen_design_test/mlp_inline.py
class MLP (line 25) | class MLP(Module):
method __call__ (line 29) | def __call__(self, x):
FILE: examples/linen_design_test/mlp_lazy.py
class MLP (line 25) | class MLP(Module):
method setup (line 27) | def setup(self):
method __call__ (line 35) | def __call__(self, x):
FILE: examples/lm1b/configs/default.py
function get_config (line 20) | def get_config():
FILE: examples/lm1b/input_pipeline.py
class NormalizeFeatureNamesOp (line 31) | class NormalizeFeatureNamesOp:
method __init__ (line 34) | def __init__(self, ds_info: tfds.core.DatasetInfo):
method __call__ (line 37) | def __call__(self, features: Features) -> Features:
function get_raw_dataset (line 44) | def get_raw_dataset(
function pack_dataset (line 69) | def pack_dataset(
function _pack_with_tf_ops (line 152) | def _pack_with_tf_ops(
function preprocess_data (line 276) | def preprocess_data(
function get_datasets (line 321) | def get_datasets(
FILE: examples/lm1b/input_pipeline_test.py
class InputPipelineTest (line 33) | class InputPipelineTest(absltest.TestCase):
method setUp (line 35) | def setUp(self):
method _get_datasets (line 41) | def _get_datasets(self):
method test_train_ds (line 63) | def test_train_ds(self):
method test_eval_ds (line 80) | def test_eval_ds(self):
method test_predict_ds (line 91) | def test_predict_ds(self):
FILE: examples/lm1b/main.py
function main (line 43) | def main(argv):
FILE: examples/lm1b/models.py
class TransformerConfig (line 37) | class TransformerConfig:
function shift_right (line 60) | def shift_right(x, axis=1):
function shift_inputs (line 70) | def shift_inputs(x, segment_ids=None, axis=1):
function sinusoidal_init (line 80) | def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0):
class AddPositionEmbs (line 108) | class AddPositionEmbs(nn.Module):
method __call__ (line 120) | def __call__(self, inputs, inputs_positions=None):
class MlpBlock (line 171) | class MlpBlock(nn.Module):
method __call__ (line 183) | def __call__(self, inputs):
class EncoderDecoder1DBlock (line 213) | class EncoderDecoder1DBlock(nn.Module):
method __call__ (line 223) | def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None):
class Decoder (line 281) | class Decoder(nn.Module):
method __call__ (line 293) | def __call__(
class TransformerLM (line 377) | class TransformerLM(nn.Module):
method __call__ (line 387) | def __call__(self, inputs, inputs_positions=None, inputs_segmentation=...
FILE: examples/lm1b/temperature_sampler.py
function temperature_sample (line 27) | def temperature_sample(
FILE: examples/lm1b/temperature_sampler_test.py
class TestTemperatureSampler (line 26) | class TestTemperatureSampler(absltest.TestCase):
method test_temperature_sampler (line 28) | def test_temperature_sampler(self):
FILE: examples/lm1b/tokenizer.py
function _dump_chars_to_textfile (line 35) | def _dump_chars_to_textfile(
function _train_sentencepiece (line 64) | def _train_sentencepiece(
function _load_sentencepiece_tokenizer (line 123) | def _load_sentencepiece_tokenizer(
function load_or_train_tokenizer (line 138) | def load_or_train_tokenizer(
class TokenizeOp (line 162) | class TokenizeOp:
method __call__ (line 166) | def __call__(self, features: Features) -> Features:
FILE: examples/lm1b/train.py
function rsqrt_schedule (line 47) | def rsqrt_schedule(
function create_learning_rate_schedule (line 70) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
function compute_weighted_cross_entropy (line 85) | def compute_weighted_cross_entropy(
function compute_weighted_accuracy (line 127) | def compute_weighted_accuracy(logits, targets, weights=None):
function compute_metrics (line 152) | def compute_metrics(logits, labels, weights, label_smoothing=0.0):
function train_step (line 170) | def train_step(
function eval_step (line 220) | def eval_step(params, batch, config, label_smoothing=0.0):
function predict_step (line 229) | def predict_step(
function pad_examples (line 270) | def pad_examples(x, desired_batch_size):
function tohost (line 276) | def tohost(x):
function evaluate (line 282) | def evaluate(
function generate_prediction (line 308) | def generate_prediction(
function train_and_evaluate (line 362) | def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
FILE: examples/lm1b/train_test.py
class TrainTest (line 32) | class TrainTest(absltest.TestCase):
method setUp (line 35) | def setUp(self):
method test_train_and_evaluate (line 41) | def test_train_and_evaluate(self):
FILE: examples/lm1b/utils.py
function create_device_mesh (line 33) | def create_device_mesh(config):
function fill_unspecified_mesh_axes (line 79) | def fill_unspecified_mesh_axes(
function unbox_logicallypartioned_trainstate (line 112) | def unbox_logicallypartioned_trainstate(
function init_train_state (line 130) | def init_train_state(model, tx, config, key):
function setup_initial_state (line 149) | def setup_initial_state(model, tx, config, rng, mesh):
FILE: examples/mnist/configs/default.py
function get_config (line 20) | def get_config():
function metrics (line 31) | def metrics():
FILE: examples/mnist/main.py
function main (line 43) | def main(argv):
FILE: examples/mnist/mnist_benchmark.py
class MnistBenchmark (line 35) | class MnistBenchmark(Benchmark):
method test_cpu (line 39) | def test_cpu(self):
FILE: examples/mnist/train.py
class CNN (line 39) | class CNN(nnx.Module):
method __init__ (line 42) | def __init__(self, rngs: nnx.Rngs):
method __call__ (line 53) | def __call__(self, x, rngs: nnx.Rngs):
function loss_fn (line 63) | def loss_fn(model: CNN, batch, rngs):
function train_step (line 72) | def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiM...
function eval_step (line 83) | def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
function get_datasets (line 88) | def get_datasets(
function train_and_evaluate (line 122) | def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) ...
FILE: examples/mnist/train_test.py
class TrainTest (line 36) | class TrainTest(absltest.TestCase):
method setUp (line 39) | def setUp(self):
method test_cnn (line 46) | def test_cnn(self):
method test_train_and_evaluate (line 55) | def test_train_and_evaluate(self):
FILE: examples/nlp_seq/configs/default.py
function get_config (line 20) | def get_config():
FILE: examples/nlp_seq/input_pipeline.py
class CoNLLAttributes (line 35) | class CoNLLAttributes(enum.Enum):
function create_vocabs (line 57) | def create_vocabs(filename, max_num_forms=100000):
function create_token (line 106) | def create_token(token, attributes, vocabs):
function create_sentence_with_root (line 142) | def create_sentence_with_root(attributes, vocabs):
function sentences_from_conll_data (line 163) | def sentences_from_conll_data(
function sentence_dataset_dict (line 199) | def sentence_dataset_dict(
FILE: examples/nlp_seq/input_pipeline_test.py
class InputPipelineTest (line 44) | class InputPipelineTest(absltest.TestCase):
method setUp (line 46) | def setUp(self):
method test_vocab_creation (line 57) | def test_vocab_creation(self):
method testInputBatch (line 74) | def testInputBatch(self):
method testInputTargetBatch (line 103) | def testInputTargetBatch(self):
FILE: examples/nlp_seq/main.py
function main (line 36) | def main(argv):
FILE: examples/nlp_seq/models.py
class TransformerConfig (line 27) | class TransformerConfig:
function sinusoidal_init (line 46) | def sinusoidal_init(max_len=2048):
class AddPositionEmbs (line 73) | class AddPositionEmbs(nn.Module):
method __call__ (line 83) | def __call__(self, inputs):
class MlpBlock (line 116) | class MlpBlock(nn.Module):
method __call__ (line 128) | def __call__(self, inputs, deterministic=True):
class Encoder1DBlock (line 152) | class Encoder1DBlock(nn.Module):
method __call__ (line 162) | def __call__(self, inputs, deterministic):
class Transformer (line 198) | class Transformer(nn.Module):
method __call__ (line 204) | def __call__(self, *, inputs, train):
FILE: examples/nlp_seq/train.py
function create_learning_rate_scheduler (line 82) | def create_learning_rate_scheduler(
function compute_weighted_cross_entropy (line 142) | def compute_weighted_cross_entropy(logits, targets, weights=None):
function compute_weighted_accuracy (line 168) | def compute_weighted_accuracy(logits, targets, weights=None):
function compute_metrics (line 193) | def compute_metrics(logits, labels, weights):
function train_step (line 206) | def train_step(state, batch, model, learning_rate_fn, dropout_rng=None):
function pad_examples (line 239) | def pad_examples(x, desired_batch_size):
function main (line 246) | def main(argv):
FILE: examples/nnx_toy_examples/01_functional_api.py
function dataset (line 27) | def dataset(batch_size):
class Linear (line 33) | class Linear(nnx.Module):
method __init__ (line 34) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 38) | def __call__(self, x):
class Count (line 42) | class Count(nnx.Variable[nnx.A]):
class MLP (line 46) | class MLP(nnx.Module):
method __init__ (line 47) | def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
method __call__ (line 52) | def __call__(self, x):
function train_step (line 63) | def train_step(params, counts, batch):
function test_step (line 81) | def test_step(params: nnx.State, counts: nnx.State, batch):
FILE: examples/nnx_toy_examples/02_lifted_transforms.py
function dataset (line 28) | def dataset(batch_size):
class Linear (line 34) | class Linear(nnx.Module):
method __init__ (line 35) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 39) | def __call__(self, x):
class Count (line 43) | class Count(nnx.Variable):
class MLP (line 47) | class MLP(nnx.Module):
method __init__ (line 48) | def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
method __call__ (line 53) | def __call__(self, x):
function train_step (line 67) | def train_step(model: MLP, optimizer: nnx.Optimizer, batch):
function test_step (line 79) | def test_step(model: MLP, batch):
FILE: examples/nnx_toy_examples/03_train_state.py
function dataset (line 29) | def dataset(batch_size):
class Linear (line 35) | class Linear(nnx.Module):
method __init__ (line 36) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 40) | def __call__(self, x):
class Count (line 44) | class Count(nnx.Variable[nnx.A]):
class MLP (line 48) | class MLP(nnx.Module):
method __init__ (line 49) | def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
method __call__ (line 54) | def __call__(self, x):
class TrainState (line 61) | class TrainState(train_state.TrainState):
function train_step (line 79) | def train_step(state: TrainState, batch):
function test_step (line 97) | def test_step(state: nnx.TrainState[MLP], batch):
FILE: examples/nnx_toy_examples/04_data_parallel_with_jit.py
class MLP (line 36) | class MLP(nnx.Module):
method __init__ (line 37) | def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):
method __call__ (line 41) | def __call__(self, x):
function train_step (line 59) | def train_step(model: MLP, optimizer: nnx.Optimizer, x, y):
function dataset (line 69) | def dataset(steps, batch_size):
FILE: examples/nnx_toy_examples/05_vae.py
class Loss (line 46) | class Loss(nnx.Variable):
class Encoder (line 51) | class Encoder(nnx.Module):
method __init__ (line 52) | def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 58) | def __call__(self, x: jax.Array) -> jax.Array:
class Decoder (line 76) | class Decoder(nnx.Module):
method __init__ (line 77) | def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 81) | def __call__(self, z: jax.Array) -> jax.Array:
class VAE (line 88) | class VAE(nnx.Module):
method __init__ (line 89) | def __init__(
method __call__ (line 104) | def __call__(self, x: jax.Array) -> jax.Array:
method generate (line 110) | def generate(self, z):
function train_step (line 129) | def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array):
function forward (line 147) | def forward(model: VAE, x: jax.Array) -> jax.Array:
function sample (line 153) | def sample(model: VAE, z: jax.Array) -> jax.Array:
FILE: examples/nnx_toy_examples/06_scan_over_layers.py
class Block (line 22) | class Block(nnx.Module):
method __init__ (line 23) | def __init__(self, dim: int, *, rngs: nnx.Rngs):
method __call__ (line 28) | def __call__(self, x: jax.Array):
class ScanMLP (line 32) | class ScanMLP(nnx.Module):
method __init__ (line 39) | def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs):
method __call__ (line 49) | def __call__(self, x: jax.Array) -> jax.Array:
FILE: examples/nnx_toy_examples/07_array_leaves.py
function dataset (line 28) | def dataset(batch_size):
class Linear (line 33) | class Linear(nnx.Module):
method __init__ (line 34) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 38) | def __call__(self, x):
class MLP (line 42) | class MLP(nnx.Module):
method __init__ (line 43) | def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
method __call__ (line 48) | def __call__(self, x):
function is_param (line 52) | def is_param(path, value):
function train_step (line 62) | def train_step(model: MLP, optimizer: nnx.Optimizer, batch):
function test_step (line 75) | def test_step(model: MLP, batch):
FILE: examples/nnx_toy_examples/08_save_load_checkpoints.py
class MLP (line 24) | class MLP(nnx.Module):
method __init__ (line 25) | def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 29) | def __call__(self, x: jax.Array) -> jax.Array:
function create_model (line 36) | def create_model(seed: int):
function create_and_save (line 40) | def create_and_save(seed: int, path: str):
function load_model (line 48) | def load_model(path: str) -> MLP:
FILE: examples/nnx_toy_examples/09_parameter_surgery.py
function load_pretrained (line 22) | def load_pretrained():
class Classifier (line 27) | class Classifier(nnx.Module):
method __init__ (line 28) | def __init__(self, *, rngs: nnx.Rngs):
method __call__ (line 32) | def __call__(self, x):
FILE: examples/nnx_toy_examples/10_fsdp_and_optimizer.py
function named_sharding (line 34) | def named_sharding(*names: str | None) -> NamedSharding:
class MeshRules (line 39) | class MeshRules:
method __call__ (line 44) | def __call__(self, *keys: str) -> tuple[str, ...]:
class MLP (line 55) | class MLP(nnx.Module):
method __init__ (line 56) | def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
method __call__ (line 70) | def __call__(self, x: jax.Array):
class SGDState (line 74) | class SGDState(nnx.Variable):
class SGD (line 78) | class SGD(nnx.Pytree):
method __init__ (line 79) | def __init__(self, params: nnx.State, lr, decay=0.9):
method update (line 94) | def update(self, grads: nnx.State):
function create_model (line 113) | def create_model():
function train_step (line 128) | def train_step(model: MLP, optimizer: SGD, x, y):
function dataset (line 143) | def dataset(batch_size, num_steps):
FILE: examples/nnx_toy_examples/hijax_basic.py
function dataset (line 28) | def dataset(batch_size):
class Linear (line 34) | class Linear(nnx.Module):
method __init__ (line 35) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 39) | def __call__(self, x):
class Count (line 43) | class Count(nnx.Variable[nnx.A]):
class MLP (line 47) | class MLP(nnx.Module):
method __init__ (line 48) | def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
method __call__ (line 53) | def __call__(self, x):
function train_step (line 64) | def train_step(model, optimizer, x, y):
function test_step (line 75) | def test_step(model: MLP, x, y):
FILE: examples/nnx_toy_examples/hijax_demo.py
function dataset (line 29) | def dataset(batch_size):
class Linear (line 43) | class Linear(nnx.Module):
method __init__ (line 44) | def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
method __call__ (line 52) | def __call__(self, x: jax.Array):
class Block (line 58) | class Block(nnx.Module):
method __init__ (line 59) | def __init__(
method __call__ (line 86) | def __call__(
class Model (line 116) | class Model(nnx.Module):
method __init__ (line 117) | def __init__(
method __call__ (line 148) | def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None):
class OptState (line 171) | class OptState(nnx.Variable): ...
class SGD (line 179) | class SGD(nnx.Pytree):
method __init__ (line 180) | def __init__(self, params, lr: float, decay: float = 0.9):
method update (line 195) | def update(self, params, grads):
function train_step (line 231) | def train_step(model: Model, optimizer: SGD, rngs: nnx.Rngs, x, y):
function test_step (line 249) | def test_step(model: Model, x, y):
FILE: examples/ogbg_molpcba/configs/default.py
function get_config (line 23) | def get_config():
FILE: examples/ogbg_molpcba/configs/default_graph_net.py
function get_config (line 23) | def get_config():
FILE: examples/ogbg_molpcba/configs/hparam_sweep.py
function get_config (line 20) | def get_config():
function sweep (line 51) | def sweep(add):
FILE: examples/ogbg_molpcba/configs/test.py
function get_config (line 20) | def get_config():
FILE: examples/ogbg_molpcba/input_pipeline.py
class GraphsTupleSize (line 26) | class GraphsTupleSize(NamedTuple):
function get_raw_datasets (line 34) | def get_raw_datasets() -> dict[str, tf.data.Dataset]:
function get_datasets (line 43) | def get_datasets(
function convert_to_graphs_tuple (line 111) | def convert_to_graphs_tuple(
function estimate_padding_budget_for_batch_size (line 168) | def estimate_padding_budget_for_batch_size(
function specs_from_graphs_tuple (line 213) | def specs_from_graphs_tuple(graph: jraph.GraphsTuple):
function get_graphs_tuple_size (line 236) | def get_graphs_tuple_size(graph: jraph.GraphsTuple):
FILE: examples/ogbg_molpcba/input_pipeline_test.py
function get_dummy_datasets (line 24) | def get_dummy_datasets(dataset_length: int):
class InputPipelineTest (line 53) | class InputPipelineTest(parameterized.TestCase):
method setUp (line 55) | def setUp(self):
method test_estimate_padding_budget_valid (line 63) | def test_estimate_padding_budget_valid(self, valid_batch_size):
method test_estimate_padding_budget_invalid (line 72) | def test_estimate_padding_budget_invalid(self, invalid_batch_size):
FILE: examples/ogbg_molpcba/main.py
function main (line 43) | def main(argv):
FILE: examples/ogbg_molpcba/models.py
function add_graphs_tuples (line 24) | def add_graphs_tuples(
class MLP (line 35) | class MLP(nn.Module):
method __call__ (line 44) | def __call__(self, inputs):
class GraphNet (line 55) | class GraphNet(nn.Module):
method __call__ (line 69) | def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
class GraphConvNet (line 137) | class GraphConvNet(nn.Module):
method pool (line 153) | def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
method __call__ (line 173) | def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
FILE: examples/ogbg_molpcba/models_test.py
class ModelsTest (line 26) | class ModelsTest(parameterized.TestCase):
method setUp (line 28) | def setUp(self):
method test_mlp (line 53) | def test_mlp(self, dropout_rate, output_size, num_layers):
method test_graph_net (line 91) | def test_graph_net(
method test_graph_conv_net (line 125) | def test_graph_conv_net(self, latent_size: int, output_globals_size: i...
FILE: examples/ogbg_molpcba/ogbg_molpcba_benchmark.py
class OgbgMolpcbaBenchmark (line 36) | class OgbgMolpcbaBenchmark(Benchmark):
method test_1x_v100 (line 39) | def test_1x_v100(self):
method test_cpu (line 92) | def test_cpu(self):
FILE: examples/ogbg_molpcba/train.py
function create_model (line 44) | def create_model(
function create_optimizer (line 74) | def create_optimizer(
function binary_cross_entropy_with_mask (line 87) | def binary_cross_entropy_with_mask(
function predictions_match_labels (line 106) | def predictions_match_labels(
function add_prefix_to_keys (line 115) | def add_prefix_to_keys(result: dict[str, Any], prefix: str) -> dict[str,...
class MeanAveragePrecision (line 121) | class MeanAveragePrecision(
method compute (line 126) | def compute(self):
class EvalMetrics (line 156) | class EvalMetrics(metrics.Collection):
class TrainMetrics (line 163) | class TrainMetrics(metrics.Collection):
function replace_globals (line 168) | def replace_globals(graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
function get_predicted_logits (line 173) | def get_predicted_logits(
function get_valid_mask (line 184) | def get_valid_mask(
function train_step (line 203) | def train_step(
function evaluate_step (line 240) | def evaluate_step(
function evaluate_model (line 266) | def evaluate_model(
function train_and_evaluate (line 292) | def train_and_evaluate(
FILE: examples/ogbg_molpcba/train_test.py
function average_with_mask (line 40) | def average_with_mask(arr: jnp.ndarray, mask: jnp.ndarray):
function get_dummy_raw_datasets (line 46) | def get_dummy_raw_datasets(dataset_length) -> dict[str, tf.data.Dataset]:
function get_dummy_datasets (line 81) | def get_dummy_datasets(
class OgbgMolpcbaTrainTest (line 134) | class OgbgMolpcbaTrainTest(parameterized.TestCase):
method setUp (line 136) | def setUp(self):
method test_binary_cross_entropy_loss (line 161) | def test_binary_cross_entropy_loss(self, probs, labels):
method test_mean_average_precision (line 199) | def test_mean_average_precision(self, logits, labels, expected_result):
method test_eval_metrics (line 226) | def test_eval_metrics(self, loss, logits, labels, mask, expected_resul...
method test_train_metrics (line 251) | def test_train_metrics(self, loss, logits, labels, mask, expected_resu...
method test_train_step (line 263) | def test_train_step(self):
method test_evaluate_step (line 304) | def test_evaluate_step(self):
method test_train_and_evaluate (line 337) | def test_train_and_evaluate(self):
FILE: examples/ppo/agent.py
function policy_action (line 31) | def policy_action(
class RemoteSimulator (line 55) | class RemoteSimulator:
method __init__ (line 61) | def __init__(self, game: str):
function rcv_action_send_exp (line 72) | def rcv_action_send_exp(conn, game: str):
FILE: examples/ppo/configs/default.py
function get_config (line 20) | def get_config():
FILE: examples/ppo/env_utils.py
class ClipRewardEnv (line 25) | class ClipRewardEnv(gym.RewardWrapper):
method __init__ (line 31) | def __init__(self, env):
method reward (line 34) | def reward(self, reward):
class FrameStack (line 39) | class FrameStack:
method __init__ (line 45) | def __init__(
method reset (line 54) | def reset(self):
method step (line 60) | def step(self, action: int):
method _get_array (line 65) | def _get_array(self):
function create_env (line 70) | def create_env(game: str, clip_rewards: bool):
function get_num_actions (line 80) | def get_num_actions(game: str):
FILE: examples/ppo/models.py
class ActorCritic (line 21) | class ActorCritic(nn.Module):
method __call__ (line 27) | def __call__(self, x):
FILE: examples/ppo/ppo_lib.py
function gae_advantages (line 39) | def gae_advantages(
function loss_fn (line 81) | def loss_fn(
function train_step (line 134) | def train_step(
function get_experience (line 180) | def get_experience(
function process_experience (line 217) | def process_experience(
function get_initial_params (line 271) | def get_initial_params(key: jax.Array, model: nn.Module):
function create_train_state (line 278) | def create_train_state(
function train (line 299) | def train(
FILE: examples/ppo/ppo_lib_test.py
class TestGAE (line 35) | class TestGAE(absltest.TestCase):
method test_gae_shape_on_random (line 37) | def test_gae_shape_on_random(self):
method test_gae_hardcoded (line 52) | def test_gae_hardcoded(self):
class TestEnvironmentPreprocessing (line 68) | class TestEnvironmentPreprocessing(absltest.TestCase):
method choose_random_game (line 70) | def choose_random_game(self):
method test_creation (line 82) | def test_creation(self):
method test_step (line 89) | def test_step(self):
class TestModel (line 104) | class TestModel(absltest.TestCase):
method choose_random_outputs (line 106) | def choose_random_outputs(self):
method test_model (line 109) | def test_model(self):
class TestOptimizationStep (line 125) | class TestOptimizationStep(absltest.TestCase):
method generate_random_data (line 127) | def generate_random_data(self, num_actions):
method test_optimization_step (line 137) | def test_optimization_step(self):
FILE: examples/ppo/ppo_main.py
function main (line 48) | def main(argv):
FILE: examples/ppo/seed_rl_atari_preprocessing.py
class AtariPreprocessing (line 39) | class AtariPreprocessing:
method __init__ (line 54) | def __init__(
method observation_space (line 104) | def observation_space(self):
method action_space (line 115) | def action_space(self):
method reward_range (line 119) | def reward_range(self):
method metadata (line 123) | def metadata(self):
method close (line 126) | def close(self):
method apply_random_noops (line 129) | def apply_random_noops(self):
method reset (line 141) | def reset(self):
method render (line 155) | def render(self, mode):
method step (line 169) | def step(self, action):
method _fetch_grayscale_observation (line 214) | def _fetch_grayscale_observation(self, output):
method _pool_and_resize (line 225) | def _pool_and_resize(self):
FILE: examples/ppo/test_episodes.py
function policy_test (line 28) | def policy_test(
FILE: examples/seq2seq/configs/default.py
function get_config (line 20) | def get_config():
FILE: examples/seq2seq/input_pipeline.py
class CharacterTable (line 27) | class CharacterTable:
method __init__ (line 30) | def __init__(self, chars: str, max_len_query_digit: int = 3) -> None:
method pad_id (line 39) | def pad_id(self) -> int:
method eos_id (line 43) | def eos_id(self) -> int:
method vocab_size (line 47) | def vocab_size(self) -> int:
method max_input_len (line 52) | def max_input_len(self) -> int:
method max_output_len (line 59) | def max_output_len(self) -> int:
method encoder_input_shape (line 67) | def encoder_input_shape(self) -> tuple[int, int, int]:
method decoder_input_shape (line 71) | def decoder_input_shape(self) -> tuple[int, int, int]:
method encode (line 74) | def encode(self, inputs: str) -> np.ndarray:
method decode (line 80) | def decode(self, inputs: Array) -> str:
method one_hot (line 89) | def one_hot(self, tokens: np.ndarray) -> np.ndarray:
method encode_onehot (line 94) | def encode_onehot(
method decode_onehot (line 112) | def decode_onehot(self, batch_inputs: Array) -> np.ndarray:
method generate_examples (line 117) | def generate_examples(
method get_batch (line 130) | def get_batch(self, batch_size: int) -> dict[str, np.ndarray]:
function mask_sequences (line 139) | def mask_sequences(sequence_batch: Array, lengths: Array) -> Array:
function get_sequence_lengths (line 146) | def get_sequence_lengths(sequence_batch: Array, eos_id: int) -> Array:
FILE: examples/seq2seq/main.py
function main (line 33) | def main(argv):
FILE: examples/seq2seq/models.py
class DecoderLSTMCell (line 31) | class DecoderLSTMCell(nn.RNNCellBase):
method __call__ (line 44) | def __call__(
method num_feature_axes (line 65) | def num_feature_axes(self) -> int:
class Seq2seq (line 69) | class Seq2seq(nn.Module):
method __call__ (line 88) | def __call__(
method get_seq_lengths (line 131) | def get_seq_lengths(self, inputs: Array) -> Array:
FILE: examples/seq2seq/train.py
function get_model (line 74) | def get_model(ctable: CTable, *, teacher_force: bool = False) -> models....
function get_initial_params (line 83) | def get_initial_params(
function get_train_state (line 96) | def get_train_state(rng: PRNGKey, ctable: CTable) -> train_state.TrainSt...
function cross_entropy_loss (line 107) | def cross_entropy_loss(
function compute_metrics (line 116) | def compute_metrics(
function train_step (line 137) | def train_step(
function log_decode (line 164) | def log_decode(question: str, inferred: str, golden: str):
function decode (line 173) | def decode(
function decode_batch (line 188) | def decode_batch(
function train_and_evaluate (line 206) | def train_and_evaluate(workdir: str) -> train_state.TrainState:
function main (line 226) | def main(_):
FILE: examples/seq2seq/train_test.py
function create_ctable (line 33) | def create_ctable(chars='0123456789+= '):
function create_train_state (line 37) | def create_train_state(ctable):
class TrainTest (line 51) | class TrainTest(absltest.TestCase):
method test_character_table (line 53) | def test_character_table(self):
method test_mask_sequences (line 62) | def test_mask_sequences(self):
method test_get_sequence_lengths (line 70) | def test_get_sequence_lengths(self):
method test_train_one_step (line 87) | def test_train_one_step(self):
method test_decode_batch (line 98) | def test_decode_batch(self):
FILE: examples/sst2/build_vocabulary.py
function get_tokenized_sequences (line 28) | def get_tokenized_sequences(
FILE: examples/sst2/configs/default.py
function get_config (line 20) | def get_config():
FILE: examples/sst2/input_pipeline.py
function get_bucket_boundaries (line 34) | def get_bucket_boundaries(bucket_size: int, max_size: int) -> np.ndarray:
function get_num_examples (line 56) | def get_num_examples(dataset: tf.data.Dataset) -> int:
function get_bucketed_batches (line 61) | def get_bucketed_batches(
function vocab_to_hashtable (line 136) | def vocab_to_hashtable(
function vocab_to_inverse_hashtable (line 148) | def vocab_to_inverse_hashtable(
function _is_text_field (line 163) | def _is_text_field(feature_name_and_type):
function _is_class_label (line 169) | def _is_class_label(feature_name_and_type):
class TextDataset (line 175) | class TextDataset:
method __init__ (line 178) | def __init__(
method padded_shapes (line 208) | def padded_shapes(self):
method example_length_fn (line 213) | def example_length_fn(self, example: Example) -> tf.Tensor:
method add_bos_eos (line 217) | def add_bos_eos(self, sequence: tf.Tensor) -> tf.Tensor:
method prepare_example (line 221) | def prepare_example(self, example: Example) -> Example:
method get_batches (line 232) | def get_batches(
method get_bucketed_batches (line 256) | def get_bucketed_batches(
FILE: examples/sst2/input_pipeline_test.py
class InputPipelineTest (line 27) | class InputPipelineTest(absltest.TestCase):
method setUp (line 29) | def setUp(self):
method _get_vocab_path (line 36) | def _get_vocab_path(self):
method _get_dataset (line 48) | def _get_dataset(self, vocab_path):
method test_bucketed_dataset (line 56) | def test_bucketed_dataset(self):
method test_batched_dataset (line 72) | def test_batched_dataset(self):
method test_batched_dataset_fixed_length (line 85) | def test_batched_dataset_fixed_length(self):
FILE: examples/sst2/main.py
function main (line 43) | def main(argv):
FILE: examples/sst2/models.py
function sequence_mask (line 28) | def sequence_mask(lengths: Array, max_length: int) -> Array:
function flip_sequences (line 50) | def flip_sequences(inputs: Array, lengths: Array) -> Array:
class WordDropout (line 83) | class WordDropout(nn.Module):
method __call__ (line 95) | def __call__(self, inputs: Array, deterministic: bool | None = None):
class Embedder (line 106) | class Embedder(nn.Module):
method setup (line 129) | def setup(self):
method __call__ (line 141) | def __call__(
class SimpleLSTM (line 167) | class SimpleLSTM(nn.Module):
method __call__ (line 180) | def __call__(self, carry, x):
method initialize_carry (line 183) | def initialize_carry(self, input_shape):
class SimpleBiLSTM (line 190) | class SimpleBiLSTM(nn.Module):
method setup (line 195) | def setup(self):
method __call__ (line 199) | def __call__(self, embedded_inputs, lengths):
class MLP (line 219) | class MLP(nn.Module):
method setup (line 238) | def setup(self):
method __call__ (line 243) | def __call__(self, inputs: Array, deterministic: bool | None = None):
class KeysOnlyMlpAttention (line 263) | class KeysOnlyMlpAttention(nn.Module):
method __call__ (line 283) | def __call__(self, keys: Array, mask: Array) -> Array:
class AttentionClassifier (line 309) | class AttentionClassifier(nn.Module):
method setup (line 325) | def setup(self):
method __call__ (line 337) | def __call__(
class TextClassifier (line 376) | class TextClassifier(nn.Module):
method setup (line 389) | def setup(self):
method embed_token_ids (line 404) | def embed_token_ids(
method logits_from_embedded_inputs (line 412) | def logits_from_embedded_inputs(
method __call__ (line 424) | def __call__(
FILE: examples/sst2/models_test.py
class ModelTest (line 28) | class ModelTest(parameterized.TestCase):
method test_embedder_returns_correct_output_shape (line 30) | def test_embedder_returns_correct_output_shape(self):
method test_lstm_returns_correct_output_shape (line 42) | def test_lstm_returns_correct_output_shape(self):
method test_bilstm_returns_correct_output_shape (line 57) | def test_bilstm_returns_correct_output_shape(self):
method test_text_classifier_returns_correct_output_shape (line 73) | def test_text_classifier_returns_correct_output_shape(self):
FILE: examples/sst2/train.py
class Metrics (line 39) | class Metrics(struct.PyTreeNode):
function sigmoid_cross_entropy_with_logits (line 48) | def sigmoid_cross_entropy_with_logits(*, labels: Array, logits: Array) -...
function get_initial_params (line 57) | def get_initial_params(rng, model):
function create_train_state (line 65) | def create_train_state(rng, config: ml_collections.ConfigDict, model):
function compute_metrics (line 76) | def compute_metrics(*, labels: Array, logits: Array) -> Metrics:
function model_from_config (line 90) | def model_from_config(config: ml_collections.ConfigDict):
function train_step (line 104) | def train_step(
function eval_step (line 141) | def eval_step(
function normalize_batch_metrics (line 157) | def normalize_batch_metrics(batch_metrics: Sequence[Metrics]) -> Metrics:
function batch_to_numpy (line 169) | def batch_to_numpy(batch: dict[str, Array]) -> dict[str, Array]:
function evaluate_model (line 176) | def evaluate_model(
function train_epoch (line 204) | def train_epoch(
function train_and_evaluate (line 232) | def train_and_evaluate(
FILE: examples/sst2/train_test.py
class TrainTest (line 32) | class TrainTest(parameterized.TestCase):
method test_train_step_updates_parameters (line 34) | def test_train_step_updates_parameters(self):
FILE: examples/sst2/vocabulary.py
class Vocabulary (line 24) | class Vocabulary:
method __init__ (line 27) | def __init__(
method build (line 57) | def build(
method _getitem__ (line 90) | def _getitem__(self, key: str):
method keys (line 93) | def keys(self):
method values (line 96) | def values(self):
method __len__ (line 99) | def __len__(self):
method pad_idx (line 103) | def pad_idx(self):
method unk_idx (line 108) | def unk_idx(self):
method bos_idx (line 113) | def bos_idx(self):
method eos_idx (line 118) | def eos_idx(self):
method load (line 122) | def load(self, path: str) -> None:
method save (line 132) | def save(self, path: str) -> None:
FILE: examples/vae/configs/default.py
function get_config (line 20) | def get_config():
FILE: examples/vae/input_pipeline.py
function build_train_set (line 23) | def build_train_set(batch_size, ds_builder):
function build_test_set (line 36) | def build_test_set(ds_builder):
function prepare_image (line 45) | def prepare_image(x):
FILE: examples/vae/main.py
function main (line 44) | def main(argv):
FILE: examples/vae/models.py
class Encoder (line 22) | class Encoder(nn.Module):
method __call__ (line 28) | def __call__(self, x):
class Decoder (line 36) | class Decoder(nn.Module):
method __call__ (line 40) | def __call__(self, z):
class VAE (line 47) | class VAE(nn.Module):
method setup (line 52) | def setup(self):
method __call__ (line 56) | def __call__(self, x, z_rng):
method generate (line 62) | def generate(self, z):
function reparameterize (line 66) | def reparameterize(rng, mean, logvar):
function model (line 72) | def model(latents):
FILE: examples/vae/train.py
function kl_divergence (line 33) | def kl_divergence(mean, logvar):
function binary_cross_entropy_with_logits (line 38) | def binary_cross_entropy_with_logits(logits, labels):
function compute_metrics (line 45) | def compute_metrics(recon_x, x, mean, logvar):
function train_step (line 51) | def train_step(state, batch, z_rng, latents):
function eval_f (line 67) | def eval_f(params, images, z, z_rng, latents):
function train_and_evaluate (line 84) | def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
FILE: examples/vae/utils.py
function save_image (line 28) | def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format_img...
FILE: examples/wmt/bleu.py
class UnicodeRegex (line 49) | class UnicodeRegex:
method __init__ (line 52) | def __init__(self):
method property_chars (line 58) | def property_chars(self, prefix):
function bleu_tokenize (line 69) | def bleu_tokenize(string):
function _get_ngrams (line 98) | def _get_ngrams(segment, max_order):
function compute_bleu_matches (line 117) | def compute_bleu_matches(reference_corpus, translation_corpus, max_order...
function bleu_partial (line 165) | def bleu_partial(ref_lines, hyp_lines, case_sensitive=False):
function complete_bleu (line 179) | def complete_bleu(
function bleu_local (line 221) | def bleu_local(ref_lines, hyp_lines, case_sensitive=False):
FILE: examples/wmt/configs/default.py
function get_config (line 20) | def get_config():
function metrics (line 119) | def metrics():
FILE: examples/wmt/decode.py
function brevity_penalty (line 33) | def brevity_penalty(alpha, length):
function add_beam_dim (line 49) | def add_beam_dim(x, beam_size):
function flatten_beam_dim (line 59) | def flatten_beam_dim(x):
function unflatten_beam_dim (line 66) | def unflatten_beam_dim(x, batch_size, beam_size):
function flat_batch_beam_expand (line 74) | def flat_batch_beam_expand(x, beam_size):
function gather_beams (line 79) | def gather_beams(nested, beam_indices, batch_size, new_beam_size):
function gather_topk_beams (line 106) | def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_si...
class BeamState (line 129) | class BeamState:
function beam_init (line 147) | def beam_init(batch_size, beam_size, max_decode_len, cache):
function beam_search (line 175) | def beam_search(
FILE: examples/wmt/input_pipeline.py
class NormalizeFeatureNamesOp (line 32) | class NormalizeFeatureNamesOp:
method __init__ (line 35) | def __init__(self, ds_info: tfds.core.DatasetInfo, reverse_translation...
method __call__ (line 40) | def __call__(self, features: Features) -> Features:
function get_raw_dataset (line 46) | def get_raw_dataset(
function pack_dataset (line 79) | def pack_dataset(
function _pack_with_tf_ops (line 162) | def _pack_with_tf_ops(
function preprocess_wmt_data (line 286) | def preprocess_wmt_data(
function get_wmt_datasets (line 331) | def get_wmt_datasets(
FILE: examples/wmt/input_pipeline_test.py
class InputPipelineTest (line 34) | class InputPipelineTest(absltest.TestCase):
method setUp (line 36) | def setUp(self):
method _get_datasets (line 42) | def _get_datasets(self):
method test_train_ds (line 63) | def test_train_ds(self):
method test_eval_ds (line 80) | def test_eval_ds(self):
method test_predict_ds (line 91) | def test_predict_ds(self):
FILE: examples/wmt/main.py
function main (line 44) | def main(argv):
FILE: examples/wmt/models.py
class TransformerConfig (line 34) | class TransformerConfig:
function shift_right (line 57) | def shift_right(x, axis=1):
function sinusoidal_init (line 67) | def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0):
class AddPositionEmbs (line 95) | class AddPositionEmbs(nn.Module):
method __call__ (line 107) | def __call__(self, inputs, inputs_positions=None):
class MlpBlock (line 158) | class MlpBlock(nn.Module):
method __call__ (line 170) | def __call__(self, inputs):
class Encoder1DBlock (line 196) | class Encoder1DBlock(nn.Module):
method __call__ (line 206) | def __call__(self, inputs, encoder_mask=None):
class EncoderDecoder1DBlock (line 245) | class EncoderDecoder1DBlock(nn.Module):
method __call__ (line 255) | def __call__(
class Encoder (line 317) | class Encoder(nn.Module):
method __call__ (line 329) | def __call__(self, inputs, inputs_positions=None, encoder_mask=None):
class Decoder (line 374) | class Decoder(nn.Module):
method __call__ (line 386) | def __call__(
class Transformer (line 463) | class Transformer(nn.Module):
method setup (line 472) | def setup(self):
method encode (line 495) | def encode(self, inputs, inputs_positions=None, inputs_segmentation=No...
method decode (line 526) | def decode(
method __call__ (line 596) | def __call__(
FILE: examples/wmt/tokenizer.py
function _dump_chars_to_textfile (line 35) | def _dump_chars_to_textfile(
function _train_sentencepiece (line 64) | def _train_sentencepiece(
function _load_sentencepiece_tokenizer (line 123) | def _load_sentencepiece_tokenizer(
function load_or_train_tokenizer (line 138) | def load_or_train_tokenizer(
class TokenizeOp (line 162) | class TokenizeOp:
method __call__ (line 166) | def __call__(self, features: Features) -> Features:
FILE: examples/wmt/train.py
class TrainState (line 50) | class TrainState(train_state.TrainState):
function rsqrt_schedule (line 54) | def rsqrt_schedule(
function create_learning_rate_schedule (line 77) | def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
function compute_weighted_cross_entropy (line 92) | def compute_weighted_cross_entropy(
function compute_weighted_accuracy (line 134) | def compute_weighted_accuracy(logits, targets, weights=None):
function compute_metrics (line 159) | def compute_metrics(logits, labels, weights, label_smoothing=0.0):
function train_step (line 178) | def train_step(
function eval_step (line 267) | def eval_step(params, batch, config, label_smoothing=0.0):
function initialize_cache (line 276) | def initialize_cache(inputs, max_decode_len, config):
function predict_step (line 287) | def predict_step(
function pad_examples (line 344) | def pad_examples(x, desired_batch_size):
function per_host_sum_pmap (line 350) | def per_host_sum_pmap(in_tree):
function tohost (line 373) | def tohost(x):
function evaluate (line 379) | def evaluate(
function translate_and_calculate_bleu (line 401) | def translate_and_calculate_bleu(
function preferred_dtype (line 455) | def preferred_dtype(config):
function train_and_evaluate (line 465) | def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
FILE: examples/wmt/train_test.py
class TrainTest (line 32) | class TrainTest(absltest.TestCase):
method setUp (line 35) | def setUp(self):
method test_train_and_evaluate (line 41) | def test_train_and_evaluate(self):
FILE: flax/configurations.py
class Config (line 24) | class Config:
method __init__ (line 36) | def __init__(self):
method _add_option (line 39) | def _add_option(self, name, default):
method _read (line 44) | def _read(self, name):
method update (line 51) | def update(self, name: str, value: Any, /) -> None:
method update (line 55) | def update(self, holder: 'FlagHolder[_T]', value: _T, /) -> None:
method update (line 58) | def update(self, name_or_holder, value, /):
method __repr__ (line 73) | def __repr__(self):
method temp_flip_flag (line 78) | def temp_flip_flag(self, var_name: str, var_value: bool):
class FlagHolder (line 98) | class FlagHolder(Generic[_T]):
method __init__ (line 99) | def __init__(self, name, help):
method __bool__ (line 104) | def __bool__(self) -> NoReturn:
method value (line 111) | def value(self) -> _T:
function bool_flag (line 115) | def bool_flag(name: str, *, default: bool, help: str) -> FlagHolder[bool]:
function int_flag (line 147) | def int_flag(name: str, *, default: int | None, help: str) -> FlagHolder...
function static_bool_env (line 179) | def static_bool_env(varname: str, default: bool) -> bool:
function static_int_env (line 206) | def static_int_env(varname: str, default: int | None) -> int | None:
FILE: flax/core/axes_scan.py
class _Broadcast (line 31) | class _Broadcast:
function build_shaped_array (line 38) | def build_shaped_array(x, batch_dim: bool = False) -> core.ShapedArray:
function scan (line 60) | def scan(
FILE: flax/core/frozen_dict.py
class FrozenKeysView (line 27) | class FrozenKeysView(collections.abc.KeysView):
method __repr__ (line 30) | def __repr__(self):
class FrozenValuesView (line 34) | class FrozenValuesView(collections.abc.ValuesView):
method __repr__ (line 37) | def __repr__(self):
function _indent (line 45) | def _indent(x, num_spaces):
class FrozenDict (line 54) | class FrozenDict(Mapping[K, V]):
method __init__ (line 59) | def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # py...
method __getitem__ (line 69) | def __getitem__(self, key):
method __setitem__ (line 75) | def __setitem__(self, key, value):
method __contains__ (line 78) | def __contains__(self, key):
method __iter__ (line 81) | def __iter__(self):
method __len__ (line 84) | def __len__(self):
method __repr__ (line 87) | def __repr__(self):
method __reduce__ (line 90) | def __reduce__(self):
method get (line 93) | def get(self, key, default=None):
method pretty_repr (line 99) | def pretty_repr(self, num_spaces=4):
method __hash__ (line 115) | def __hash__(self):
method copy (line 123) | def copy(
method keys (line 129) | def keys(self):
method values (line 132) | def values(self):
method items (line 135) | def items(self):
method pop (line 139) | def pop(self, key: K) -> tuple['FrozenDict[K, V]', V]:
method unfreeze (line 159) | def unfreeze(self) -> dict[K, V]:
method tree_flatten_with_keys (line 167) | def tree_flatten_with_keys(self) -> tuple[tuple[Any, ...], Hashable]:
method tree_unflatten (line 179) | def tree_unflatten(cls, keys, values):
function _prepare_freeze (line 185) | def _prepare_freeze(xs: Any) -> Any:
function freeze (line 198) | def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]:
function unfreeze (line 211) | def unfreeze(x: FrozenDict | dict[str, Any]) -> dict[Any, Any]:
function copy (line 237) | def copy(
function pop (line 269) | def pop(
function pretty_repr (line 300) | def pretty_repr(x: Any, num_spaces: int = 4) -> str:
function _frozen_dict_state_dict (line 331) | def _frozen_dict_state_dict(xs):
function _restore_frozen_dict (line 341) | def _restore_frozen_dict(xs, states):
FILE: flax/core/lift.py
class TransformContext (line 60) | class TransformContext(Generic[A], threading.local):
method push (line 66) | def push(self, a: A):
method get (line 73) | def get(self) -> A:
function tree_map_rngs (line 77) | def tree_map_rngs(fn, tree):
function _dedup_scopes (line 87) | def _dedup_scopes(scopes):
function _dup_scopes (line 109) | def _dup_scopes(orig_scopes, scopes, paths):
function _transpose (line 121) | def _transpose(xs):
function _partial_pack (line 125) | def _partial_pack(
function pack (line 281) | def pack(
function map_variables (line 340) | def map_variables(
function swap_collection (line 409) | def swap_collection(fn: Callable[..., Any], col_a: str, col_b: str):
function _split_in_out_axes (line 421) | def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]):
function _bwd_wrapper (line 428) | def _bwd_wrapper(treedef, bwd_fn, tangent):
function vjp (line 434) | def vjp(
function value_and_grad (line 530) | def value_and_grad(
function jvp (line 619) | def jvp(
function vmap (line 710) | def vmap(
function scan (line 872) | def scan(
function while_loop (line 1072) | def while_loop(
function cond (line 1168) | def cond(
function switch (line 1232) | def switch(
function custom_vjp (line 1318) | def custom_vjp(
function checkpoint (line 1425) | def checkpoint(
function _hashable_filter (line 1510) | def _hashable_filter(x):
class CountsHolder (line 1523) | class CountsHolder:
method __init__ (line 1525) | def __init__(self, flat_d):
method make (line 1529) | def make(cls, d):
method sub (line 1534) | def sub(self, other):
method add (line 1542) | def add(self, other):
method unflat (line 1550) | def unflat(self):
function set_from_dict (line 1554) | def set_from_dict(original, updates):
class _SideEffectCache (line 1565) | class _SideEffectCache(threading.local):
method __init__ (line 1567) | def __init__(self):
function _restore_rng_counters (line 1574) | def _restore_rng_counters(scopes, fingerprint, capture_old_counts):
function jit (line 1598) | def jit(
function remat_scan (line 1716) | def remat_scan(
function _unzip2 (line 1793) | def _unzip2(xs):
function _broadcast_prefix_tree (line 1798) | def _broadcast_prefix_tree(prefix_tree: Any, full_tree: Any) -> list[Any]:
function fold_rngs (line 1810) | def fold_rngs(
FILE: flax/core/meta.py
class AxisMetadata (line 39) | class AxisMetadata(Generic[A], metaclass=abc.ABCMeta):
method unbox (line 58) | def unbox(self) -> A:
method replace_boxed (line 75) | def replace_boxed(self, val: B) -> 'AxisMetadata[B]':
method add_axis (line 88) | def add_axis(
method remove_axis (line 110) | def remove_axis(
function is_axis_metadata (line 132) | def is_axis_metadata(val: Any) -> bool:
function map_axis_meta (line 137) | def map_axis_meta(fn: Callable[[AxisMetadata[Any]], Any], tree: Any) -> ...
function add_axis (line 149) | def add_axis(tree: Any, index: int, params: dict[Any, Any]) -> Any:
function remove_axis (line 154) | def remove_axis(tree: Any, index: int, params: dict[Any, Any]) -> Any:
function unbox (line 159) | def unbox(tree: Any) -> Any:
function replace_boxed (line 164) | def replace_boxed(tree: Any, updates: Any) -> Any:
function get_global_mesh (line 181) | def get_global_mesh() -> jax.sharding.AbstractMesh | jax.sharding.Mesh |...
function global_mesh_defined (line 188) | def global_mesh_defined() -> bool:
class Partitioned (line 194) | class Partitioned(struct.PyTreeNode, AxisMetadata[A]):
method unbox (line 256) | def unbox(self, apply_constraint=True) -> A:
method replace_boxed (line 267) | def replace_boxed(self, val: B) -> 'Partitioned[B]':
method _get_partition_name (line 270) | def _get_partition_name(self, params: dict[Any, Any]) -> str:
method add_axis (line 275) | def add_axis(self, index: int, params: dict[Any, Any]) -> 'Partitioned...
method remove_axis (line 283) | def remove_axis(self, index: int, params: dict[Any, Any]) -> 'Partitio...
method get_partition_spec (line 289) | def get_partition_spec(self) -> jax.sharding.PartitionSpec:
method get_sharding (line 293) | def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
method to_nnx_metadata (line 297) | def to_nnx_metadata(self) -> dict[str, Any]:
method from_nnx_metadata (line 304) | def from_nnx_metadata(cls, metadata: dict[str, Any]):
function with_partitioning (line 311) | def with_partitioning(
function _get_leaf_pspec (line 342) | def _get_leaf_pspec(x: Any) -> jax.sharding.PartitionSpec | None:
function get_partition_spec (line 352) | def get_partition_spec(tree: Any) -> Any:
function get_sharding (line 359) | def get_sharding(tree: Any, mesh: jax.sharding.Mesh) -> Any:
FILE: flax/core/nn/attention.py
function dot_product_attention (line 34) | def dot_product_attention(
function _invert_perm (line 156) | def _invert_perm(perm):
class CacheEntry (line 163) | class CacheEntry(struct.PyTreeNode):
function multi_head_dot_product_attention (line 169) | def multi_head_dot_product_attention(
function make_padding_mask (line 420) | def make_padding_mask(
function _make_causal_mask (line 476) | def _make_causal_mask(key, attention_axis=None, self_mask=False):
FILE: flax/core/nn/linear.py
function _normalize_axes (line 30) | def _normalize_axes(axes, ndim):
function dense_general (line 35) | def dense_general(
function dense (line 134) | def dense(
function _conv_dimension_numbers (line 174) | def _conv_dimension_numbers(input_shape):
function conv (line 183) | def conv(
function conv_transpose (line 261) | def conv_transpose(
class Embedding (line 331) | class Embedding:
method lookup (line 334) | def lookup(self, indices):
method attend (line 348) | def attend(self, query):
function embedding (line 364) | def embedding(
FILE: flax/core/nn/normalization.py
function _absolute_dims (line 24) | def _absolute_dims(ndim, dims):
function batch_norm (line 28) | def batch_norm(
function layer_norm (line 86) | def layer_norm(
function group_norm (line 129) | def group_norm(
FILE: flax/core/nn/stochastic.py
function dropout (line 21) | def dropout(scope, inputs, rate, deterministic=False, rng=None):
FILE: flax/core/partial_eval.py
function _maybe_unknown (line 26) | def _maybe_unknown(x: Any) -> pe.PartialVal:
function lazy_init (line 33) | def lazy_init(fn):
FILE: flax/core/scope.py
class DenyList (line 70) | class DenyList:
method __lt__ (line 82) | def __lt__(self, other):
method __gt__ (line 89) | def __gt__(self, other):
class LazyRng (line 101) | class LazyRng(struct.PyTreeNode):
method as_jax_rng (line 107) | def as_jax_rng(self) -> PRNGKey:
method create (line 111) | def create(
method clear_suffix (line 119) | def clear_suffix(self):
function _fold_in_static (line 124) | def _fold_in_static(
function is_filter_empty (line 157) | def is_filter_empty(filter_like: Filter) -> bool:
function in_filter (line 181) | def in_filter(filter_like: Filter, col: str) -> bool:
function filter_to_set (line 207) | def filter_to_set(x: Filter) -> set[str]:
function union_filters (line 226) | def union_filters(a: Filter, b: Filter) -> Filter:
function subtract_filters (line 251) | def subtract_filters(a: Filter, b: Filter) -> Filter:
function intersect_filters (line 276) | def intersect_filters(a: Filter, b: Filter) -> Filter:
function group_collections (line 302) | def group_collections(
class Variable (line 335) | class Variable(Generic[T]):
method __init__ (line 343) | def __init__(self, scope: 'Scope', collection: str, name: str, unbox: ...
method value (line 359) | def value(self) -> T:
method value (line 365) | def value(self, value: T):
method is_mutable (line 379) | def is_mutable(self) -> bool:
class _ChildRNGSentinel (line 384) | class _ChildRNGSentinel:
class _DefaultSentinel (line 392) | class _DefaultSentinel:
function _put_variable (line 402) | def _put_variable(target, key, val):
class Scope (line 414) | class Scope:
method __init__ (line 428) | def __init__(
method __eq__ (line 470) | def __eq__(self, other: Any) -> bool:
method __hash__ (line 484) | def __hash__(self) -> int:
method root (line 489) | def root(self) -> 'Scope':
method path_text (line 493) | def path_text(self) -> str:
method invalid (line 498) | def invalid(self) -> bool:
method _check_valid (line 502) | def _check_valid(self):
method temporary (line 507) | def temporary(self):
method invalidate (line 514) | def invalidate(self):
method mutable_variables (line 518) | def mutable_variables(self) -> VariableDict | dict[str, Any]:
method variables (line 528) | def variables(self) -> VariableDict | dict[str, Any]:
method _validate_trace_level (line 535) | def _validate_trace_level(self):
method rewound (line 538) | def rewound(self, rewind_rngs: bool = False) -> 'Scope':
method name_reserved (line 563) | def name_reserved(self, name: str, col: str | None = None) -> bool:
method reserve (line 581) | def reserve(self, name: str, col: str | None = None):
method default_name (line 598) | def default_name(self, prefix: str) -> str:
method push (line 614) | def push(
method child (line 654) | def child(
method is_mutable_collection (line 695) | def is_mutable_collection(self, col: str) -> bool:
method is_collection_empty (line 699) | def is_collection_empty(self, col: str) -> bool:
method _mutable_collection (line 705) | def _mutable_collection(self, col: str) -> MutableCollection:
method _collection (line 733) | def _collection(self, col: str) -> Collection:
method has_rng (line 746) | def has_rng(self, name: str) -> bool:
method make_rng (line 750) | def make_rng(self, name: str = 'params') -> PRNGKey:
method get_variable (line 762) | def get_variable(self, col: str, name: str, default: Any = None) -> Any:
method has_variable (line 781) | def has_variable(self, col: str, name: str) -> bool:
method put_variable (line 791) | def put_variable(self, col: str, name: str, value: Any):
method variable (line 808) | def variable(
method variable (line 818) | def variable(
method variable (line 830) | def variable(
method variable (line 842) | def variable(
method variable (line 853) | def variable(
method param (line 893) | def param(
method param (line 899) | def param(
method param (line 910) | def param(
method param (line 921) | def param(
method param (line 931) | def param(
method _populate_collections (line 989) | def _populate_collections(self):
method has_flag (line 994) | def has_flag(self, key) -> bool:
method get_flag (line 997) | def get_flag(self, key, default=no_flag) -> Any:
function _unfreeze_variables (line 1003) | def _unfreeze_variables(variables, mutable):
function bind (line 1013) | def bind(
function apply (line 1050) | def apply(
function init (line 1103) | def init(
function lazy_init (line 1137) | def lazy_init(
function _is_valid_collection (line 1173) | def _is_valid_collection(col: VariableDict):
function _is_valid_variables (line 1183) | def _is_valid_variables(variables: VariableDict) -> bool:
function _is_valid_rng (line 1200) | def _is_valid_rng(rng: Array):
function _is_valid_rngs (line 1223) | def _is_valid_rngs(rngs: PRNGKey | RNGSequences):
FILE: flax/core/spmd.py
function get_pspec (line 28) | def get_pspec(sharding, sharding_rules = None) -> PartitionSpec:
function map_sharding (line 32) | def map_sharding(f, sharding):
function get_mesh (line 40) | def get_mesh(sharding):
function apply_rules (line 48) | def apply_rules(sharding, sharding_rules):
function _apply_sharding (line 57) | def _apply_sharding(value, sharding, mesh):
function shard_value (line 69) | def shard_value(value, out_sharding, sharding_rules, mesh):
class _AxisRules (line 104) | class _AxisRules(threading.local):
function set_logical_axis_rules (line 114) | def set_logical_axis_rules(rules: LogicalRules):
function get_logical_axis_rules (line 119) | def get_logical_axis_rules() -> LogicalRules:
function logical_axis_rules (line 125) | def logical_axis_rules(rules: LogicalRules):
function composite_rules (line 135) | def composite_rules(rule1, rule2):
function from_sharding_rules (line 153) | def from_sharding_rules(
FILE: flax/core/tracers.py
function current_trace (line 21) | def current_trace():
function check_trace_level (line 32) | def check_trace_level(base_level):
FILE: flax/cursor.py
class Indexable (line 34) | class Indexable(Protocol):
method __getitem__ (line 35) | def __getitem__(self, key) -> Any:
class AccessType (line 39) | class AccessType(enum.Enum):
class ParentKey (line 45) | class ParentKey(Generic[A]):
function is_named_tuple (line 51) | def is_named_tuple(obj):
function _traverse_tree (line 60) | def _traverse_tree(path, obj, *, update_fn=None, cond_fn=None):
class Cursor (line 120) | class Cursor(Generic[A]):
method __init__ (line 125) | def __init__(self, obj: A, parent_key: ParentKey[A] | None):
method _root (line 133) | def _root(self) -> 'Cursor[A]':
method _path (line 140) | def _path(self) -> str:
method __getitem__ (line 152) | def __getitem__(self, key) -> 'Cursor[A]':
method __getattr__ (line 169) | def __getattr__(self, name) -> 'Cursor[A]':
method __setitem__ (line 182) | def __setitem__(self, key, value):
method __setattr__ (line 187) | def __setattr__(self, name, value):
method set (line 190) | def set(self, value) -> A:
method build (line 224) | def build(self) -> A:
method apply_update (line 284) | def apply_update(
method find (line 382) | def find(self, cond_fn: Callable[[str, Any], bool]) -> 'Cursor[A]':
method find_all (line 473) | def find_all(
method __str__ (line 545) | def __str__(self):
method __repr__ (line 548) | def __repr__(self):
method _pretty_repr (line 551) | def _pretty_repr(self, indent=2, _prefix_indent=0):
method __len__ (line 579) | def __len__(self):
method __iter__ (line 582) | def __iter__(self):
method __reversed__ (line 591) | def __reversed__(self):
method __add__ (line 600) | def __add__(self, other):
method __sub__ (line 603) | def __sub__(self, other):
method __mul__ (line 606) | def __mul__(self, other):
method __matmul__ (line 609) | def __matmul__(self, other):
method __truediv__ (line 612) | def __truediv__(self, other):
method __floordiv__ (line 615) | def __floordiv__(self, other):
method __mod__ (line 618) | def __mod__(self, other):
method __divmod__ (line 621) | def __divmod__(self, other):
method __pow__ (line 624) | def __pow__(self, other):
method __lshift__ (line 627) | def __lshift__(self, other):
method __rshift__ (line 630) | def __rshift__(self, other):
method __and__ (line 633) | def __and__(self, other):
method __xor__ (line 636) | def __xor__(self, other):
method __or__ (line 639) | def __or__(self, other):
method __radd__ (line 642) | def __radd__(self, other):
method __rsub__ (line 645) | def __rsub__(self, other):
method __rmul__ (line 648) | def __rmul__(self, other):
method __rmatmul__ (line 651) | def __rmatmul__(self, other):
method __rtruediv__ (line 654) | def __rtruediv__(self, other):
method __rfloordiv__ (line 657) | def __rfloordiv__(self, other):
method __rmod__ (line 660) | def __rmod__(self, other):
method __rdivmod__ (line 663) | def __rdivmod__(self, other):
method __rpow__ (line 666) | def __rpow__(self, other):
method __rlshift__ (line 669) | def __rlshift__(self, other):
method __rrshift__ (line 672) | def __rrshift__(self, other):
method __rand__ (line 675) | def __rand__(self, other):
method __rxor__ (line 678) | def __rxor__(self, other):
method __ror__ (line 681) | def __ror__(self, other):
method __neg__ (line 684) | def __neg__(self):
method __pos__ (line 687) | def __pos__(self):
method __abs__ (line 690) | def __abs__(self):
method __invert__ (line 693) | def __invert__(self):
method __round__ (line 696) | def __round__(self, ndigits=None):
method __lt__ (line 699) | def __lt__(self, other):
method __le__ (line 702) | def __le__(self, other):
method __eq__ (line 705) | def __eq__(self, other):
method __ne__ (line 708) | def __ne__(self, other):
method __gt__ (line 711) | def __gt__(self, other):
method __ge__ (line 714) | def __ge__(self, other):
function cursor (line 718) | def cursor(obj: A) -> Cursor[A]:
FILE: flax/errors.py
class FlaxError (line 52) | class FlaxError(Exception):
method __init__ (line 53) | def __init__(self, message):
method __reduce__ (line 63) | def __reduce__(self):
class TraceContextError (line 72) | class TraceContextError(FlaxError):
class LazyInitError (line 81) | class LazyInitError(FlaxError):
method __init__ (line 101) | def __init__(self, partial_val):
class InvalidRngError (line 113) | class InvalidRngError(FlaxError):
method __init__ (line 167) | def __init__(self, msg):
class ApplyScopeInvalidVariablesTypeError (line 174) | class ApplyScopeInvalidVariablesTypeError(FlaxError):
method __init__ (line 181) | def __init__(self):
class ApplyScopeInvalidVariablesStructureError (line 189) | class ApplyScopeInvalidVariablesStructureError(FlaxError):
method __init__ (line 196) | def __init__(self, variables):
class ScopeParamNotFoundError (line 205) | class ScopeParamNotFoundError(FlaxError):
method __init__ (line 228) | def __init__(self, param_name, scope_path):
class ScopeCollectionNotFound (line 235) | class ScopeCollectionNotFound(FlaxError):
method __init__ (line 249) | def __init__(self, col_name, var_name, scope_path):
class ScopeParamShapeError (line 256) | class ScopeParamShapeError(FlaxError):
method __init__ (line 284) | def __init__(self, param_name, scope_path, value_shape, init_shape):
class ScopeVariableNotFoundError (line 292) | class ScopeVariableNotFoundError(FlaxError):
method __init__ (line 300) | def __init__(self, name, col, scope_path):
class InvalidFilterError (line 307) | class InvalidFilterError(FlaxError):
method __init__ (line 310) | def __init__(self, filter_like):
class InvalidScopeError (line 314) | class InvalidScopeError(FlaxError):
method __init__ (line 323) | def __init__(self, scope_name):
class ModifyScopeVariableError (line 327) | class ModifyScopeVariableError(FlaxError):
method __init__ (line 347) | def __init__(self, col, variable_name, scope_path):
class ImmutableVariableError (line 354) | class ImmutableVariableError(FlaxError):
method __init__ (line 366) | def __init__(self, message):
class JaxTransformError (line 370) | class JaxTransformError(FlaxError):
method __init__ (line 379) | def __init__(self):
class PartitioningUnspecifiedError (line 388) | class PartitioningUnspecifiedError(FlaxError):
method __init__ (line 395) | def __init__(self, target):
class NameInUseError (line 407) | class NameInUseError(FlaxError):
method __init__ (line 456) | def __init__(self, key_type, value, module_name):
class AssignSubModuleError (line 464) | class AssignSubModuleError(FlaxError):
method __init__ (line 502) | def __init__(self, cls):
class SetAttributeInModuleSetupError (line 509) | class SetAttributeInModuleSetupError(FlaxError):
method __init__ (line 541) | def __init__(self):
class SetAttributeFrozenModuleError (line 545) | class SetAttributeFrozenModuleError(FlaxError):
method __init__ (line 576) | def __init__(self, module_cls, attr_name, attr_val):
class MultipleMethodsCompactError (line 584) | class MultipleMethodsCompactError(FlaxError):
method __init__ (line 598) | def __init__(self):
class ReservedModuleAttributeError (line 602) | class ReservedModuleAttributeError(FlaxError):
method __init__ (line 611) | def __init__(self, annotations):
class ApplyModuleInvalidMethodError (line 617) | class ApplyModuleInvalidMethodError(FlaxError):
method __init__ (line 628) | def __init__(self, method):
class CallCompactUnboundModuleError (line 634) | class CallCompactUnboundModuleError(FlaxError):
method __init__ (line 656) | def __init__(self):
class CallSetupUnboundModuleError (line 660) | class CallSetupUnboundModuleError(FlaxError):
method __init__ (line 689) | def __init__(self):
class CallUnbindOnUnboundModuleError (line 693) | class CallUnbindOnUnboundModuleError(FlaxError):
method __init__ (line 716) | def __init__(self):
class CallShareScopeOnUnboundModuleError (line 719) | class CallShareScopeOnUnboundModuleError(FlaxError):
method __init__ (line 735) | def __init__(self):
class InvalidInstanceModuleError (line 738) | class InvalidInstanceModuleError(FlaxError):
method __init__ (line 756) | def __init__(self):
class IncorrectPostInitOverrideError (line 763) | class IncorrectPostInitOverrideError(FlaxError):
method __init__ (line 784) | def __init__(self):
class DescriptorAttributeError (line 790) | class DescriptorAttributeError(FlaxError):
method __init__ (line 806) | def __init__(self):
class InvalidCheckpointError (line 813) | class InvalidCheckpointError(FlaxError):
method __init__ (line 822) | def __init__(self, path, step):
class MPACheckpointingRequiredError (line 829) | class MPACheckpointingRequiredError(FlaxError):
method __init__ (line 840) | def __init__(self, path, step):
class MPARestoreTargetRequiredError (line 848) | class MPARestoreTargetRequiredError(FlaxError):
method __init__ (line 858) | def __init__(self, path, step, key=None):
class MPARestoreDataCorruptedError (line 870) | class MPARestoreDataCorruptedError(FlaxError):
method __init__ (line 876) | def __init__(self, step, path):
class TransformedMethodReturnValueError (line 889) | class TransformedMethodReturnValueError(FlaxError):
method __init__ (line 892) | def __init__(self, name):
class TransformTargetError (line 898) | class TransformTargetError(FlaxError):
method __init__ (line 923) | def __init__(self, target):
class AlreadyExistsError (line 936) | class AlreadyExistsError(FlaxError):
method __init__ (line 943) | def __init__(self, path):
class CursorFindError (line 952) | class CursorFindError(FlaxError):
method __init__ (line 959) | def __init__(self, cursor=None, cursor2=None):
class TraverseTreeError (line 970) | class TraverseTreeError(FlaxError):
method __init__ (line 984) | def __init__(self, update_fn, cond_fn):
FILE: flax/ids.py
class UUIDManager (line 20) | class UUIDManager:
method __init__ (line 32) | def __init__(self):
method __call__ (line 36) | def __call__(self):
class FlaxId (line 45) | class FlaxId:
method __init__ (line 48) | def __init__(self, rawid):
method __eq__ (line 51) | def __eq__(self, other):
method __hash__ (line 54) | def __hash__(self):
method __repr__ (line 57) | def __repr__(self):
method __deepcopy__ (line 60) | def __deepcopy__(self, memo):
method __copy__ (line 64) | def __copy__(self):
FILE: flax/io.py
class BackendMode (line 33) | class BackendMode(Enum):
function override_mode (line 69) | def override_mode(override: BackendMode):
function set_mode (line 85) | def set_mode(override: BackendMode):
function GFile (line 97) | def GFile(name, mode): # pylint: disable=invalid-name
function listdir (line 109) | def listdir(path):
function isdir (line 118) | def isdir(path):
function copy (line 127) | def copy(src, dst, overwrite=False):
function rename (line 139) | def rename(src, dst, overwrite=False):
function exists (line 150) | def exists(path):
function makedirs (line 159) | def makedirs(path):
function glob (line 168) | def glob(pattern):
function remove (line 179) | def remove(path):
function rmtree (line 189) | def rmtree(path):
function getsize (line 199) | def getsize(path):
FILE: flax/jax_utils.py
function _pmap_device_order (line 30) | def _pmap_device_order():
function replicate (line 34) | def replicate(tree, devices=None):
function unreplicate (line 48) | def unreplicate(tree):
function pmean (line 68) | def pmean(xs, axis_name):
function partial_eval_by_shape (line 73) | def partial_eval_by_shape(fn, input_spec, *args, **kwargs):
function _parse_spec (line 114) | def _parse_spec(spec):
function prefetch_to_device (line 123) | def prefetch_to_device(iterator, size, devices=None):
function _scan_nd (line 168) | def _scan_nd(body_fn, init, xs, n=1, unroll=(1,)):
function _invert_perm (line 190) | def _invert_perm(perm):
function scan_in_dim (line 197) | def scan_in_dim(body_fn, init, xs, axis=(0,), unroll=(1,), keepdims=False):
function pad_shard_unpad (line 256) | def pad_shard_unpad(
FILE: flax/linen/activation.py
class PReLU (line 59) | class PReLU(Module):
method __call__ (line 85) | def __call__(self, inputs: Array) -> Array:
FILE: flax/linen/attention.py
function dot_product_attention_weights (line 47) | def dot_product_attention_weights(
function dot_product_attention (line 166) | def dot_product_attention(
class MultiHeadDotProductAttention (line 298) | class MultiHeadDotProductAttention(Module):
method __call__ (line 409) | def __call__(
method __call__ (line 423) | def __call__(
method __call__ (line 436) | def __call__(
class MultiHeadAttention (line 692) | class MultiHeadAttention(MultiHeadDotProductAttention):
class SelfAttention (line 775) | class SelfAttention(MultiHeadDotProductAttention):
method __call__ (line 787) | def __call__( # type: ignore
function make_attention_mask (line 830) | def make_attention_mask(
function make_causal_mask (line 862) | def make_causal_mask(
function combine_masks (line 890) | def combine_masks(
FILE: flax/linen/batch_apply.py
function ndim_at_least (line 21) | def ndim_at_least(x, num_dims):
function arbitrary_mergeable_leaf (line 26) | def arbitrary_mergeable_leaf(min_num_dims, args, kwargs):
function merge_leading_dims (line 36) | def merge_leading_dims(x, num_dims):
function split_leading_dim (line 45) | def split_leading_dim(x, to_dim):
class BatchApply (line 49) | class BatchApply:
method __init__ (line 86) | def __init__(self, f, num_dims=2):
method __call__ (line 96) | def __call__(self, *args, **kwargs):
FILE: flax/linen/combinators.py
class Sequential (line 23) | class Sequential(Module):
method __post_init__ (line 94) | def __post_init__(self):
method __call__ (line 102) | def __call__(self, *args, **kwargs):
FILE: flax/linen/dtypes.py
function canonicalize_dtype (line 22) | def canonicalize_dtype(
function promote_dtype (line 54) | def promote_dtype(*args, dtype=None, inexact=True) -> list[Any]:
FILE: flax/linen/experimental/layers_with_named_axes.py
class Dense (line 45) | class Dense(nn.Module):
method __call__ (line 76) | def __call__(self, inputs: Array) -> Array:
class Embed (line 120) | class Embed(nn.Module):
method setup (line 147) | def setup(self):
method __call__ (line 156) | def __call__(self, inputs: Array) -> Array:
method attend (line 179) | def attend(self, query: Array) -> Array:
function _canonicalize_axes (line 196) | def _canonicalize_axes(rank: int, axes: Axes) -> Sequence[int]:
function _abs_sq (line 203) | def _abs_sq(x):
function _compute_stats (line 211) | def _compute_stats(x: Array, axes: Axes):
function _normalize (line 234) | def _normalize(
class LayerNorm (line 282) | class LayerNorm(nn.Module):
method __call__ (line 317) | def __call__(self, x):
FILE: flax/linen/fp8_ops.py
class Fp8MetaTyRules (line 43) | class Fp8MetaTyRules:
method physical_element_aval (line 46) | def physical_element_aval(dtype) -> core.ShapedArray:
method replicate_trailing_dims (line 51) | def replicate_trailing_dims(ctx, val, aval):
method logical_sharding (line 56) | def logical_sharding(aval, phys_sharding):
method physical_sharding (line 60) | def physical_sharding(aval, sharding):
method convert_from (line 65) | def convert_from(fp8_meta_dtype, other_dtype) -> bool:
method convert_to (line 69) | def convert_to(other_dtype, fp8_meta_dtype) -> bool:
method add (line 74) | def add(dt, x, y):
method zero (line 80) | def zero(dt):
method tangent_dtype (line 86) | def tangent_dtype(dtype):
method full (line 90) | def full(shape, fill_value, dtype):
method global_sharded_result_handler (line 96) | def global_sharded_result_handler(aval, out_sharding, committed):
class fp8_meta_dtype (line 108) | class fp8_meta_dtype(dtypes.extended): pass
class fp8_meta_dtype_wrapper (line 112) | class fp8_meta_dtype_wrapper(dtypes.ExtendedDType):
method __repr__ (line 117) | def __repr__(self) -> str:
function get_fp8_max (line 125) | def get_fp8_max(fp8_dtype, out_dtype):
function quantize (line 130) | def quantize(x, q_dtype, scale, compute_dtype):
function dequantize (line 139) | def dequantize(x, dq_dtype, scale):
function qdq (line 142) | def qdq(x, q_dtype, scale, compute_dtype):
function compute_scale (line 147) | def compute_scale(amax, scale, fp8_max, margin=0):
function compute_amax_history (line 161) | def compute_amax_history(x, amax_history):
function update_fp8_meta (line 167) | def update_fp8_meta(
function quantize_dequantize_update (line 188) | def quantize_dequantize_update(x, q_dtype, scale, amax_history, compute_...
function _fm32_to_float32 (line 193) | def _fm32_to_float32(value):
function dot_general_transpose_lhs (line 198) | def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
function dot_general_transpose_rhs (line 232) | def dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
function in_qdq (line 243) | def in_qdq(compute_dtype, q_dtype, inp, scale, amax_history):
function in_qdq_fwd (line 250) | def in_qdq_fwd(compute_dtype, q_dtype, inp, scale, amax_history):
function in_qdq_bwd (line 257) | def in_qdq_bwd(compute_dtype, q_dtype, res, g):
function out_qdq (line 267) | def out_qdq(compute_dtype, q_dtype, out, scale, amax_history):
function out_qdq_fwd (line 271) | def out_qdq_fwd(compute_dtype, q_dtype, out, scale, amax_history):
function out_qdq_bwd (line 275) | def out_qdq_bwd(compute_dtype, q_dtype, res, g):
function in_q (line 287) | def in_q(compute_dtype, q_dtype, inp, scale, amax_history):
function in_q_fwd (line 292) | def in_q_fwd(compute_dtype, q_dtype, inp, scale, amax_history):
function in_q_bwd (line 297) | def in_q_bwd(compute_dtype, q_dtype, res, _):
function out_dq (line 306) | def out_dq(dq_type, lhs_scale, rhs_scale, out):
function out_dq_fwd (line 314) | def out_dq_fwd(dq_type, lhs_scale, rhs_scale, out):
function out_dq_bwd (line 317) | def out_dq_bwd(dq_type, _, g):
function quantized_dot (line 323) | def quantized_dot(
function quantized_dot_fwd (line 343) | def quantized_dot_fwd(
function quantized_dot_bwd (line 374) | def quantized_dot_bwd(
function fp8_scaled_dot_general (line 450) | def fp8_scaled_dot_general(
function dot_general_with_precision (line 497) | def dot_general_with_precision(
function dot_general_with_precision_jvp (line 512) | def dot_general_with_precision_jvp(
function _parse_dot_inputs (line 529) | def _parse_dot_inputs(*args, **kwargs):
class Fp8DotGeneralBase (line 542) | class Fp8DotGeneralBase(module.Module):
method setup (line 547) | def setup(self) -> None:
class Fp8DotGeneralOp (line 582) | class Fp8DotGeneralOp(Fp8DotGeneralBase):
method __post_init__ (line 583) | def __post_init__(self):
method __call__ (line 592) | def __call__(self, *args, **kwargs):
class Fp8DirectDotGeneralOp (line 614) | class Fp8DirectDotGeneralOp(Fp8DotGeneralBase):
method __call__ (line 615) | def __call__(self, *args, **kwargs):
class NANOOFp8DotGeneralOp (line 637) | class NANOOFp8DotGeneralOp(Fp8DotGeneralOp):
class Fp8Einsum (line 641) | class Fp8Einsum(Fp8DotGeneralBase):
method __call__ (line 643) | def __call__(self, eqn, lhs: jnp.ndarray, rhs: jnp.ndarray,
FILE: flax/linen/initializers.py
function zeros_init (line 43) | def zeros_init() -> Initializer:
function ones_init (line 56) | def ones_init() -> Initializer:
FILE: flax/linen/kw_only_dataclasses.py
class _KwOnlyType (line 70) | class _KwOnlyType:
method __repr__ (line 73) | def __repr__(self):
function field (line 80) | def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs):
function dataclass (line 106) | def dataclass(cls=None, extra_fields=None, **kwargs):
function _process_class (line 129) | def _process_class(cls: type[M], extra_fields=None, **kwargs):
FILE: flax/linen/linear.py
class PromoteDtypeFn (line 44) | class PromoteDtypeFn(Protocol):
method __call__ (line 45) | def __call__(
function _normalize_axes (line 52) | def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
function _canonicalize_tuple (line 57) | def _canonicalize_tuple(x: Sequence[int] | int) -> tuple[int, ...]:
class DenseGeneral (line 64) | class DenseGeneral(Module):
method __call__ (line 118) | def __call__(self, inputs: Array) -> Array:
class Dense (line 214) | class Dense(Module):
method __call__ (line 254) | def __call__(self, inputs: Array) -> Array:
class Einsum (line 298) | class Einsum(Module):
method __call__ (line 343) | def __call__(self, inputs: Array, einsum_str: str | None = None) -> Ar...
method _get_bias_shape (line 402) | def _get_bias_shape(self, einsum_str: str, lhs: Array, rhs: Array):
function _conv_dimension_numbers (line 434) | def _conv_dimension_numbers(input_shape):
function canonicalize_padding (line 443) | def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
class _Conv (line 467) | class _Conv(Module):
method shared_weights (line 529) | def shared_weights(self) -> bool: # type: ignore
method __call__ (line 541) | def __call__(self, inputs: Array) -> Array:
class Conv (line 734) | class Conv(_Conv):
method shared_weights (line 801) | def shared_weights(self) -> bool:
class ConvLocal (line 805) | class ConvLocal(_Conv):
method shared_weights (line 872) | def shared_weights(self) -> bool:
class ConvTranspose (line 876) | class ConvTranspose(Module):
method __call__ (line 955) | def __call__(self, inputs: Array) -> Array:
class Embed (line 1110) | class Embed(Module):
method setup (line 1165) | def setup(self):
method __call__ (line 1173) | def __call__(self, inputs: Array) -> Array:
method attend (line 1196) | def attend(self, query: Array) -> Array:
FILE: flax/linen/module.py
function _get_fn_name (line 86) | def _get_fn_name(fn):
function _indent (line 92) | def _indent(x: str, num_spaces: int):
function _attr_repr (line 100) | def _attr_repr(value: Any):
function _module_repr (line 111) | def _module_repr(module: 'Module', num_spaces: int = 4):
class _CallInfo (line 153) | class _CallInfo:
class _CallInfoContext (line 166) | class _CallInfoContext(threading.local):
method get_call_index (line 170) | def get_call_index(self) -> int:
function _tabulate_context (line 177) | def _tabulate_context():
class _DynamicContext (line 187) | class _DynamicContext(threading.local):
method __init__ (line 193) | def __init__(self):
class _Sentinel (line 205) | class _Sentinel:
method __copy__ (line 206) | def __copy__(self):
method __deepcopy__ (line 209) | def __deepcopy__(self, memo):
method __reduce__ (line 213) | def __reduce__(self):
function _get_unspecified_parent (line 217) | def _get_unspecified_parent():
function _derive_profiling_name (line 229) | def _derive_profiling_name(module, fn):
function enable_named_call (line 236) | def enable_named_call():
function disable_named_call (line 251) | def disable_named_call():
function override_named_call (line 261) | def override_named_call(enable: bool = True):
class InterceptorContext (line 282) | class InterceptorContext:
class ThreadLocalStack (line 297) | class ThreadLocalStack(threading.local):
method __init__ (line 300) | def __init__(self):
method push (line 303) | def push(self, elem: Any) -> None:
method pop (line 306) | def pop(self) -> Any:
method __iter__ (line 309) | def __iter__(self) -> Iterator[Any]:
method __len__ (line 312) | def __len__(self) -> int:
method __repr__ (line 315) | def __repr__(self) -> str:
function intercept_methods (line 327) | def intercept_methods(interceptor: Interceptor):
function run_interceptors (line 395) | def run_interceptors(
function _sorted_items (line 426) | def _sorted_items(x):
function _get_suffix_value_pairs (line 431) | def _get_suffix_value_pairs(
function _map_over_modules_in_tree (line 443) | def _map_over_modules_in_tree(fn, tree_or_leaf):
function _freeze_attr (line 458) | def _freeze_attr(val: Any) -> Any:
function compact (line 477) | def compact(fun: _CallableT) -> _CallableT:
function nowrap (line 505) | def nowrap(fun: _CallableT) -> _CallableT:
function compact_name_scope (line 548) | def compact_name_scope(fun: _CallableT) -> _CallableT:
function _get_local_method_names (line 629) | def _get_local_method_names(
function _get_local_descriptor_names (line 652) | def _get_local_descriptor_names(
function wrap_method_once (line 677) | def wrap_method_once(fun: Callable[..., Any]) -> Callable[..., Any]:
function wrap_descriptor_once (line 707) | def wrap_descriptor_once(descriptor) -> 'DescriptorWrapper':
function _wrap_hash (line 723) | def _wrap_hash(hash_fn: Callable[..., Any]) -> Callable[..., Any]:
function _get_unbound_fn (line 743) | def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., A...
function _map_submodules (line 772) | def _map_submodules(fn: Callable[['Module'], Any], tree):
class SetupState (line 778) | class SetupState(enum.IntEnum):
class _ModuleInternalState (line 788) | class _ModuleInternalState:
method reset (line 805) | def reset(self) -> None:
method export (line 815) | def export(self) -> '_ModuleInternalState':
method reimport (line 829) | def reimport(self, other: '_ModuleInternalState') -> None:
class ParentDescriptor (line 861) | class ParentDescriptor:
method __get__ (line 874) | def __get__(self, obj, objtype=None):
method __set__ (line 881) | def __set__(self, obj, value):
class Descriptor (line 886) | class Descriptor(tpe.Protocol):
method __get__ (line 889) | def __get__(self, obj, objtype=None) -> Any:
method __set__ (line 892) | def __set__(self, obj, value) -> None:
method __delete__ (line 895) | def __delete__(self, obj) -> None:
method __set_name__ (line 898) | def __set_name__(self, owner, name) -> None:
class DescriptorWrapper (line 902) | class DescriptorWrapper:
function create_descriptor_wrapper (line 906) | def create_descriptor_wrapper(descriptor: Descriptor):
function module_field (line 956) | def module_field(*, kw_only: bool = False, default: Any | None = ...) ->...
class ModuleBase (line 975) | class ModuleBase:
class Module (line 983) | class Module(ModuleBase):
method __init__ (line 1025) | def __init__(self, *args, **kwargs):
method __call__ (line 1029) | def __call__(self, *args, **kwargs) -> Any:
method __init_subclass__ (line 1034) | def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None:
method _customized_dataclass_transform (line 1055) | def _customized_dataclass_transform(cls, kw_only: bool):
method _find_compact_name_scope_methods (line 1127) | def _find_compact_name_scope_methods(cls):
method _wrap_module_attributes (line 1138) | def _wrap_module_attributes(cls):
method _call_wrapped_method (line 1170) | def _call_wrapped_method(self, fun, args, kwargs):
method __setattr__ (line 1255) | def __setattr__(self, name: str, val: Any):
method __getattr__ (line 1302) | def __getattr__(self, name: str) -> Any:
method __dir__ (line 1319) | def __dir__(self) -> list[str]:
method __post_init__ (line 1324) | def __post_init__(self) -> None:
method __repr__ (line 1392) | def __repr__(self) -> str:
method setup (line 1395) | def setup(self) -> None:
method _register_submodules (line 1430) | def _register_submodules(self, name, val):
method _try_setup (line 1481) | def _try_setup(self, shallow: bool = False) -> None:
method _validate_setup (line 1515) | def _validate_setup(self) -> None:
method _name_taken (line 1525) | def _name_taken(
method _initialization_allowed (line 1537) | def _initialization_allowed(self):
method path (line 1545) | def path(self):
method clone (line 1577) | def clone(
method copy (line 1653) | def copy(
method variable (line 1677) | def variable(
method variable (line 1687) | def variable(
method variable (line 1699) | def variable(
method variable (line 1711) | def variable(
method variable (line 1722) | def variable(
method param (line 1787) | def param(
method param (line 1793) | def param(
method param (line 1804) | def param(
method param (line 1815) | def param(
method param (line 1824) | def param(
method has_variable (line 1885) | def has_variable(self, col: str, name: str) -> bool:
method is_mutable_collection (line 1902) | def is_mutable_collection(self, col: str) -> bool:
method has_rng (line 1908) | def has_rng(self, name: str) -> bool:
method make_rng (line 1914) | def make_rng(self, name: str = 'params') -> PRNGKey:
method is_initializing (line 1957) | def is_initializing(self) -> bool:
method _module_checks (line 1971) | def _module_checks(self):
method bind (line 1982) | def bind(
method unbind (line 2043) | def unbind(self: M) -> tuple[M, VariableDict]:
method apply (line 2092) | def apply(
method init_with_output (line 2252) | def init_with_output(
method init (line 2316) | def init(
method lazy_init (line 2467) | def lazy_init(
method variables (line 2518) | def variables(self) -> VariableDict:
method get_variable (line 2524) | def get_variable(self, col: str, name: str, default: T | None = None) ...
method put_variable (line 2541) | def put_variable(self, col: str, name: str, value: Any):
method sow (line 2554) | def sow(self, col: str, name: str, value: Any) -> bool:
method sow (line 2558) | def sow(
method sow (line 2568) | def sow(
method perturb (line 2655) | def perturb(
method tabulate (line 2727) | def tabulate(
method module_paths (line 2857) | def module_paths(
function merge_param (line 2923) | def merge_param(name: str, a: T | None, b: T | None) -> T:
function apply (line 2968) | def apply(
function init_with_output (line 3038) | def init_with_output(
function init (line 3109) | def init(
class CompactNameScope (line 3176) | class CompactNameScope(Module):
method __call__ (line 3181) | def __call__(self, *args, **kwargs) -> Any:
method __call__ (line 3191) | def __call__(self, *args, **kwargs) -> Any:
class CompactNameScope (line 3186) | class CompactNameScope:
method __call__ (line 3181) | def __call__(self, *args, **kwargs) -> Any:
method __call__ (line 3191) | def __call__(self, *args, **kwargs) -> Any:
function share_scope (line 3195) | def share_scope(module: Module, other: Module, /):
FILE: flax/linen/normalization.py
function _canonicalize_axes (line 45) | def _canonicalize_axes(rank: int, axes: Axes) -> tuple[int, ...]:
function _abs_sq (line 52) | def _abs_sq(x):
function _compute_stats (line 60) | def _compute_stats(
function _normalize (line 154) | def _normalize(
function _l2_normalize (line 229) | def _l2_normalize(x, axis=None, eps=1e-12):
class BatchNorm (line 247) | class BatchNorm(Module):
method __call__ (line 324) | def __call__(
class LayerNorm (line 424) | class LayerNorm(Module):
method __call__ (line 501) | def __call__(self, x, *, mask: jax.Array | None = None):
class RMSNorm (line 541) | class RMSNorm(Module):
method __call__ (line 601) | def __call__(self, x, *, mask: jax.Array | None = None):
class GroupNorm (line 642) | class GroupNorm(Module):
method __call__ (line 727) | def __call__(self, x, *, mask: jax.Array | None = None):
class InstanceNorm (line 822) | class InstanceNorm(Module):
method __call__ (line 901) | def __call__(self, x, *, mask: jax.Array | None = None):
class SpectralNorm (line 947) | class SpectralNorm(Module):
method __call__ (line 1069) | def __call__(self, *args, update_stats: bool, **kwargs):
method _spectral_normalize (line 1104) | def _spectral_normalize(self, path, vs, update_stats):
class WeightNorm (line 1184) | class WeightNorm(Module):
method __call__ (line 1312) | def __call__(self, *args, **kwargs):
method _l2_normalize (line 1339) | def _l2_normalize(self, path, vs):
FILE: flax/linen/partitioning.py
class AxisMetadata (line 85) | class AxisMetadata:
function _param_with_axes_sow_reduce_fn (line 91) | def _param_with_axes_sow_reduce_fn(x, y):
function param_with_axes (line 123) | def param_with_axes(
class PartitionedVariable (line 174) | class PartitionedVariable(flax.core.scope.Variable):
method __init__ (line 184) | def __init__(
method value (line 208) | def value(self):
method value (line 216) | def value(self, value):
function _core_variable_with_axes (line 223) | def _core_variable_with_axes(
function variable_with_axes (line 245) | def variable_with_axes(
function get_axis_names (line 305) | def get_axis_names(axes_metadata):
function _tree_map_axes (line 338) | def _tree_map_axes(fn, tree):
function _is_mutable (line 346) | def _is_mutable(axis_col: str) -> bool:
function _add_axis_to_metadata (line 369) | def _add_axis_to_metadata(fn, axis_pos, axis_name, axis_col='params_axes'):
function scan_with_axes (line 416) | def scan_with_axes(
function vmap_with_axes (line 472) | def vmap_with_axes(
function core_remat_static (line 526) | def core_remat_static(
function remat (line 583) | def remat(
FILE: flax/linen/pooling.py
function pool (line 22) | def pool(inputs, init, reduce_fn, window_shape, strides, padding):
function avg_pool (line 79) | def avg_pool(
function max_pool (line 110) | def max_pool(inputs, window_shape, strides=None, padding='VALID'):
function min_pool (line 128) | def min_pool(inputs, window_shape, strides=None, padding='VALID'):
FILE: flax/linen/recurrent.py
class RNNCellBase (line 57) | class RNNCellBase(Module):
method initialize_carry (line 61) | def initialize_carry(
method num_feature_axes (line 76) | def num_feature_axes(self) -> int:
class LSTMCell (line 81) | class LSTMCell(RNNCellBase):
method __call__ (line 135) | def __call__(self, carry, inputs):
method initialize_carry (line 176) | def initialize_carry(
method num_feature_axes (line 195) | def num_feature_axes(self) -> int:
class DenseParams (line 199) | class DenseParams(Module):
method __call__ (line 210) | def __call__(self, inputs: Array) -> tuple[Array, Array | None]:
class OptimizedLSTMCell (line 224) | class OptimizedLSTMCell(RNNCellBase):
method __call__ (line 282) | def __call__(
method initialize_carry (line 365) | def initialize_carry(
method num_feature_axes (line 385) | def num_feature_axes(self) -> int:
class SimpleCell (line 389) | class SimpleCell(RNNCellBase):
method __call__ (line 446) | def __call__(self, carry, inputs):
method initialize_carry (line 484) | def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]):
method num_feature_axes (line 499) | def num_feature_axes(self) -> int:
class GRUCell (line 503) | class GRUCell(RNNCellBase):
method __call__ (line 555) | def __call__(self, carry, inputs):
method initialize_carry (line 598) | def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]):
method num_feature_axes (line 613) | def num_feature_axes(self) -> int:
class MGUCell (line 617) | class MGUCell(RNNCellBase):
method __call__ (line 686) | def __call__(self, carry, inputs):
method initialize_carry (line 733) | def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]):
method num_feature_axes (line 748) | def num_feature_axes(self) -> int:
class ConvLSTMCell (line 752) | class ConvLSTMCell(RNNCellBase):
method __call__ (line 814) | def __call__(self, carry, inputs):
method initialize_carry (line 858) | def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]):
method num_feature_axes (line 878) | def num_feature_axes(self) -> int:
class RNN (line 882) | class RNN(Module):
method __call__ (line 1016) | def __call__(
function _select_last_carry (line 1166) | def _select_last_carry(sequence: A, seq_lengths: jnp.ndarray) -> A:
function _expand_dims_like (line 1175) | def _expand_dims_like(x, target):
function flip_sequences (line 1180) | def flip_sequences(
function _concatenate (line 1241) | def _concatenate(a: Array, b: Array) -> Array:
class RNNBase (line 1246) | class RNNBase(Protocol):
method __call__ (line 1247) | def __call__(
class Bidirectional (line 1262) | class Bidirectional(Module):
method __call__ (line 1282) | def __call__(
FILE: flax/linen/spmd.py
class _UnassignedAxis (line 52) | class _UnassignedAxis:
method __repr__ (line 55) | def __repr__(self):
method __bool__ (line 58) | def __bool__(self):
function _mesh_assignment_free (line 65) | def _mesh_assignment_free(new_assignment, existing_assignments):
function _logical_to_mesh_axes (line 76) | def _logical_to_mesh_axes(
function logical_to_mesh_axes (line 114) | def logical_to_mesh_axes(
function logical_to_mesh (line 161) | def logical_to_mesh(tree: Any, rules: LogicalRules | None = None) -> Any:
function logical_to_mesh_sharding (line 170) | def logical_to_mesh_sharding(
class RulesFallback (line 183) | class RulesFallback(enum.Enum):
function _with_sharding_constraint (line 191) | def _with_sharding_constraint(
function _with_sharding_constraint_one_fallback (line 206) | def _with_sharding_constraint_one_fallback(
function _is_axis_spec (line 231) | def _is_axis_spec(x):
function _is_logical_spec (line 239) | def _is_logical_spec(x):
function with_logical_constraint (line 245) | def with_logical_constraint(
class LogicallyPartitioned (line 276) | class LogicallyPartitioned(meta.Partitioned):
method __eq__ (line 283) | def __eq__(self, other):
method unbox (line 290) | def unbox(self, apply_constraint=True) -> Any:
method to_nnx_metadata (line 302) | def to_nnx_metadata(self) -> dict[str, Any]:
method from_nnx_metadata (line 312) | def from_nnx_metadata(cls, metadata: dict[str, Any]):
function with_logical_partitioning (line 320) | def with_logical_partitioning(
FILE: flax/linen/stochastic.py
class Dropout (line 26) | class Dropout(Module):
method __call__ (line 69) | def __call__(
FILE: flax/linen/summary.py
class _ValueRepresentation (line 52) | class _ValueRepresentation(ABC):
method render (line 56) | def render(self) -> str:
class _ArrayRepresentation (line 61) | class _ArrayRepresentation(_ValueRepresentation):
method from_array (line 66) | def from_array(cls, x: Array) -> '_ArrayRepresentation':
method render_array (line 70) | def render_array(cls, x) -> str:
method render (line 73) | def render(self):
class _PartitionedArrayRepresentation (line 79) | class _PartitionedArrayRepresentation(_ValueRepresentation):
method from_partitioned (line 84) | def from_partitioned(
method render (line 91) | def render(self):
class _ObjectRepresentation (line 96) | class _ObjectRepresentation(_ValueRepresentation):
method render (line 99) | def render(self):
class Row (line 104) | class Row:
method __post_init__ (line 134) | def __post_init__(self):
method size_and_bytes (line 140) | def size_and_bytes(
class Table (line 153) | class Table(list[Row]):
method __init__ (line 163) | def __init__(
function tabulate (line 174) | def tabulate(
function _get_flops (line 327) | def _get_flops(fn, *args, **kwargs):
function _get_call_flops (line 336) | def _get_call_flops(
function _get_module_table (line 425) | def _get_module_table(
function _get_module_variables (line 491) | def _get_module_variables(
function _get_path_variables (line 521) | def _get_path_variables(
function _process_inputs (line 543) | def _process_inputs(args, kwargs) -> Any:
function _render_table (line 559) | def _render_table(
function _summary_tree_map (line 659) | def _summary_tree_map(f, tree, *rest):
function _size_and_bytes_repr (line 663) | def _size_and_bytes_repr(size: int, num_bytes: int) -> str:
function _size_and_bytes (line 670) | def _size_and_bytes(pytree: Any) -> tuple[int, int]:
function _get_rich_repr (line 679) | def _get_rich_repr(obj, console_kwargs):
function _as_yaml_str (line 686) | def _as_yaml_str(value) -> str:
function _normalize_structure (line 702) | def _normalize_structure(obj):
function _bytes_repr (line 723) | def _bytes_repr(num_bytes):
function _get_value_representation (line 737) | def _get_value_representation(x: Any) -> _ValueRepresentation:
function _from_value_representation (line 750) | def _from_value_representation(x: _ValueRepresentation) -> Any:
function _represent_tree (line 765) | def _represent_tree(x):
function _maybe_render (line 775) | def _maybe_render(x):
FILE: flax/linen/transforms.py
function clean_clone (line 70) | def clean_clone(x):
class VariablePlaceholder (line 81) | class VariablePlaceholder:
class InstancePlaceholder (line 91) | class InstancePlaceholder:
function _memoize_by_id (line 99) | def _memoize_by_id(fn, refs):
function get_module_scopes (line 118) | def get_module_scopes(module, args=None, kwargs=None):
function set_module_scopes (line 195) | def set_module_scopes(module, args, kwargs, scopes):
function _test_transformed_return_values (line 285) | def _test_transformed_return_values(tree, method_name):
function module_class_lift_transform (line 299) | def module_class_lift_transform(
function decorator_lift_transform (line 385) | def decorator_lift_transform(
class _HashableProxy (line 439) | class _HashableProxy:
method from_module (line 450) | def from_module(cls, module: Module) -> '_HashableProxy':
method __hash__ (line 455) | def __hash__(self):
method __eq__ (line 458) | def __eq__(self, other):
method module (line 462) | def module(self):
function _module_fingerprint (line 466) | def _module_fingerprint(module: Module) -> tuple[type[Any], Any]:
function _fingerprint_recursive (line 470) | def _fingerprint_recursive(
function _check_field_is_hashable (line 551) | def _check_field_is_hashable(path: tuple[str, ...], x: Any):
function decorator_lift_transform_cached (line 560) | def decorator_lift_transform_cached(transform, class_fn, **trafo_kwargs):
function fork_rngs (line 643) | def fork_rngs(module: Module):
function module_class_lift_transform_cached (line 660) | def module_class_lift_transform_cached(
function _is_module_class (line 762) | def _is_module_class(target: TransformTarget) -> bool:
function lift_transform (line 771) | def lift_transform(
function lift_transform_cached (line 789) | def lift_transform_cached(
function lift_direct_transform (line 807) | def lift_direct_transform(
function vmap (line 834) | def vmap(
function jit (line 927) | def jit(
function checkpoint (line 997) | def checkpoint(
function remat_scan (line 1087) | def remat_scan(
function scan (line 1153) | def scan(
function map_variables (line 1339) | def map_variables(
function vjp (line 1417) | def vjp(
function value_and_grad (line 1502) | def value_and_grad(
function grad (line 1591) | def grad(
function jvp (line 1668) | def jvp(
function while_loop (line 1763) | def while_loop(
function _cond_wrapper (line 1833) | def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs):
function cond (line 1839) | def cond(
function _switch_wrapper (line 1901) | def _switch_wrapper(*args, variables, rngs, n_branches):
function switch (line 1911) | def switch(
function _custom_vjp_single_scope_fn (line 2002) | def _custom_vjp_single_scope_fn(
function custom_vjp (line 2015) | def custom_vjp(
function named_call (line 2099) | def named_call(class_fn, force=True):
function add_metadata_axis (line 2124) | def add_metadata_axis(
function fold_rngs (line 2166) | def fold_rngs(
FILE: flax/metrics/tensorboard.py
function _flatten_dict (line 29) | def _flatten_dict(input_dict, parent_key='', sep='.'):
function _as_default (line 68) | def _as_default(summary_writer: tf.summary.SummaryWriter, auto_flush: bo...
class SummaryWriter (line 82) | class SummaryWriter:
method __init__ (line 85) | def __init__(self, log_dir, auto_flush=True):
method close (line 102) | def close(self):
method flush (line 109) | def flush(self):
method scalar (line 112) | def scalar(self, tag, value, step):
method image (line 124) | def image(self, tag, image, step, max_outputs=3):
method audio (line 156) | def audio(self, tag, audiodata, step, sample_rate=44100, max_outputs=3):
method histogram (line 186) | def histogram(self, tag, values, step, bins=None):
method text (line 200) | def text(self, tag, textdata, step):
method write (line 214) | def write(self, tag, tensor, step, metadata=None):
method hparams (line 229) | def hparams(self, hparams):
FILE: flax/nnx/__init__.py
function __getattr__ (line 230) | def __getattr__(name):
FILE: flax/nnx/bridge/interop.py
function nnx_in_bridge_mdl (line 26) | def nnx_in_bridge_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Mod...
function linen_in_bridge_mdl (line 69) | def linen_in_bridge_mdl(linen_module: nn_module.Module,
FILE: flax/nnx/bridge/module.py
class ModuleStackEntry (line 45) | class ModuleStackEntry:
class ModuleContext (line 54) | class ModuleContext(threading.local):
class ModuleState (line 63) | class ModuleState(statelib.State):
class Scope (line 70) | class Scope(Pytree):
method __init__ (line 71) | def __init__(self, rngs: rnglib.Rngs, mutable: CollectionFilter):
method copy (line 75) | def copy(self):
class _HasSetup (line 79) | class _HasSetup(tp.Protocol):
method setup (line 80) | def setup(self) -> None: ...
function has_setup (line 83) | def has_setup(x: tp.Any) -> tp.TypeGuard[_HasSetup]:
function _maybe_call_setup (line 87) | def _maybe_call_setup(module: Module):
function _bind_module (line 104) | def _bind_module(parent: Module, module: Module) -> Module:
function current_context (line 115) | def current_context() -> ModuleStackEntry | None:
function current_module (line 119) | def current_module() -> Module | None:
function _auto_submodule_name (line 127) | def _auto_submodule_name(parent_ctx, cls):
class ModuleMeta (line 134) | class ModuleMeta(nnx_module.ModuleMeta):
method _pytree_meta_construct (line 136) | def _pytree_meta_construct(cls, self, *args, **kwargs):
function _module_meta_call (line 141) | def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
class AttrPriority (line 184) | class AttrPriority(enum.IntEnum):
class PriorityStr (line 191) | class PriorityStr(str):
method __new__ (line 194) | def __new__(cls, priority: AttrPriority, value: str):
method _check_and_get_priority (line 199) | def _check_and_get_priority(self, other) -> AttrPriority:
method __lt__ (line 208) | def __lt__(self, other) -> bool:
method __gt__ (line 214) | def __gt__(self, other) -> bool:
class ModuleBase (line 220) | class ModuleBase:
class Module (line 227) | class Module(nnx_module.Module, ModuleBase, metaclass=ModuleMeta):
method __init_subclass__ (line 228) | def __init_subclass__(cls) -> None:
method __getattribute__ (line 234) | def __getattribute__(self, name: str):
method _getattr (line 237) | def _getattr(self, name: str) -> tp.Any:
method _setattr (line 243) | def _setattr(self, name: str, value: tp.Any) -> None:
method _graph_node_flatten (line 255) | def _graph_node_flatten(self):
method set_attr_priority (line 264) | def set_attr_priority(self, name: str, value: AttrPriority):
method make_rng (line 267) | def make_rng(self, name: str = 'default') -> jax.Array:
method param (line 272) | def param( # type: ignore[invalid-annotation]
method variable (line 322) | def variable( # type: ignore[invalid-annotation]
method _get_variables (line 376) | def _get_variables(self) -> tp.Mapping:
method variables (line 411) | def variables(self):
method apply (line 415) | def apply(
method init (line 502) | def init(
method init_with_output (line 520) | def init_with_output(
method is_initializing (line 539) | def is_initializing(self) -> bool:
function compact (line 543) | def compact(f: F) -> F:
function _get_unbound_fn (line 561) | def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable:
FILE: flax/nnx/bridge/variables.py
function sort_variable_types (line 31) | def sort_variable_types(types: tp.Iterable[type]):
class NNXMeta (line 43) | class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
method unbox (line 50) | def unbox(self) -> A:
method replace_boxed (line 53) | def replace_boxed(self, val: B) -> 'NNXMeta[B]':
method add_axis (line 56) | def add_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
method remove_axis (line 60) | def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[...
method get_partition_spec (line 64) | def get_partition_spec(self) -> jax.sharding.PartitionSpec:
method to_nnx_variable (line 71) | def to_nnx_variable(self) -> variablelib.Variable:
function is_vanilla_variable (line 75) | def is_vanilla_variable(vs: variablelib.Variable) -> bool:
function to_linen_var (line 89) | def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata:
function get_col_name (line 101) | def get_col_name(keypath: tp.Sequence[Any]) -> str:
function to_nnx_var (line 108) | def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variablelib.Vari...
function _recursive_merge (line 123) | def _recursive_merge(dict1, dict2):
function linen_vars_to_nnx_attrs (line 130) | def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str...
function nnx_attrs_to_linen_vars (line 152) | def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
function with_partitioning (line 167) | def with_partitioning(
FILE: flax/nnx/bridge/wrappers.py
class Functional (line 41) | class Functional(tp.Generic[M]):
method init (line 47) | def init(self, *, rngs: tp.Optional[Rngs] = None) -> State:
method apply (line 56) | def apply(self, *states: tp.Any):
function functional (line 61) | def functional(cls: tp.Type[M]) -> tp.Callable[..., Functional[M]]:
function _set_initializing (line 68) | def _set_initializing(module: Module, initializing: bool):
function lazy_init (line 74) | def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
function current_linen_module (line 92) | def current_linen_module() -> linen.Module | None:
class ToNNX (line 98) | class ToNNX(Module):
method __init__ (line 128) | def __init__(
method rngs (line 144) | def rngs(self) -> Rngs | None:
method module (line 152) | def module(self) -> linen.Module:
method _setattr (line 159) | def _setattr(self, name, value):
method lazy_init (line 164) | def lazy_init(self, *args, **kwargs):
method __getattr__ (line 168) | def __getattr__(self, name: str):
method __call__ (line 178) | def __call__(
function linen_rngs_dict (line 248) | def linen_rngs_dict(linen_module: linen.Module, add_default: bool = False):
function _get_module_method (line 259) | def _get_module_method(module, method: tp.Callable[..., Any] | str | None):
class ToLinen (line 281) | class ToLinen(linen.Module):
method __call__ (line 327) | def __call__(
method __getattr__ (line 384) | def __getattr__(self, name: str):
method _update_variables (line 396) | def _update_variables(self, module):
class _Missing (line 435) | class _Missing:
function to_linen (line 442) | def to_linen(
function to_linen_class (line 463) | def to_linen_class(
FILE: flax/nnx/extract.py
class PrefixMapping (line 36) | class PrefixMapping(abc.ABC):
method map_prefix (line 38) | def map_prefix(
function check_consistent_aliasing (line 45) | def check_consistent_aliasing(
function check_consistent_aliasing2 (line 110) | def check_consistent_aliasing2(
function broadcast_prefix (line 159) | def broadcast_prefix(
function broadcast_prefix2 (line 183) | def broadcast_prefix2(
function broadcast_prefix_map (line 198) | def broadcast_prefix_map(
class GraphDefState (line 211) | class GraphDefState(struct.PyTreeNode):
class NodeStates (line 219) | class NodeStates(struct.PyTreeNode):
method graphdef (line 225) | def graphdef(self) -> graphlib.GraphDef[tp.Any]:
method state (line 231) | def state(self) -> tp.Any:
method from_split (line 239) | def from_split(
method from_states (line 250) | def from_states(
method from_prefixes (line 258) | def from_prefixes(
function default_split_fn (line 268) | def default_split_fn(
function to_tree (line 274) | def to_tree(
function to_tree2 (line 336) | def to_tree2(
function from_tree2 (line 398) | def from_tree2(tree: tp.Any, /) -> tp.Any:
function merge_tree_node (line 420) | def merge_tree_node(
function is_tree_node (line 428) | def is_tree_node(x):
function from_tree (line 432) | def from_tree(
function clear_non_graph_nodes (line 485) | def clear_non_graph_nodes(tree):
class Mask (line 495) | class Mask(tp.NamedTuple):
function mask_at (line 498) | def mask_at(t: tuple, index: int | None) -> tuple:
function replace_at (line 506) | def replace_at(t: tuple, index: int, value: tp.Any) -> tuple:
function updates_and_snapshot (line 512) | def updates_and_snapshot(args: A) -> tuple[A, A]:
function check_no_aliases (line 529) | def check_no_aliases(fn_name: str, /, **kwargs):
function check_prefix (line 562) | def check_prefix(prefix: tp.Any, prefix_name: str, fn_name: str):
function variable_changed (line 578) | def variable_changed(post: variablelib.Variable, pre: variablelib.Variab...
function mask_variable_updates (line 589) | def mask_variable_updates(
function apply_variable_updates (line 617) | def apply_variable_updates(args_tree: A, updates_tree: A):
function treemap_copy_args (line 628) | def treemap_copy_args(f):
function check_same_variables (line 636) | def check_same_variables(inputs, outputs, transform_name: str = ''):
function update_carry_variables (line 650) | def update_carry_variables(init_val, val_out):
FILE: flax/nnx/filterlib.py
function to_predicate (line 32) | def to_predicate(filter: Filter) -> Predicate:
function filters_to_predicates (line 57) | def filters_to_predicates(
class HasTag (line 71) | class HasTag(tp.Protocol):
function _has_tag (line 75) | def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]:
class WithTag (line 80) | class WithTag:
method __call__ (line 83) | def __call__(self, path: PathParts, x: tp.Any):
method __repr__ (line 86) | def __repr__(self):
class PathContains (line 91) | class PathContains:
method __call__ (line 95) | def __call__(self, path: PathParts, x: tp.Any):
method __repr__ (line 100) | def __repr__(self):
class PathIn (line 104) | class PathIn:
method __init__ (line 105) | def __init__(self, *paths: PathParts):
method __call__ (line 108) | def __call__(self, path: PathParts, x: tp.Any):
method __repr__ (line 111) | def __repr__(self):
method __eq__ (line 115) | def __eq__(self, other):
method __hash__ (line 118) | def __hash__(self):
class OfType (line 123) | class OfType:
method __call__ (line 126) | def __call__(self, path: PathParts, x: tp.Any):
method __repr__ (line 129) | def __repr__(self):
class Any (line 133) | class Any:
method __init__ (line 134) | def __init__(self, *filters: Filter):
method __call__ (line 139) | def __call__(self, path: PathParts, x: tp.Any):
method __repr__ (line 142) | def __repr__(self):
method __eq__ (line 145) | def __eq__(self, other):
method __hash__ (line 148) | def __hash__(self):
class All (line 152) | class All:
method __init__ (line 153) | def __init__(self, *filters: Filter):
method __call__ (line 158) | def __call__(self, path: PathParts, x: tp.Any):
method __repr__ (line 161) | def __repr__(self):
method __eq__ (line 164) | def __eq__(self, other):
method __hash__ (line 167) | def __hash__(self):
class Not (line 171) | class Not:
method __init__ (line 172) | def __init__(self, collection_filter: Filter, /):
method __call__ (line 175) | def __call__(self, path: PathParts, x: tp.Any):
method __repr__ (line 178) | def __repr__(self):
method __eq__ (line 181) | def __eq__(self, other):
method __hash__ (line 184) | def __hash__(self):
class Everything (line 188) | class Everything:
method __call__ (line 189) | def __call__(self, path: PathParts, x: tp.Any):
method __repr__ (line 192) | def __repr__(self):
method __eq__ (line 195) | def __eq__(self, other):
method __hash__ (line 198) | def __hash__(self):
class Nothing (line 202) | class Nothing:
method __call__ (line 203) | def __call__(self, path: PathParts, x: tp.Any):
method __repr__ (line 206) | def __repr__(self):
method __eq__ (line 209) | def __eq__(self, other):
method __hash__ (line 212) | def __hash__(self):
FILE: flax/nnx/graphlib.py
function _tree_mode_suggestion_api (line 53) | def _tree_mode_suggestion_api(fn_name: str) -> str:
function _tree_mode_suggestion_transform (line 63) | def _tree_mode_suggestion_transform(fn_name: str) -> str:
function _check_valid_pytree (line 74) | def _check_valid_pytree(
class NoUpdate (line 103) | class NoUpdate: ...
class Repeated (line 111) | class Repeated: ...
class ArrayRefOutput (line 119) | class ArrayRefOutput(reprlib.Representable):
method __nnx_repr__ (line 122) | def __nnx_repr__(self):
method __treescope_repr__ (line 126) | def __treescope_repr__(self, path, subtree_renderer):
function is_node_leaf (line 149) | def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[LeafType]:
class IndexMap (line 153) | class IndexMap(dict[Index, tp.Any]):
method from_refmap (line 155) | def from_refmap(refmap: RefMap) -> IndexMap:
class RefMap (line 166) | class RefMap(tp.MutableMapping[tp.Any, int], reprlib.MappingReprMixin):
method __init__ (line 169) | def __init__(
method from_indexmap (line 181) | def from_indexmap(indexmap: IndexMap) -> RefMap:
method get (line 186) | def get(self, key: tp.Any, default: int | None = None) -> int | None: ...
method __getitem__ (line 189) | def __getitem__(self, key: tp.Any) -> int:
method __setitem__ (line 192) | def __setitem__(self, key: tp.Any, value: int):
method __delitem__ (line 195) | def __delitem__(self, key: tp.Any):
method __len__ (line 198) | def __len__(self) -> int:
method __contains__ (line 201) | def __contains__(self, key: tp.Any) -> bool:
method __iter__ (line 204) | def __iter__(self) -> tp.Iterator[tp.Any]:
method items (line 208) | def items(self) -> tp.ItemsView[tp.Any, int]:
class NodeImplBase (line 222) | class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
method node_dict (line 226) | def node_dict(self, node: Node) -> dict[Key, tp.Any]:
class GraphNodeImpl (line 235) | class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
class PytreeNodeImpl (line 244) | class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
function register_graph_node_type (line 259) | def register_graph_node_type(
function register_pytree_node_type (line 282) | def register_pytree_node_type(
function is_node (line 302) | def is_node(x: tp.Any) -> bool:
function is_graph_node (line 310) | def is_graph_node(x: tp.Any) -> bool:
function is_node_type (line 318) | def is_node_type(x: type[tp.Any]) -> bool:
function get_node_impl (line 322) | def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None:
function get_node_impl_for_type (line 338) | def get_node_impl_for_type(
function _type_aware_sort (line 351) | def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
class NodeRef (line 363) | class NodeRef(tp.Generic[Node], reprlib.Representable):
method __nnx_repr__ (line 366) | def __nnx_repr__(self):
method __treescope_repr__ (line 370) | def __treescope_repr__(self, path, subtree_renderer):
class VariableDef (line 387) | class VariableDef(reprlib.Representable, tp.Generic[Node]):
method with_no_outer_index (line 394) | def with_no_outer_index(self) -> VariableDef:
method with_same_outer_index (line 405) | def with_same_outer_index(self) -> VariableDef:
method with_matching_outer_index (line 416) | def with_matching_outer_index(self, other) -> VariableDef:
method __nnx_repr__ (line 427) | def __nnx_repr__(self):
method __treescope_repr__ (line 434) | def __treescope_repr__(self, path, subtree_renderer):
class ArrayRefDef (line 456) | class ArrayRefDef(reprlib.Representable):
method with_no_outer_index (line 460) | def with_no_outer_index(self):
method with_same_outer_index (line 466) | def with_same_outer_index(self):
method with_matching_outer_index (line 472) | def with_matching_outer_index(self, other):
method __nnx_repr__ (line 478) | def __nnx_repr__(self):
method __treescope_repr__ (line 483) | def __treescope_repr__(self, path, subtree_renderer):
class NodeDef (line 497) | class NodeDef(tp.Generic[Node], reprlib.Representable):
method with_no_outer_index (line 508) | def with_no_outer_index(self) -> NodeDef[Node]:
method with_same_outer_index (line 517) | def with_same_outer_index(self) -> NodeDef[Node]:
method with_matching_outer_index (line 526) | def with_matching_outer_index(self, other) -> NodeDef[Node]:
method __nnx_repr__ (line 535) | def __nnx_repr__(self):
method __treescope_repr__ (line 544) | def __treescope_repr__(self, path, subtree_renderer):
class TreeNodeDef (line 567) | class TreeNodeDef(tp.Generic[Node]):
method with_no_outer_index (line 572) | def with_no_outer_index(self) -> TreeNodeDef[Node]:
method with_same_outer_index (line 575) | def with_same_outer_index(self) -> TreeNodeDef[Node]:
method with_matching_outer_index (line 578) | def with_matching_outer_index(self, other) -> TreeNodeDef[Node]:
class NodeAttr (line 591) | class NodeAttr:
class LeafAttr (line 598) | class LeafAttr:
class GraphDef (line 613) | class GraphDef(tp.Generic[Node]):
method __hash__ (line 618) | def __hash__(self) -> int:
method with_no_outer_index (line 621) | def with_no_outer_index(self) -> GraphDef[Node]:
method with_matching_outer_index (line 631) | def with_matching_outer_index(self, other) -> GraphDef[Node]:
method with_same_outer_index (line 641) | def with_same_outer_index(self) -> GraphDef[Node]:
method apply (line 652) | def apply(
function _tree_flatten (line 678) | def _tree_flatten(
function flatten (line 750) | def flatten( # type: ignore[invalid-annotation]
function flatten (line 759) | def flatten( # type: ignore[invalid-annotation]
function flatten (line 772) | def flatten( # type: ignore[invalid-annotation]
function flatten (line 785) | def flatten( # type: ignore[invalid-annotation]
function flatten (line 797) | def flatten( # type: ignore[invalid-annotation]
class DataElem (line 862) | class DataElem:
class StaticElem (line 867) | class StaticElem:
function _graph_flatten (line 870) | def _graph_flatten(
function _get_sorted_leaves (line 1036) | def _get_sorted_leaves(
function _tree_unflatten (line 1054) | def _tree_unflatten(
function unflatten (line 1083) | def unflatten( # type: ignore[invalid-annotation]
function _graph_unflatten (line 1164) | def _graph_unflatten(
function graph_pop (line 1356) | def graph_pop(
function _graph_pop (line 1372) | def _graph_pop(
function _graph_update_dynamic (line 1428) | def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]):
class StaticCache (line 1500) | class StaticCache(tp.NamedTuple):
method create (line 1509) | def create(
class GraphContext (line 1529) | class GraphContext(threading.local):
class set_graph_mode (line 1544) | class set_graph_mode(BaseConfigContext):
class set_graph_updates (line 1549) | class set_graph_updates(BaseConfigContext):
function static_cache (line 1555) | def static_cache(static_cache: tp.MutableMapping[tp.Any, StaticCache]):
function _cached_partial (line 1571) | def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args, graph: bo...
class SplitContext (line 1681) | class SplitContext:
method split (line 1687) | def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ....
method split (line 1690) | def split( # type: ignore[invalid-annotation]
method split (line 1695) | def split(
method split (line 1704) | def split(
method flatten (line 1722) | def flatten( # type: ignore[invalid-annotation]
method flatten (line 1731) | def flatten( # type: ignore[invalid-annotation]
method flatten (line 1738) | def flatten( # type: ignore[invalid-annotation]
method flatten (line 1746) | def flatten( # type: ignore[invalid-annotation]
method flatten (line 1759) | def flatten( # type: ignore[invalid-annotation]
function split_context (line 1841) | def split_context(ctxtag: tp.Hashable | None = None):
class MergeContext (line 1858) | class MergeContext:
method merge (line 1863) | def merge( # type: ignore[invalid-annotation]
method unflatten (line 1887) | def unflatten( # type: ignore[invalid-annotation]
function merge_context (line 1981) | def merge_context() -> tp.Generator[MergeContext, None, None]: ... # ty...
function merge_context (line 1984) | def merge_context(
function merge_context (line 1988) | def merge_context(ctxtag: tp.Hashable | None = None, inner: bool | None ...
class UpdateContext (line 2007) | class UpdateContext:
method __hash__ (line 2019) | def __hash__(self):
method __eq__ (line 2022) | def __eq__(self, other):
method flatten_end (line 2025) | def flatten_end(self, ref_index: RefMap):
method unflatten_end (line 2037) | def unflatten_end(self, index_ref: IndexMap, inner_merge: bool):
class UpdateContextManager (line 2045) | class UpdateContextManager:
method __enter__ (line 2048) | def __enter__(self):
method __exit__ (line 2069) | def __exit__(self, *args):
method __call__ (line 2086) | def __call__(self, f: F) -> F:
function update_context (line 2095) | def update_context(tag: tp.Hashable):
function current_update_context (line 2198) | def current_update_context(tag: tp.Hashable) -> UpdateContext:
function _split_state (line 2210) | def _split_state(
function split (line 2224) | def split( # type: ignore[invalid-annotation]
function split (line 2228) | def split( # type: ignore[invalid-annotation]
function split (line 2232) | def split( # type: ignore[invalid-annotation]
function split (line 2244) | def split( # type: ignore[invalid-annotation]
function _to_nested_state (line 2325) | def _to_nested_state(
function _merge_to_flat_state (line 2340) | def _merge_to_flat_state(states: tp.Iterable[tp.Any]):
function merge (line 2355) | def merge( # type: ignore[invalid-annotation]
function update (line 2416) | def update(node, state: tp.Any, /, *states: tp.Any) -> None:
function state (line 2461) | def state(node, /, *, graph: bool | None = None) -> GraphState: ...
function state (line 2463) | def state(node, first: filterlib.Filter, /, *, graph: bool | None = None...
function state (line 2465) | def state(
function state (line 2473) | def state(
function map (line 2530) | def map(
function graphdef (line 2572) | def graphdef(
function pop (line 2601) | def pop(
function pop (line 2609) | def pop(
function pop (line 2618) | def pop(
function clone (line 2681) | def clone(node: Node, variables: bool = True, *, graph: bool | None = No...
function vars_as (line 2708) | def vars_as(
function pure (line 2759) | def pure(tree: A) -> A:
function call (line 2810) | def call(
function set_metadata (line 2900) | def set_metadata(
function iter_graph (line 2937) | def iter_graph(
function _iter_graph (line 2987) | def _iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
function _iter_tree (line 3009) | def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
function iter_children (line 3059) | def iter_children(
function recursive_map (line 3131) | def recursive_map(
function _recursive_map_graph (line 3181) | def _recursive_map_graph(
function _recursive_map_tree (line 3221) | def _recursive_map_tree(
function find_duplicates (line 3275) | def find_duplicates(node: tp.Any, /, *, only: filterlib.Filter = ...) ->...
function _node_paths (line 3330) | def _node_paths(
class Static (line 3359) | class Static(tp.Generic[A]):
class GenericPytree (line 3370) | class GenericPytree: ...
function is_pytree_node (line 3376) | def is_pytree_node(
function _key_path_to_key (line 3391) | def _key_path_to_key(key: tp.Any) -> Key:
function jax_to_nnx_path (line 3408) | def jax_to_nnx_path(jax_path: tuple, /):
class IndexesPytreeDef (line 3412) | class IndexesPytreeDef(tp.NamedTuple):
function _flatten_pytree (line 3417) | def _flatten_pytree(pytree: tp.Any):
function _unflatten_pytree (line 3432) | def _unflatten_pytree(
function _list_set_key (line 3448) | def _list_set_key(x: list[tp.Any], key: int, value: tp.Any):
function _mutable_mapping_set_key (line 3467) | def _mutable_mapping_set_key(
function _mutable_mapping_pop_key (line 3473) | def _mutable_mapping_pop_key(x: tp.MutableMapping[Key, tp.Any], key: Key):
FILE: flax/nnx/helpers.py
class Dict (line 36) | class Dict(reprlib.MappingReprMixin, Module, tp.MutableMapping[str, A]):
method __init__ (line 54) | def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ...
method __init__ (line 57) | def __init__(
method __init__ (line 61) | def __init__(self, *args, **kwargs):
method __getitem__ (line 65) | def __getitem__(self, key) -> A:
method __setitem__ (line 71) | def __setitem__(self, key, value):
method __iter__ (line 74) | def __iter__(self) -> tp.Iterator[str]:
method __len__ (line 77) | def __len__(self) -> int:
method __hash__ (line 83) | def __hash__(self) -> int:
method __delitem__ (line 86) | def __delitem__(self, key: str) -> None:
method __getattr__ (line 93) | def __getattr__(self, key: str) -> A:
method __setattr__ (line 95) | def __setattr__(self, key: str, value: A) -> None:
class List (line 99) | class List(reprlib.SequenceReprMixin, Module, tp.MutableSequence[A]):
method __init__ (line 116) | def __init__(self, it: tp.Iterable[A] | None = None, /):
method _get_elem (line 126) | def _get_elem(self, key: int) -> A:
method _set_elem (line 129) | def _set_elem(self, key: int, value: A) -> None:
method _del_elem (line 132) | def _del_elem(self, key: int) -> None:
method __len__ (line 135) | def __len__(self) -> int:
method append (line 138) | def append(self, value: A) -> None:
method insert (line 142) | def insert(self, index: int, value: A) -> None:
method __iter__ (line 158) | def __iter__(self) -> tp.Iterator[A]:
method __getitem__ (line 163) | def __getitem__(self, index: int) -> A: ...
method __getitem__ (line 165) | def __getitem__(self, index: slice) -> tp.List[A]: ...
method __getitem__ (line 166) | def __getitem__(self, index: int | slice) -> A | tp.List[A]:
method __setitem__ (line 179) | def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -...
method _graph_node_set_key (line 198) | def _graph_node_set_key(self, key: str, value: tp.Any):
method __delitem__ (line 209) | def __delitem__(self, index: int | slice) -> None:
class Sequential (line 230) | class Sequential(Module):
method __init__ (line 253) | def __init__(self, *fns: tp.Callable[..., tp.Any]):
method __call__ (line 260) | def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) ->...
class ModuleDefApply (line 294) | class ModuleDefApply(tp.Protocol, tp.Generic[M]):
method __call__ (line 295) | def __call__(
class TrainState (line 300) | class TrainState(tp.Generic[M], struct.PyTreeNode):
method create (line 308) | def create(
method __getattr__ (line 328) | def __getattr__(self, key: str) -> tp.Any: ...
method apply (line 330) | def apply(
method apply_gradients (line 349) | def apply_gradients(self: TS, grads: State, **kwargs) -> TS:
function has_keyword_arg (line 361) | def has_keyword_arg(func: tp.Callable[..., tp.Any], name: str) -> bool:
FILE: flax/nnx/ids.py
class UUIDManager (line 19) | class UUIDManager:
method __init__ (line 31) | def __init__(self):
method __call__ (line 35) | def __call__(self):
class UUID (line 44) | class UUID:
method __init__ (line 47) | def __init__(self, rawid):
method __eq__ (line 50) | def __eq__(self, other):
method __hash__ (line 53) | def __hash__(self):
method __repr__ (line 56) | def __repr__(self):
method __deepcopy__ (line 59) | def __deepcopy__(self, memo):
method __copy__ (line 63) | def __copy__(self):
FILE: flax/nnx/module.py
class ModuleMeta (line 49) | class ModuleMeta(PytreeMeta):
class Module (line 55) | class Module(Pytree, metaclass=ModuleMeta):
method sow (line 86) | def sow(
method perturb (line 187) | def perturb(
method iter_modules (line 279) | def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]:
method iter_children (line 293) | def iter_children(self) -> tp.Iterator[tuple[Key, Module]]:
method set_attributes (line 308) | def set_attributes(
method train (line 369) | def train(self, **attributes):
method eval (line 405) | def eval(self, **attributes):
function view (line 440) | def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found...
function with_attributes (line 518) | def with_attributes(
function _parse_docstring_args (line 589) | def _parse_docstring_args(doc_str: str) -> dict[str, str]:
function view_info (line 620) | def view_info(node: Module, /, *, only: filterlib.Filter = ..., graph: b...
function first_from (line 703) | def first_from(*args: tp.Optional[A], error_msg: str) -> A:
function iter_modules (line 719) | def iter_modules(
function capture (line 774) | def capture(
function capture (line 782) | def capture(
function capture (line 789) | def capture(fn: tp.Callable[P, R] | type[variableslib.Variable], *var_ty...
function _collect_state_by_path (line 928) | def _collect_state_by_path(state):
function _navigate_to_path (line 946) | def _navigate_to_path(state, path):
function _extract_captures (line 955) | def _extract_captures(module, state, var_types):
function _add_capturing (line 969) | def _add_capturing(cls, variable_type):
function _remove_capturing (line 986) | def _remove_capturing(cls):
FILE: flax/nnx/nn/activations.py
class PReLU (line 81) | class PReLU(nnx.Module):
method __init__ (line 112) | def __init__(
method __call__ (line 128) | def __call__(self, inputs: Array) -> Array:
FILE: flax/nnx/nn/attention.py
function dot_product_attention_weights (line 52) | def dot_product_attention_weights(
function dot_product_attention (line 190) | def dot_product_attention(
class MultiHeadAttention (line 322) | class MultiHeadAttention(Module):
method __init__ (line 407) | def __init__(
method __call__ (line 577) | def __call__(
method init_cache (line 747) | def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
method set_view (line 781) | def set_view(
function make_attention_mask (line 828) | def make_attention_mask(
function make_causal_mask (line 860) | def make_causal_mask(
function combine_masks (line 888) | def combine_masks(
FILE: flax/nnx/nn/dtypes.py
function canonicalize_dtype (line 22) | def canonicalize_dtype(
function promote_dtype (line 54) | def promote_dtype(args: T, /, *, dtype=None, inexact=True) -> T:
FILE: flax/nnx/nn/initializers.py
function zeros_init (line 41) | def zeros_init() -> Initializer:
function ones_init (line 54) | def ones_init() -> Initializer:
FILE: flax/nnx/nn/linear.py
function canonicalize_padding (line 52) | def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
function _conv_dimension_numbers (line 76) | def _conv_dimension_numbers(input_shape):
function _normalize_axes (line 85) | def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
function _canonicalize_tuple (line 90) | def _canonicalize_tuple(x: tp.Sequence[int] | int) -> tuple[int, ...]:
class LinearGeneral (line 97) | class LinearGeneral(Module):
method __init__ (line 156) | def __init__(
method __call__ (line 253) | def __call__(self, inputs: Array, out_sharding = None) -> Array:
class Linear (line 313) | class Linear(Module):
method __init__ (line 357) | def __init__(
method __call__ (line 400) | def __call__(self, inputs: Array, out_sharding = None) -> Array:
class Einsum (line 435) | class Einsum(Module):
method __init__ (line 482) | def __init__(
method __call__ (line 527) | def __call__(
method _infer_broadcasted_bias_shape (line 579) | def _infer_broadcasted_bias_shape(
method _einsum_str_check (line 611) | def _einsum_str_check(self, einsum_str):
class Conv (line 624) | class Conv(Module):
method __init__ (line 715) | def __init__(
method __call__ (line 782) | def __call__(self, inputs: Array, out_sharding=None) -> Array:
class ConvTranspose (line 912) | class ConvTranspose(Module):
method __init__ (line 1018) | def __init__(
method __call__ (line 1079) | def __call__(self, inputs: Array) -> Array:
class Embed (line 1217) | class Embed(Module):
method __init__ (line 1271) | def __init__(
method __call__ (line 1294) | def __call__(self, inputs: Array, out_sharding=None) -> Array:
method attend (line 1321) | def attend(self, query: Array, out_sharding=None) -> Array:
FILE: flax/nnx/nn/lora.py
class LoRAParam (line 36) | class LoRAParam(variablelib.Param[A]):
class LoRA (line 40) | class LoRA(Module):
method __init__ (line 89) | def __init__(
method __call__ (line 123) | def __call__(self, x: jax.Array):
class LoRALinear (line 135) | class LoRALinear(Linear):
method __init__ (line 179) | def __init__(
method __call__ (line 212) | def __call__(self, x: jax.Array, out_sharding = None):
FILE: flax/nnx/nn/normalization.py
function _canonicalize_axes (line 35) | def _canonicalize_axes(rank: int, axes: Axes) -> tp.Tuple[int, ...]:
function _abs_sq (line 42) | def _abs_sq(x):
function _compute_stats (line 50) | def _compute_stats(
function _normalize (line 134) | def _normalize(
function _l2_normalize (line 186) | def _l2_normalize(x, axis=None, eps=1e-12):
class BatchNorm (line 201) | class BatchNorm(Module):
method __init__ (line 289) | def __init__(
method __call__ (line 343) | def __call__(
method set_view (line 414) | def set_view(
class LayerNorm (line 428) | class LayerNorm(Module):
method __init__ (line 491) | def __init__(
method __call__ (line 541) | def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
class RMSNorm (line 579) | class RMSNorm(Module):
method __init__ (line 636) | def __init__(
method __call__ (line 675) | def __call__(self, x, mask: tp.Optional[jax.Array] = None):
class GroupNorm (line 713) | class GroupNorm(Module):
method __init__ (line 793) | def __init__(
method __call__ (line 872) | def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
class WeightNorm (line 928) | class WeightNorm(nnx.Module):
method __init__ (line 979) | def __init__(
method _weightnorm_inplace (line 1014) | def _weightnorm_inplace(self, path, param):
method __call__ (line 1051) | def __call__(self, x: Array, *args, **kwargs) -> Array:
class InstanceNorm (line 1071) | class InstanceNorm(Module):
method __init__ (line 1149) | def __init__(
method __call__ (line 1196) | def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
class SpectralNorm (line 1242) | class SpectralNorm(Module):
method __init__ (line 1304) | def __init__(
method __call__ (line 1370) | def __call__(
method _spectral_normalize_inplace (line 1400) | def _spectral_normalize_inplace(self, path, orig_param, update_stats):
FILE: flax/nnx/nn/recurrent.py
class RNNCellBase (line 47) | class RNNCellBase(Module):
method initialize_carry (line 50) | def initialize_carry(
method __call__ (line 68) | def __call__(
method num_feature_axes (line 86) | def num_feature_axes(self) -> int:
function modified_orthogonal (line 90) | def modified_orthogonal(key: Array, shape: Shape, dtype: Dtype = jnp.flo...
class LSTMCell (line 95) | class LSTMCell(RNNCellBase):
method __init__ (line 114) | def __init__(
method __call__ (line 195) | def __call__(
method initialize_carry (line 218) | def initialize_carry(
method num_feature_axes (line 252) | def num_feature_axes(self) -> int:
class OptimizedLSTMCell (line 256) | class OptimizedLSTMCell(RNNCellBase):
method __init__ (line 303) | def __init__(
method __call__ (line 373) | def __call__(
method initialize_carry (line 406) | def initialize_carry(
method num_feature_axes (line 441) | def num_feature_axes(self) -> int:
class SimpleCell (line 445) | class SimpleCell(RNNCellBase):
method __init__ (line 467) | def __init__(
method __call__ (line 537) | def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]...
method initialize_carry (line 544) | def initialize_carry(
method num_feature_axes (line 578) | def num_feature_axes(self) -> int:
class GRUCell (line 582) | class GRUCell(RNNCellBase):
method __init__ (line 624) | def __init__(
method __call__ (line 694) | def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]...
method initialize_carry (line 727) | def initialize_carry(
method num_feature_axes (line 762) | def num_feature_axes(self) -> int:
class RNN (line 766) | class RNN(Module):
method __init__ (line 774) | def __init__(
method __call__ (line 808) | def __call__(
function _select_last_carry (line 920) | def _select_last_carry(sequence: A, seq_lengths: jnp.ndarray) -> A:
function _expand_dims_like (line 929) | def _expand_dims_like(x, target):
function flip_sequences (line 934) | def flip_sequences(
function _concatenate (line 994) | def _concatenate(a: Array, b: Array) -> Array:
class RNNBase (line 999) | class RNNBase(Protocol):
method __call__ (line 1000) | def __call__(
class Bidirectional (line 1014) | class Bidirectional(Module):
method __init__ (line 1046) | def __init__(
method __call__ (line 1075) | def __call__(
FILE: flax/nnx/nn/stochastic.py
class Dropout (line 27) | class Dropout(Module):
method __init__ (line 71) | def __init__(
method __call__ (line 96) | def __call__(
method set_view (line 159) | def set_view(
FILE: flax/nnx/proxy_caller.py
function _identity (line 25) | def _identity(x):
class GetItem (line 29) | class GetItem:
class GetAttr (line 34) | class GetAttr:
class DelayedAccessor (line 39) | class DelayedAccessor:
method __call__ (line 42) | def __call__(self, x):
method __getattr__ (line 50) | def __getattr__(self, name):
method __getitem__ (line 53) | def __getitem__(self, key):
class _AccessorCall (line 59) | class _AccessorCall(tp.Protocol):
method __call__ (line 60) | def __call__(self, accessor: DelayedAccessor, /, *args, **kwargs) -> t...
class CallableProxy (line 63) | class CallableProxy:
method __init__ (line 64) | def __init__(
method __call__ (line 70) | def __call__(self, *args, **kwargs):
method __getattr__ (line 73) | def __getattr__(self, name) -> CallableProxy:
method __getitem__ (line 76) | def __getitem__(self, key) -> CallableProxy:
class ApplyCaller (line 80) | class ApplyCaller(tp.Protocol, tp.Generic[A]):
method __getattr__ (line 81) | def __getattr__(self, __name) -> ApplyCaller[A]:
method __getitem__ (line 84) | def __getitem__(self, __name) -> ApplyCaller[A]:
method __call__ (line 87) | def __call__(self, *args, **kwargs) -> tuple[tp.Any, A]:
FILE: flax/nnx/pytreelib.py
function data (line 60) | def data(value: A, /) -> A: ...
function data (line 62) | def data(
function data (line 73) | def data(value: tp.Any = MISSING, /, **kwargs) -> tp.Any:
function register_data_type (line 116) | def register_data_type(type_: T, /) -> T:
function is_data (line 156) | def is_data(value: tp.Any, /) -> bool:
function has_data (line 201) | def has_data(value: tp.Any, /) -> list[tp.Any]:
function static (line 221) | def static(value: A, /) -> A: ...
function static (line 223) | def static(
function static (line 234) | def static(value: tp.Any = MISSING, /, **kwargs) -> tp.Any:
function dataclass (line 272) | def dataclass(cls: type[A], /) -> type[A]: ...
function dataclass (line 274) | def dataclass(
function dataclass (line 287) | def dataclass(
function _collect_stats (line 312) | def _collect_stats(
class ObjectContext (line 350) | class ObjectContext(threading.local):
class PytreeState (line 358) | class PytreeState(reprlib.Representable):
method __init__ (line 361) | def __init__(self, initializing: bool = False, is_setup: bool = False):
method trace_state (line 367) | def trace_state(self) -> tracers.TraceState:
method initializing (line 371) | def initializing(self) -> bool:
method is_setup (line 375) | def is_setup(self) -> bool:
method __nnx_repr__ (line 378) | def __nnx_repr__(self):
method __treescope_repr__ (line 382) | def __treescope_repr__(self, path, subtree_renderer):
function _flatten_pytree_state (line 391) | def _flatten_pytree_state(state: PytreeState):
function _unflatten_pytree_state (line 395) | def _unflatten_pytree_state(static: tuple[bool, bool], _):
function check_pytree (line 407) | def check_pytree(pytree):
class PytreeMeta (line 416) | class PytreeMeta(ABCMeta):
method __call__ (line 419) | def __call__(cls, *args: Any, **kwargs: Any) -> Any:
method _pytree_meta_construct (line 422) | def _pytree_meta_construct(cls, self, *args, **kwargs):
function _graph_node_meta_call (line 427) | def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P:
class ArrayRepr (line 447) | class ArrayRepr(reprlib.Representable):
method from_array (line 452) | def from_array(array: jax.Array | np.ndarray) -> ArrayRepr:
method __nnx_repr__ (line 455) | def __nnx_repr__(self):
class VariableRepr (line 461) | class VariableRepr(reprlib.Representable):
method __nnx_repr__ (line 466) | def __nnx_repr__(self):
class MutableArrayRepr (line 472) | class MutableArrayRepr(reprlib.Representable):
method from_array (line 477) | def from_array(array: jax.Array | np.ndarray) -> MutableArrayRepr:
method __nnx_repr__ (line 480) | def __nnx_repr__(self):
function _to_shape_dtype (line 485) | def _to_shape_dtype(x):
class AttributeStatus (line 500) | class AttributeStatus(tp.NamedTuple):
class Pytree (line 505) | class Pytree(reprlib.Representable, metaclass=PytreeMeta):
method __init_subclass__ (line 513) | def __init_subclass__(
method _object__nodes (line 625) | def _object__nodes(self):
method _object__state (line 634) | def _object__state(self):
method __setattr__ (line 644) | def __setattr__(self, name: str, value: Any) -> None:
method _setattr (line 647) | def _setattr(self, name, value: tp.Any) -> None:
method _check_value (line 671) | def _check_value(self, key, value, new_status: AttributeStatus | None):
method _check_valid_context (line 774) | def _check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:
method __deepcopy__ (line 778) | def __deepcopy__(self: P, memo=None) -> P:
method __nnx_repr__ (line 784) | def __nnx_repr__(self):
method __treescope_repr__ (line 834) | def __treescope_repr__(self, path, subtree_renderer):
method __getstate__ (line 889) | def __getstate__(self):
method __setstate__ (line 892) | def __setstate__(self, state):
method _pytree__flatten_with_paths (line 900) | def _pytree__flatten_with_paths(self):
method _pytree__flatten (line 933) | def _pytree__flatten(self):
method _pytree__unflatten (line 962) | def _pytree__unflatten(
method _graph_node_flatten (line 978) | def _graph_node_flatten(self):
method _graph_node_set_key (line 999) | def _graph_node_set_key(self, key, value: tp.Any):
method _graph_node_pop_key (line 1013) | def _graph_node_pop_key(self, key):
method __delattr__ (line 1020) | def __delattr__(self, name: str) -> None:
method _graph_node_create_empty (line 1030) | def _graph_node_create_empty(node_type: tp.Type[P]) -> P:
method _graph_node_clear (line 1034) | def _graph_node_clear(self):
method _graph_node_init (line 1037) | def _graph_node_init(self, attributes: tp.Iterable[tuple[str | int, tp...
method __call__ (line 1042) | def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ...
class Object (line 1045) | class Object(Pytree, pytree=False):
method __init_subclass__ (line 1048) | def __init_subclass__(cls, **kwargs):
function _maybe_int (line 1057) | def _maybe_int(x):
function _get_str (line 1063) | def _get_str(x):
FILE: flax/nnx/reprlib.py
function supports_color (line 26) | def supports_color() -> bool:
class Color (line 42) | class Color(tp.NamedTuple):
class ReprContext (line 92) | class ReprContext(threading.local):
function colorized (line 100) | def colorized(x, /):
class Object (line 133) | class Object:
method elem_sep (line 144) | def elem_sep(self):
class Attr (line 149) | class Attr:
class Representable (line 158) | class Representable:
method __nnx_repr__ (line 161) | def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]:
method __repr__ (line 164) | def __repr__(self) -> str:
method __str__ (line 172) | def __str__(self) -> str:
function get_repr (line 176) | def get_repr(obj: Representable) -> str:
class MappingReprMixin (line 234) | class MappingReprMixin(Representable):
method __nnx_repr__ (line 235) | def __nnx_repr__(self):
class PrettyMapping (line 242) | class PrettyMapping(Representable):
method __nnx_repr__ (line 245) | def __nnx_repr__(self):
class SequenceReprMixin (line 253) | class SequenceReprMixin(Representable):
method __nnx_repr__ (line 254) | def __nnx_repr__(self):
class PrettySequence (line 262) | class PrettySequence(Representable):
method __nnx_repr__ (line 265) | def __nnx_repr__(self):
FILE: flax/nnx/rnglib.py
class KeylessInitializer (line 45) | class KeylessInitializer(tp.Protocol):
method __call__ (line 46) | def __call__(
function _to_keyless (line 55) | def _to_keyless(
function _function_to_method (line 61) | def _function_to_method(random_f):
function _initializer_to_method (line 69) | def _initializer_to_method(
class RngState (line 85) | class RngState(Variable[jax.Array]):
class RngCount (line 89) | class RngCount(RngState): ...
class RngKey (line 92) | class RngKey(RngState): ...
class RngStream (line 98) | class RngStream(Pytree):
method __init__ (line 100) | def __init__(
method __call__ (line 120) | def __call__(self) -> jax.Array:
method split (line 126) | def split(self, k: int | tuple[int, ...]):
method fork (line 130) | def fork(self, *, split: int | tuple[int, ...] | None = None):
class Rngs (line 323) | class Rngs(Pytree):
method __init__ (line 372) | def __init__(
method _get_stream (line 401) | def _get_stream(self, name: str, error_type: type[Exception]) -> RngSt...
method __getitem__ (line 412) | def __getitem__(self, name: str):
method __getattr__ (line 415) | def __getattr__(self, name: str):
method __call__ (line 418) | def __call__(self):
method __iter__ (line 421) | def __iter__(self) -> tp.Iterator[str]:
method __len__ (line 426) | def __len__(self) -> int:
method __contains__ (line 431) | def __contains__(self, name: tp.Any) -> bool:
method items (line 434) | def items(self):
method split (line 439) | def split(self, k: tp.Mapping[filterlib.Filter, int | tuple[int, ...]]...
method fork (line 487) | def fork(
class SplitBackups (line 716) | class SplitBackups(struct.PyTreeNode, tp.Iterable[StreamBackup]):
method __iter__ (line 719) | def __iter__(self) -> tp.Iterator[StreamBackup]:
method __enter__ (line 722) | def __enter__(self):
method __exit__ (line 725) | def __exit__(self, *args):
function split_rngs (line 730) | def split_rngs(
function split_rngs (line 740) | def split_rngs(
function split_rngs (line 750) | def split_rngs(
function split_rngs (line 757) | def split_rngs(
function _graph_split_rngs (line 899) | def _graph_split_rngs(
function _tree_split_rngs (line 933) | def _tree_split_rngs(
function fork_rngs (line 968) | def fork_rngs(
function fork_rngs (line 978) | def fork_rngs(
function fork_rngs (line 985) | def fork_rngs(
function backup_keys (line 1073) | def backup_keys(node: tp.Any, /, *, graph: bool | None = None):
function _scalars_only (line 1080) | def _scalars_only(
function _match_shape (line 1093) | def _match_shape(
function reseed (line 1101) | def reseed(
function restore_rngs (line 1172) | def restore_rngs(backups: tp.Iterable[StreamBackup], /):
FILE: flax/nnx/spmd.py
function add_axis (line 35) | def add_axis(tree: A, index: int, transform_metadata: tp.Mapping) -> A:
function remove_axis (line 65) | def remove_axis(
function _get_partition_name_and_metadata (line 101) | def _get_partition_name_and_metadata(
function with_partitioning (line 118) | def with_partitioning(
function get_var_pspec (line 133) | def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None:
function get_partition_spec (line 149) | def get_partition_spec(tree: A) -> A:
function get_named_sharding (line 164) | def get_named_sharding(tree: A, mesh: jax.sharding.Mesh) -> A:
function get_abstract_model (line 174) | def get_abstract_model(init_fn, mesh, *, graph: bool | None = None):
function abstract_with_sharding (line 185) | def abstract_with_sharding(
FILE: flax/nnx/statelib.py
class NestedStateRepr (line 38) | class NestedStateRepr(reprlib.Representable):
method __init__ (line 39) | def __init__(self, state: State):
method __nnx_repr__ (line 42) | def __nnx_repr__(self):
method __treescope_repr__ (line 50) | def __treescope_repr__(self, path, subtree_renderer):
class FlatState (line 59) | class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.Representable):
method __init__ (line 65) | def __init__(self, items: tp.Iterable[tuple[PathParts, V]], /, *, sort...
method from_sorted_keys_values (line 76) | def from_sorted_keys_values(
method paths (line 85) | def paths(self) -> tp.Tuple[PathParts, ...]:
method leaves (line 89) | def leaves(self) -> list[V]:
method __nnx_repr__ (line 92) | def __nnx_repr__(self):
method __getitem__ (line 99) | def __getitem__(self, index: int) -> tuple[PathParts, V]: ...
method __getitem__ (line 101) | def __getitem__(self, index: slice) -> FlatState[V]: ...
method __getitem__ (line 102) | def __getitem__(
method __len__ (line 109) | def __len__(self) -> int:
method __iter__ (line 112) | def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]:
method to_nested_state (line 115) | def to_nested_state(self) -> State[Key, V]:
method split (line 119) | def split(self, first: filterlib.Filter, /) -> FlatState[V]: ...
method split (line 122) | def split(
method split (line 131) | def split(
method split (line 135) | def split( # type: ignore[misc]
method filter (line 155) | def filter(self, first: filterlib.Filter, /) -> FlatState[V]: ...
method filter (line 158) | def filter(
method filter (line 166) | def filter(
method merge (line 185) | def merge(
function _flat_state_pytree_flatten (line 201) | def _flat_state_pytree_flatten(x: FlatState[V]):
function _flat_state_pytree_unflatten (line 205) | def _flat_state_pytree_unflatten(
class State (line 221) | class State(MutableMapping[K, V], reprlib.Representable):
method __init__ (line 225) | def __init__(
method raw_mapping (line 251) | def raw_mapping(self) -> dict[K, tp.Mapping[K, tp.Any] | V]:
method __contains__ (line 254) | def __contains__(self, key) -> bool:
method __getitem__ (line 257) | def __getitem__(self, key: K) -> State | V: # type: ignore
method __getattr__ (line 263) | def __getattr__(self, key: K) -> State | V: # type: ignore[misc]
method __setitem__ (line 268) | def __setitem__(self, key: K, value: State | V) -> None:
method __delitem__ (line 278) | def __delitem__(self, key: K) -> None:
method __iter__ (line 281) | def __iter__(self) -> tp.Iterator[K]:
method __len__ (line 284) | def __len__(self) -> int:
method __nnx_repr__ (line 287) | def __nnx_repr__(self):
method __treescope_repr__ (line 295) | def __treescope_repr__(self, path, subtree_renderer):
method map (line 308) | def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]:
method flat_state (line 316) | def flat_state(self) -> FlatState[V]:
method from_flat_path (line 325) | def from_flat_path(
method to_pure_dict (line 337) | def to_pure_dict(self,
method replace_by_pure_dict (line 347) | def replace_by_pure_dict(self,
method split (line 359) | def split(self, first: filterlib.Filter, /) -> State[K, V]: ...
method split (line 362) | def split(
method split (line 371) | def split(
method split (line 375) | def split( # type: ignore[misc]
method filter (line 386) | def filter(
method filter (line 393) | def filter(
method filter (line 401) | def filter(
method merge (line 415) | def merge(cls, state: tp.Mapping[K, V], /, *states: tp.Mapping[K, V]):
method __or__ (line 423) | def __or__(self, other: State[K, V]) -> State[K, V]:
method __sub__ (line 428) | def __sub__(self, other: State[K, V]) -> State[K, V]:
method __init_subclass__ (line 436) | def __init_subclass__(cls) -> None:
function _state_flatten_with_keys (line 447) | def _state_flatten_with_keys(x: State):
function _state_unflatten (line 453) | def _state_unflatten(
function map_state (line 467) | def map_state(f: tp.Callable[[tuple, tp.Any], tp.Any], state: State) -> ...
function to_flat_state (line 483) | def to_flat_state(state: State) -> FlatState:
function from_flat_state (line 494) | def from_flat_state(
function to_pure_dict (line 511) | def to_pure_dict(
function restore_int_paths (line 529) | def restore_int_paths(pure_dict: dict[str, tp.Any]):
function replace_by_pure_dict (line 571) | def replace_by_pure_dict(
function split_state (line 606) | def split_state(state: State, first: filterlib.Filter, /) -> State: ...
function split_state (line 610) | def split_state(
function split_state (line 620) | def split_state(
function split_state (line 625) | def split_state( # type: ignore[misc]
function filter_state (line 674) | def filter_state(
function filter_state (line 682) | def filter_state(
function filter_state (line 691) | def filter_state(
function merge_state (line 739) | def merge_state(state: tp.Mapping, /, *states: tp.Mapping,
function diff (line 788) | def diff(state: State, other: State) -> State:
function _split_state (line 799) | def _split_state(
function create_path_filters (line 831) | def create_path_filters(state: State):
FILE: flax/nnx/summary.py
class NoneDumper (line 48) | class NoneDumper(yaml.SafeDumper):
class SizeBytes (line 56) | class SizeBytes(typing.SizeBytes):
method __repr__ (line 57) | def __repr__(self) -> str:
class ObjectInfo (line 61) | class ObjectInfo(tp.NamedTuple):
function _collect_stats (line 70) | def _collect_stats(
class ArrayRepr (line 121) | class ArrayRepr:
method from_array (line 126) | def from_array(cls, x: jax.Array | np.ndarray):
method __str__ (line 129) | def __str__(self):
class CallInfo (line 135) | class CallInfo:
class SimpleObjectRepr (line 145) | class SimpleObjectRepr:
method __init__ (line 146) | def __init__(self, obj: tp.Any):
method __str__ (line 149) | def __str__(self):
method __repr__ (line 152) | def __repr__(self):
function _to_dummy_array (line 156) | def _to_dummy_array(x):
function _pure_nnx_vjp (line 166) | def _pure_nnx_vjp(f, model, *args, **kwargs):
function filter_rng_streams (line 174) | def filter_rng_streams(row: CallInfo):
function
Copy disabled (too large)
Download .json
Condensed preview — 618 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (10,642K chars).
[
{
"path": ".git-blame-ignore-revs",
"chars": 54,
"preview": "# apply pyink\n40a6e074e5224d733f964be00e21e0a1cb98bd2e"
},
{
"path": ".github/ISSUE_TEMPLATE/bug_report.md",
"chars": 792,
"preview": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: bug\nassignees: ''\n\n---\n\nProvide as much"
},
{
"path": ".github/analytics/README.md",
"chars": 544,
"preview": "# Repo Analytics\n\nTo run the repo analytics follow the steps below:\n\n1. You must have a Github token, if you don't have "
},
{
"path": ".github/analytics/get_repo_metrics.py",
"chars": 13820,
"preview": "# Copyright 2024 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": ".github/analytics/issue_activity_since_date.gql",
"chars": 1677,
"preview": "{\n # Queries all the issues in a repo. For each issue, we get some basic data such as\n # the number, state, labels, an"
},
{
"path": ".github/analytics/pr_data_query.gql",
"chars": 2016,
"preview": "query {\n # Queries all the Pull Requests in a repo. For each issue, we get some basic data such as\n # the number, stat"
},
{
"path": ".github/analytics/requirements.txt",
"chars": 34,
"preview": "pandas\nabsl-py\nrequests\nmatplotlib"
},
{
"path": ".github/pull_request_template.md",
"chars": 1149,
"preview": "# What does this PR do?\n\n<!--\n\nGreat, you are contributing to Flax!\n\nBut... please read the following carefully so we ca"
},
{
"path": ".github/workflows/flax_publish.yml",
"chars": 2419,
"preview": "# This workflows will upload a Python Package using Twine when a release is created\n# For more information see: https://"
},
{
"path": ".github/workflows/flax_test.yml",
"chars": 7187,
"preview": "# This workflow will install Python dependencies, run tests and lint with a variety of Python versions\n# For more inform"
},
{
"path": ".github/workflows/flaxlib_publish.yml",
"chars": 3224,
"preview": "name: Flaxlib - Build and upload to PyPI\n\n# for testing only:\non:\n push:\n branches: [main]\n paths: ['flaxlib/**']"
},
{
"path": ".github/workflows/jax_nightly.yml",
"chars": 1675,
"preview": "name: CI - with JAX nightly\n\nconcurrency:\n group: ${{ github.workflow }}-${{ github.ref }}\n cancel-in-progress: true\n\n"
},
{
"path": ".gitignore",
"chars": 362,
"preview": "*~\n\\#*\\#\n*.pyc\n.tfds\n.DS_Store\ndist/\nbuild/\n*.egg-info\n*.rej\n.pytype\n.vscode/*\n/.devcontainer\ndocs*/**/_autosummary\ndocs"
},
{
"path": ".pre-commit-config.yaml",
"chars": 1358,
"preview": "# Install the pre-commit hooks below with\n# 'pre-commit install'\n\n# Auto-update the version of the hooks with\n# 'pre-com"
},
{
"path": ".readthedocs.yml",
"chars": 12,
"preview": "# deprecated"
},
{
"path": "AUTHORS",
"chars": 293,
"preview": "# This is the list the Flax authors for copyright purposes.\n#\n# This does not necessarily list everyone who has contribu"
},
{
"path": "CHANGELOG.md",
"chars": 28333,
"preview": "Changelog\n----------\n\nvNext\n------\n(Add your change to a random empty line to avoid merge conflicts)\n-\n-\n-\n-\n- removed G"
},
{
"path": "LICENSE",
"chars": 11309,
"preview": " Version 2.0, January 2004\n http://www.apache.org/licenses/\n\n TERMS A"
},
{
"path": "README.md",
"chars": 7905,
"preview": "<div align=\"center\">\n<img src=\"https://raw.githubusercontent.com/google/flax/main/images/flax_logo_250px.png\" alt=\"logo\""
},
{
"path": "benchmarks/README.md",
"chars": 312,
"preview": "# Benchmarks\n\nThese are mini benchmarks to measure the performance of NNX operations.\n\nSample profile command:\n\n```shell"
},
{
"path": "benchmarks/nnx_graph_overhead.py",
"chars": 4154,
"preview": "# Copyright 2024 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "benchmarks/nnx_mlpmixer_training.py",
"chars": 6752,
"preview": "# Copyright 2024 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "benchmarks/nnx_simple_training.py",
"chars": 5374,
"preview": "# Copyright 2024 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "benchmarks/nnx_state_traversal.py",
"chars": 3266,
"preview": "# Copyright 2024 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "benchmarks/tracing/README.md",
"chars": 494,
"preview": "# Tracing and lowering benchmarks for Flax examples\n\nSee Flax\n[documentation](https://flax.readthedocs.io/en/latest/exam"
},
{
"path": "benchmarks/tracing/__init__.py",
"chars": 581,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/gemma.py",
"chars": 7682,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/imagenet.py",
"chars": 7327,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/lm1b.py",
"chars": 7480,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/mnist.py",
"chars": 3132,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/nlp_seq.py",
"chars": 6923,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/ogbg_molpcba.py",
"chars": 6955,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/ppo.py",
"chars": 3890,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/requirements.txt",
"chars": 60,
"preview": "absl-py\nflax\ngoogle-benchmark\njax\nml_collections\nnumpy\noptax"
},
{
"path": "benchmarks/tracing/run_all_benchmarks.sh",
"chars": 441,
"preview": "#!/bin/bash\nset -e\n\nexport XLA_FLAGS=--xla_force_host_platform_device_count=8\n\nTARGETS=(\n mnist\n vae\n sst2\n gemma\n "
},
{
"path": "benchmarks/tracing/seq2seq.py",
"chars": 3961,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/sst2.py",
"chars": 5145,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/tracing_benchmark.py",
"chars": 2778,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/vae.py",
"chars": 2913,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "benchmarks/tracing/wmt.py",
"chars": 8926,
"preview": "# Copyright 2025 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "contributing.md",
"chars": 110,
"preview": "# How to Contribute\n\nPlease see https://flax.readthedocs.io/en/latest/contributing.html for more information.\n"
},
{
"path": "docs/.gitignore",
"chars": 18,
"preview": "_formatted_howtos\n"
},
{
"path": "docs/.readthedocs.yaml",
"chars": 626,
"preview": "# .readthedocs.yml\n# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html fo"
},
{
"path": "docs/Makefile",
"chars": 634,
"preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the "
},
{
"path": "docs/README.md",
"chars": 5436,
"preview": "# Deprecation\n\nThis folder contains the deprecated Flax Linen documentation. For the latest Flax NNX docs, check out the"
},
{
"path": "docs/_ext/codediff.py",
"chars": 7135,
"preview": "# Copyright 2024 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "docs/_ext/codediff_test.py",
"chars": 4133,
"preview": "# Copyright 2024 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "docs/_ext/flax_module.py",
"chars": 2288,
"preview": "# Copyright 2024 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "docs/_static/css/flax_theme.css",
"chars": 309,
"preview": "@import url(\"theme.css\");\n\n.wy-nav-content {\n max-width: 1290px;\n}\n\n.rst-content table.docutils {\n width: 100%;\n}\n\n.rs"
},
{
"path": "docs/_templates/autosummary/flax_module.rst",
"chars": 674,
"preview": "{{ fullname | escape | underline }}\n\n.. currentmodule:: {{ module }}\n\n.. autoclass:: {{ objname }}\n :exclude-members:\n"
},
{
"path": "docs/api_reference/flax.core.frozen_dict.rst",
"chars": 321,
"preview": "\nflax.core.frozen_dict package\n=============================\n\n.. currentmodule:: flax.core.frozen_dict\n\n.. autoclass:: F"
},
{
"path": "docs/api_reference/flax.cursor.rst",
"chars": 1679,
"preview": "\nflax.cursor package\n=============================\n\nThe Cursor API allows for mutability of pytrees. This API provides a"
},
{
"path": "docs/api_reference/flax.errors.rst",
"chars": 158,
"preview": "\nflax.errors package\n===================\n\nFlax has the following classes of errors.\n\n.. automodule:: flax.errors\n :me"
},
{
"path": "docs/api_reference/flax.jax_utils.rst",
"chars": 365,
"preview": "\nflax.jax_utils package\n========================\n\n.. currentmodule:: flax.jax_utils\n\n.. automodule:: flax.jax_utils\n\n\n.."
},
{
"path": "docs/api_reference/flax.linen/activation_functions.rst",
"chars": 827,
"preview": "\nActivation functions\n------------------------\n\n.. automodule:: flax.linen.activation\n.. currentmodule:: flax.linen.acti"
},
{
"path": "docs/api_reference/flax.linen/decorators.rst",
"chars": 117,
"preview": "Decorators\n----------------------\n\n.. currentmodule:: flax.linen\n\n.. autofunction:: compact\n.. autofunction:: nowrap\n"
},
{
"path": "docs/api_reference/flax.linen/index.rst",
"chars": 350,
"preview": "\nflax.linen\n==========\n\nLinen is the Flax Module system. Read more about our design goals in the `Linen README <https://"
},
{
"path": "docs/api_reference/flax.linen/init_apply.rst",
"chars": 141,
"preview": "\nInit/Apply\n==============\n\n.. currentmodule:: flax.linen\n\n.. autofunction:: apply\n.. autofunction:: init\n.. autofunctio"
},
{
"path": "docs/api_reference/flax.linen/initializers.rst",
"chars": 756,
"preview": "Initializers\n------------------------\n\n.. automodule:: flax.linen.initializers\n.. currentmodule:: flax.linen.initializer"
},
{
"path": "docs/api_reference/flax.linen/inspection.rst",
"chars": 94,
"preview": "\nInspection\n----------------------\n\n.. currentmodule:: flax.linen\n\n.. autofunction:: tabulate\n"
},
{
"path": "docs/api_reference/flax.linen/layers.rst",
"chars": 2360,
"preview": "Layers\n======\n\n.. currentmodule:: flax.linen\n\nLinear Modules\n------------------------\n\n.. flax_module::\n :module: flax."
},
{
"path": "docs/api_reference/flax.linen/module.rst",
"chars": 539,
"preview": "Module\n------------------------\n\n.. automodule:: flax.linen\n.. currentmodule:: flax.linen\n\n.. autoclass:: Module\n :mem"
},
{
"path": "docs/api_reference/flax.linen/profiling.rst",
"chars": 176,
"preview": "Profiling\n----------------------\n\n.. currentmodule:: flax.linen\n\n.. autofunction:: enable_named_call\n.. autofunction:: d"
},
{
"path": "docs/api_reference/flax.linen/spmd.rst",
"chars": 587,
"preview": "\nSPMD\n----------------------\n\n.. automodule:: flax.linen.spmd\n.. currentmodule:: flax.linen\n\n.. autofunction:: Partition"
},
{
"path": "docs/api_reference/flax.linen/transformations.rst",
"chars": 412,
"preview": "Transformations\n----------------------\n\n.. automodule:: flax.linen.transforms\n.. currentmodule:: flax.linen\n\n.. autofunc"
},
{
"path": "docs/api_reference/flax.linen/variable.rst",
"chars": 116,
"preview": "\nVariable dictionary\n----------------------\n\n.. automodule:: flax.core.variables\n.. autoclass:: flax.linen.Variable\n"
},
{
"path": "docs/api_reference/flax.serialization.rst",
"chars": 480,
"preview": "\nflax.serialization package\n============================\n\n.. currentmodule:: flax.serialization\n\n.. automodule:: flax.se"
},
{
"path": "docs/api_reference/flax.struct.rst",
"chars": 161,
"preview": "\nflax.struct package\n=====================\n\n.. currentmodule:: flax.struct\n\n.. automodule:: flax.struct\n\n\n.. autofunctio"
},
{
"path": "docs/api_reference/flax.traceback_util.rst",
"chars": 275,
"preview": "flax.traceback_util package\n============================\n\n.. currentmodule:: flax.traceback_util\n\n.. automodule:: flax.t"
},
{
"path": "docs/api_reference/flax.training.rst",
"chars": 1212,
"preview": "\nflax.training package\n=====================\n\nCheckpoints\n------------------------\n\n.. currentmodule:: flax.training.che"
},
{
"path": "docs/api_reference/index.rst",
"chars": 265,
"preview": "API Reference\n=============\n\n.. toctree::\n :maxdepth: 4\n\n flax.config\n flax.core.frozen_dict\n flax.cursor\n fla"
},
{
"path": "docs/conf.py",
"chars": 6263,
"preview": "# Copyright 2024 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "docs/conf_sphinx_patch.py",
"chars": 6599,
"preview": "# Copyright 2024 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "docs/developer_notes/index.rst",
"chars": 153,
"preview": "Developer notes\n===============\n\n.. toctree::\n :maxdepth: 1\n\n module_lifecycle\n lift\n FLIPs <https://github.com/"
},
{
"path": "docs/developer_notes/lift.md",
"chars": 18043,
"preview": "# Lifted transformations\n\n⚠️ Advanced topic ⚠️\n\nThis design note explains the underlying implementation of `flax.linen.t"
},
{
"path": "docs/developer_notes/module_lifecycle.rst",
"chars": 21996,
"preview": "The Flax Module lifecycle\n#########################\n\n.. testsetup::\n\n from typing import Any, Callable, Iterable\n impo"
},
{
"path": "docs/examples/community_examples.rst",
"chars": 4692,
"preview": "Community examples\n==================\n\nIn addition to the `curated list of official Flax examples on GitHub <https://git"
},
{
"path": "docs/examples/core_examples.rst",
"chars": 4413,
"preview": "Core examples\n=============\n\nCore examples are hosted on the GitHub Flax repository in the `examples <https://github.com"
},
{
"path": "docs/examples/google_research_examples.rst",
"chars": 22578,
"preview": "########################\nGoogle Research examples\n########################\n\nA collection of research by Google Research "
},
{
"path": "docs/examples/index.rst",
"chars": 148,
"preview": "Examples\n========\n\n.. toctree::\n :maxdepth: 2\n\n core_examples\n google_research_examples\n repositories_that_use_f"
},
{
"path": "docs/examples/repositories_that_use_flax.rst",
"chars": 2016,
"preview": "Repositories that use Flax\n==========================\n\nThe following code bases use Flax and provide training frameworks"
},
{
"path": "docs/faq.rst",
"chars": 4130,
"preview": "Frequently Asked Questions (FAQ)\n================================\n\nThis is a collection of answers to frequently asked q"
},
{
"path": "docs/flip/0000-template.md",
"chars": 648,
"preview": "- Start Date: (fill me in with today's date, YYYY-MM-DD)\n- FLIP PR: [#0000](https://github.com/google/flax/pull/0000)\n- "
},
{
"path": "docs/flip/1009-optimizer-api.md",
"chars": 17252,
"preview": "- Start Date: 2021-02-08\n- FLIP PR: [#1011](https://github.com/google/flax/pull/1011)\n- FLIP Issue: [#1009](https://gith"
},
{
"path": "docs/flip/1777-default-dtype.md",
"chars": 8185,
"preview": "# FLIP: Default dtypes\n\n\n- Start Date: 2022-01-11\n- FLIP PR: [#1776](https://github.com/google/flax/pull/1776)\n- FLIP Is"
},
{
"path": "docs/flip/2396-rnn.md",
"chars": 11758,
"preview": "# RNN Flip\n\n- Start Date: 2022-08-18\n- FLIP PR: [#2604](https://github.com/google/flax/pull/2604)\n- FLIP Issue: [#2396]("
},
{
"path": "docs/flip/2434-general-metadata.md",
"chars": 10424,
"preview": "# FLIP: Axis Metadata\n\n\n- Start Date: 2022-08-08\n- FLIP Issue: [#2434](https://github.com/google/flax/issues/2434)\n- FLI"
},
{
"path": "docs/flip/2974-kw-only-dataclasses.md",
"chars": 4089,
"preview": "# FLIP: kw_only dataclasses\nAuthors: Brennan Saeta, Ivy Zheng\n\n - Start Date: Mar 23, 2023\n - FLIP Issue: [TBD]\n - FLIP "
},
{
"path": "docs/flip/3099-rnnbase-refactor.md",
"chars": 4067,
"preview": "# Refactor RNNCellBase in FLIP\n\nAuthors: Cristian Garcia, Marcus Chiam, Jasmijn Bastings\n\n - Start Date: May 1, 2023\n - "
},
{
"path": "docs/flip/4105-jax-style-nnx-transforms.md",
"chars": 8505,
"preview": "# JAX-style NNX Transforms\n\n- Authors: Cristian Garcia, Anselm Levskaya\n- Date: Jun/2024\n- FLIP PR: #4107\n- Status: Impl"
},
{
"path": "docs/flip/README.md",
"chars": 1404,
"preview": "# FLIP: Flax Improvement Process\n\nMost changes can be discussed with simple issues/discussions and pull requests.\n\nSome "
},
{
"path": "docs/glossary.rst",
"chars": 7012,
"preview": "*********\nGlossary\n*********\n\nFor additional terms, refer to the `Jax glossary <https://jax.readthedocs.io/en/latest/glo"
},
{
"path": "docs/guides/converting_and_upgrading/convert_pytorch_to_flax.rst",
"chars": 10599,
"preview": "Convert PyTorch models to Flax\n==============================\n\n.. testsetup::\n\n import numpy as np\n import jax\n from "
},
{
"path": "docs/guides/converting_and_upgrading/haiku_migration_guide.rst",
"chars": 26304,
"preview": "\nMigrating from Haiku to Flax\n============================\n\nThis guide will walk through the process of migrating Haiku "
},
{
"path": "docs/guides/converting_and_upgrading/index.rst",
"chars": 255,
"preview": "Converting and upgrading\n========================\n\n.. toctree::\n :maxdepth: 1\n\n haiku_migration_guide\n convert_pyt"
},
{
"path": "docs/guides/converting_and_upgrading/linen_upgrade_guide.rst",
"chars": 17336,
"preview": "Upgrading my codebase to Linen\n==============================\n\nAs of Flax v0.4.0, ``flax.nn`` no longer exists, and is r"
},
{
"path": "docs/guides/converting_and_upgrading/optax_update_guide.rst",
"chars": 10667,
"preview": "Upgrading my codebase to Optax\n==============================\n\nWe have proposed to replace :py:mod:`flax.optim` with `Op"
},
{
"path": "docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst",
"chars": 10245,
"preview": "Migrate checkpointing to Orbax\n==============================\n\nThis guide shows how to convert Flax's checkpoint saving "
},
{
"path": "docs/guides/converting_and_upgrading/regular_dict_upgrade_guide.rst",
"chars": 4106,
"preview": "Migrate to regular dicts\n========================\n\nFlax will migrate from returning ``FrozenDicts`` to regular dicts whe"
},
{
"path": "docs/guides/converting_and_upgrading/rnncell_upgrade_guide.rst",
"chars": 6105,
"preview": "RNNCellBase Upgrade Guide\n=========================\n\nThe ``RNNCellBase`` API has undergone some key updates aimed at enh"
},
{
"path": "docs/guides/data_preprocessing/full_eval.rst",
"chars": 6986,
"preview": "Processing the entire Dataset\n=============================\n\nFor efficiency reasons, we form batches that contain multip"
},
{
"path": "docs/guides/data_preprocessing/index.rst",
"chars": 101,
"preview": "Data preprocessing\n=================\n\n.. toctree::\n :maxdepth: 1\n\n full_eval\n loading_datasets\n"
},
{
"path": "docs/guides/data_preprocessing/loading_datasets.ipynb",
"chars": 8227,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# Loading datasets\\n\",\n \"\\n\",\n "
},
{
"path": "docs/guides/data_preprocessing/loading_datasets.md",
"chars": 5146,
"preview": "---\njupytext:\n formats: ipynb,md:myst\n text_representation:\n extension: .md\n format_name: myst\n format_versio"
},
{
"path": "docs/guides/flax_fundamentals/arguments.md",
"chars": 4087,
"preview": "# Dealing with Flax Module arguments\n\n## Introduction\n\nIn Flax Linen we can define `Module` arguments either as dataclas"
},
{
"path": "docs/guides/flax_fundamentals/flax_basics.ipynb",
"chars": 38145,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"[, submodules"
},
{
"path": "docs/guides/flax_fundamentals/state_params.rst",
"chars": 5886,
"preview": "Managing Parameters and State\n=============================\n\nWe will show you how to...\n\n* manage the variables from ini"
},
{
"path": "docs/guides/flax_sharp_bits.ipynb",
"chars": 7326,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# 🔪 Flax - The Sharp Bits 🔪\\n\",\n "
},
{
"path": "docs/guides/flax_sharp_bits.md",
"chars": 5894,
"preview": "---\njupytext:\n formats: ipynb,md:myst\n text_representation:\n extension: .md\n format_name: myst\n format_versio"
},
{
"path": "docs/guides/index.rst",
"chars": 274,
"preview": "Guides\n======\n\n.. toctree::\n :maxdepth: 2\n\n flax_fundamentals/index\n data_preprocessing/index\n training_techniqu"
},
{
"path": "docs/guides/model_inspection/extracting_intermediates.rst",
"chars": 12728,
"preview": "Extracting intermediate values\n==============================\n\nThis guide will show you how to extract intermediate valu"
},
{
"path": "docs/guides/model_inspection/index.rst",
"chars": 110,
"preview": "Model inspection\n================\n\n.. toctree::\n :maxdepth: 1\n\n model_surgery\n extracting_intermediates\n"
},
{
"path": "docs/guides/model_inspection/model_surgery.ipynb",
"chars": 7138,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"120e57f5\",\n \"metadata\": {},\n \"source\": [\n \"Model surgery\\"
},
{
"path": "docs/guides/model_inspection/model_surgery.md",
"chars": 4368,
"preview": "---\njupyter:\n jupytext:\n formats: md,ipynb\n main_language: python\n text_representation:\n extension: .md\n "
},
{
"path": "docs/guides/parallel_training/ensembling.rst",
"chars": 10423,
"preview": "Ensembling on multiple devices\n==============================\n\nWe show how to train an ensemble of CNNs on the MNIST dat"
},
{
"path": "docs/guides/parallel_training/flax_on_pjit.ipynb",
"chars": 57177,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# Scale up Flax Modules on multiple"
},
{
"path": "docs/guides/parallel_training/flax_on_pjit.md",
"chars": 24364,
"preview": "---\njupytext:\n formats: ipynb,md:myst\n text_representation:\n extension: .md\n format_name: myst\n format_versio"
},
{
"path": "docs/guides/parallel_training/index.rst",
"chars": 97,
"preview": "Parallel training\n=================\n\n.. toctree::\n :maxdepth: 1\n\n ensembling\n flax_on_pjit\n"
},
{
"path": "docs/guides/quantization/fp8_basics.ipynb",
"chars": 17748,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"ca360491\",\n \"metadata\": {},\n \"source\": [\n \"# User Guide o"
},
{
"path": "docs/guides/quantization/fp8_basics.md",
"chars": 13018,
"preview": "---\njupytext:\n formats: ipynb,md:myst\n text_representation:\n extension: .md\n format_name: myst\n format_versio"
},
{
"path": "docs/guides/quantization/index.rst",
"chars": 71,
"preview": "Quantization\n============\n\n.. toctree::\n :maxdepth: 1\n\n fp8_basics\n"
},
{
"path": "docs/guides/training_techniques/batch_norm.rst",
"chars": 9089,
"preview": "Batch normalization\n===================\n\nIn this guide, you will learn how to apply `batch normalization <https://arxiv."
},
{
"path": "docs/guides/training_techniques/dropout.rst",
"chars": 10935,
"preview": "Dropout\n=======\n\nThis guide provides an overview of how to apply\n`dropout <https://jmlr.org/papers/volume15/srivastava14"
},
{
"path": "docs/guides/training_techniques/index.rst",
"chars": 152,
"preview": "Training techniques\n===================\n\n.. toctree::\n :maxdepth: 1\n\n batch_norm\n dropout\n lr_schedule\n transf"
},
{
"path": "docs/guides/training_techniques/lr_schedule.rst",
"chars": 8226,
"preview": "Learning rate scheduling\n=============================\n\nThe learning rate is considered one of the most important hyperp"
},
{
"path": "docs/guides/training_techniques/transfer_learning.ipynb",
"chars": 11290,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# Transfer learning\"\n ]\n },\n {\n"
},
{
"path": "docs/guides/training_techniques/transfer_learning.md",
"chars": 7843,
"preview": "---\njupytext:\n formats: ipynb,md:myst\n text_representation:\n extension: .md\n format_name: myst\n format_versio"
},
{
"path": "docs/guides/training_techniques/use_checkpointing.ipynb",
"chars": 53415,
"preview": "{\n \"cells\": [\n {\n \"attachments\": {},\n \"cell_type\": \"markdown\",\n \"id\": \"6e9134fa\",\n \"metadata\": {},\n \"source\":"
},
{
"path": "docs/guides/training_techniques/use_checkpointing.md",
"chars": 25757,
"preview": "---\njupyter:\n jupytext:\n formats: ipynb,md\n main_language: python\n text_representation:\n extension: .md\n "
},
{
"path": "docs/index.rst",
"chars": 8849,
"preview": ".. Flax documentation main file, created by\n sphinx-quickstart on Mon Feb 17 11:41:38 2020.\n You can adapt this file"
},
{
"path": "docs/linen_intro.ipynb",
"chars": 39585,
"preview": "{\n \"cells\": [\n {\n \"attachments\": {},\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"[\n- "
},
{
"path": "docs_nnx/flip/1009-optimizer-api.md",
"chars": 17252,
"preview": "- Start Date: 2021-02-08\n- FLIP PR: [#1011](https://github.com/google/flax/pull/1011)\n- FLIP Issue: [#1009](https://gith"
},
{
"path": "docs_nnx/flip/1777-default-dtype.md",
"chars": 8185,
"preview": "# FLIP: Default dtypes\n\n\n- Start Date: 2022-01-11\n- FLIP PR: [#1776](https://github.com/google/flax/pull/1776)\n- FLIP Is"
},
{
"path": "docs_nnx/flip/2396-rnn.md",
"chars": 11751,
"preview": "# RNN Flip\n\n- Start Date: 2022-08-18\n- FLIP PR: [#2604](https://github.com/google/flax/pull/2604)\n- FLIP Issue: [#2396]("
},
{
"path": "docs_nnx/flip/2434-general-metadata.md",
"chars": 10424,
"preview": "# FLIP: Axis Metadata\n\n\n- Start Date: 2022-08-08\n- FLIP Issue: [#2434](https://github.com/google/flax/issues/2434)\n- FLI"
},
{
"path": "docs_nnx/flip/2974-kw-only-dataclasses.md",
"chars": 4089,
"preview": "# FLIP: kw_only dataclasses\nAuthors: Brennan Saeta, Ivy Zheng\n\n - Start Date: Mar 23, 2023\n - FLIP Issue: [TBD]\n - FLIP "
},
{
"path": "docs_nnx/flip/3099-rnnbase-refactor.md",
"chars": 4067,
"preview": "# Refactor RNNCellBase in FLIP\n\nAuthors: Cristian Garcia, Marcus Chiam, Jasmijn Bastings\n\n - Start Date: May 1, 2023\n - "
},
{
"path": "docs_nnx/flip/4105-jax-style-nnx-transforms.md",
"chars": 8505,
"preview": "# JAX-style NNX Transforms\n\n- Authors: Cristian Garcia, Anselm Levskaya\n- Date: Jun/2024\n- FLIP PR: #4107\n- Status: Impl"
},
{
"path": "docs_nnx/flip/4844-var-eager-sharding.md",
"chars": 2480,
"preview": "- Start Date: 2025-09-12\n- FLIP PR: [#4844](https://github.com/google/flax/pull/4844)\n\n# FLIP 4844: Variable eager shard"
},
{
"path": "docs_nnx/flip/5310-tree-mode-nnx.md",
"chars": 13171,
"preview": "# Tree Mode NNX\n\nMar 4, 2026\nCristian Garcia, Samuel Anklesaria, Flax Team\n\n## Motivation\n\nCurrent NNX APIs allow genera"
}
]
// ... and 418 more files (download for full content)
About this extraction
This page contains the full source code of the google/flax GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 618 files (9.9 MB), approximately 2.6M tokens, and a symbol index with 5058 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.